diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel
index ac5599ba9..3b0e2762a 100644
--- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel
+++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel
@@ -122,6 +122,7 @@ java_library(
deps = [
":evaluation_exception",
":evaluation_exception_builder",
+ ":function_binding",
":function_overload",
":function_resolver",
":resolved_overload",
@@ -141,6 +142,7 @@ cel_android_library(
deps = [
":evaluation_exception",
":evaluation_exception_builder",
+ ":function_binding_android",
":function_overload_android",
":function_resolver_android",
":resolved_overload_android",
@@ -609,6 +611,7 @@ java_library(
deps = [
":function_binding",
":runtime_equality",
+ "//common:operator",
"//common:options",
"//common/annotations",
"//runtime/standard:add",
@@ -666,6 +669,7 @@ cel_android_library(
deps = [
":function_binding_android",
":runtime_equality_android",
+ "//common:operator_android",
"//common:options",
"//common/annotations",
"//runtime/standard:add_android",
@@ -723,6 +727,7 @@ java_library(
tags = [
],
deps = [
+ ":evaluation_exception",
":function_overload",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
@@ -735,6 +740,7 @@ cel_android_library(
tags = [
],
deps = [
+ ":evaluation_exception",
":function_overload_android",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven_android//:com_google_guava_guava",
@@ -774,7 +780,9 @@ java_library(
],
deps = [
":evaluation_exception",
+ "//runtime:unknown_attributes",
"@maven//:com_google_errorprone_error_prone_annotations",
+ "@maven//:com_google_guava_guava",
],
)
@@ -785,7 +793,9 @@ cel_android_library(
],
deps = [
":evaluation_exception",
+ "//runtime:unknown_attributes_android",
"@maven//:com_google_errorprone_error_prone_annotations",
+ "@maven_android//:com_google_guava_guava",
],
)
@@ -1173,7 +1183,6 @@ java_library(
],
deps = [
":function_overload",
- ":unknown_attributes",
"//:auto_value",
"//common/annotations",
"@maven//:com_google_errorprone_error_prone_annotations",
@@ -1188,7 +1197,6 @@ cel_android_library(
],
deps = [
":function_overload_android",
- ":unknown_attributes_android",
"//:auto_value",
"//common/annotations",
"@maven//:com_google_errorprone_error_prone_annotations",
diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java
index 79b0f3f54..c7b63926b 100644
--- a/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java
+++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java
@@ -14,8 +14,13 @@
package dev.cel.runtime;
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
import com.google.errorprone.annotations.Immutable;
+import java.util.Collection;
/**
* Binding consisting of an overload id, a Java-native argument signature, and an overload
@@ -35,7 +40,6 @@
*
*
Examples: string_startsWith_string, mathMax_list, lessThan_money_money
*/
-
@Immutable
public interface CelFunctionBinding {
String getOverloadId();
@@ -70,4 +74,23 @@ static CelFunctionBinding from(
return new FunctionBindingImpl(
overloadId, ImmutableList.copyOf(argTypes), impl, /* isStrict= */ true);
}
+
+ /** See {@link #fromOverloads(String, Collection)}. */
+ static ImmutableSet fromOverloads(
+ String functionName, CelFunctionBinding... overloadBindings) {
+ return fromOverloads(functionName, ImmutableList.copyOf(overloadBindings));
+ }
+
+ /**
+ * Creates a set of bindings for a function, enabling dynamic dispatch logic to select the correct
+ * overload at runtime based on argument types.
+ */
+ static ImmutableSet fromOverloads(
+ String functionName, Collection overloadBindings) {
+ checkArgument(!Strings.isNullOrEmpty(functionName), "Function name cannot be null or empty");
+ checkArgument(!overloadBindings.isEmpty(), "You must provide at least one binding.");
+
+ return FunctionBindingImpl.groupOverloadsToFunction(
+ functionName, ImmutableSet.copyOf(overloadBindings));
+ }
}
diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java
index a1341cb21..3e30a2146 100644
--- a/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java
+++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java
@@ -14,7 +14,9 @@
package dev.cel.runtime;
+import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.Immutable;
+import java.util.Map;
/** Interface describing the general signature of all CEL custom function implementations. */
@Immutable
@@ -43,4 +45,39 @@ interface Unary {
interface Binary {
Object apply(T1 arg1, T2 arg2) throws CelEvaluationException;
}
+
+ /**
+ * Returns true if the overload's expected argument types match the types of the given arguments.
+ */
+ static boolean canHandle(
+ Object[] arguments, ImmutableList> parameterTypes, boolean isStrict) {
+ if (parameterTypes.size() != arguments.length) {
+ return false;
+ }
+ for (int i = 0; i < parameterTypes.size(); i++) {
+ Class> paramType = parameterTypes.get(i);
+ Object arg = arguments[i];
+ if (arg == null) {
+ // null can be assigned to messages, maps, and to objects.
+ // TODO: Remove null special casing
+ if (paramType != Object.class && !Map.class.isAssignableFrom(paramType)) {
+ return false;
+ }
+ continue;
+ }
+
+ if (arg instanceof Exception || arg instanceof CelUnknownSet) {
+ // Only non-strict functions can accept errors/unknowns as arguments to a function
+ if (!isStrict) {
+ // Skip assignability check below, but continue to validate remaining args
+ continue;
+ }
+ }
+
+ if (!paramType.isAssignableFrom(arg.getClass())) {
+ return false;
+ }
+ }
+ return true;
+ }
}
diff --git a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java
index f6a0c4f99..2bcdf3a2d 100644
--- a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java
+++ b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java
@@ -19,7 +19,6 @@
import com.google.errorprone.annotations.Immutable;
import dev.cel.common.annotations.Internal;
import java.util.List;
-import java.util.Map;
/**
* Representation of a function overload which has been resolved to a specific set of argument types
@@ -80,41 +79,6 @@ public static CelResolvedOverload of(
* Returns true if the overload's expected argument types match the types of the given arguments.
*/
boolean canHandle(Object[] arguments) {
- return canHandle(arguments, getParameterTypes(), isStrict());
- }
-
- /**
- * Returns true if the overload's expected argument types match the types of the given arguments.
- */
- public static boolean canHandle(
- Object[] arguments, ImmutableList> parameterTypes, boolean isStrict) {
- if (parameterTypes.size() != arguments.length) {
- return false;
- }
- for (int i = 0; i < parameterTypes.size(); i++) {
- Class> paramType = parameterTypes.get(i);
- Object arg = arguments[i];
- if (arg == null) {
- // null can be assigned to messages, maps, and to objects.
- // TODO: Remove null special casing
- if (paramType != Object.class && !Map.class.isAssignableFrom(paramType)) {
- return false;
- }
- continue;
- }
-
- if (arg instanceof Exception || arg instanceof CelUnknownSet) {
- // Only non-strict functions can accept errors/unknowns as arguments to a function
- if (!isStrict) {
- // Skip assignability check below, but continue to validate remaining args
- continue;
- }
- }
-
- if (!paramType.isAssignableFrom(arg.getClass())) {
- return false;
- }
- }
- return true;
+ return CelFunctionOverload.canHandle(arguments, getParameterTypes(), isStrict());
}
}
diff --git a/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java b/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java
index bedd41728..39797e086 100644
--- a/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java
+++ b/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java
@@ -15,12 +15,15 @@
package dev.cel.runtime;
import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.errorprone.annotations.Immutable;
import dev.cel.common.CelOptions;
+import dev.cel.common.Operator;
import dev.cel.common.annotations.Internal;
import dev.cel.runtime.standard.AddOperator;
import dev.cel.runtime.standard.AddOperator.AddOverload;
@@ -104,6 +107,8 @@
import dev.cel.runtime.standard.TimestampFunction.TimestampOverload;
import dev.cel.runtime.standard.UintFunction;
import dev.cel.runtime.standard.UintFunction.UintOverload;
+import java.util.Collection;
+import java.util.Map;
/** Runtime function bindings for the standard functions in CEL. */
@Immutable
@@ -135,7 +140,7 @@ public final class CelStandardFunctions {
GreaterEqualsOverload.GREATER_EQUALS_UINT64_DOUBLE,
GreaterEqualsOverload.GREATER_EQUALS_DOUBLE_UINT64);
- private final ImmutableSet standardOverloads;
+ private final ImmutableMultimap standardOverloads;
public static final ImmutableSet ALL_STANDARD_FUNCTIONS =
ImmutableSet.of(
@@ -187,13 +192,15 @@ public final class CelStandardFunctions {
* special-cased, and does not appear in this enum.
*/
public enum StandardFunction {
- LOGICAL_NOT(LogicalNotOverload.LOGICAL_NOT),
- IN(InOverload.IN_LIST, InOverload.IN_MAP),
- NOT_STRICTLY_FALSE(NotStrictlyFalseOverload.NOT_STRICTLY_FALSE),
- EQUALS(EqualsOverload.EQUALS),
- NOT_EQUALS(NotEqualsOverload.NOT_EQUALS),
- BOOL(BoolOverload.BOOL_TO_BOOL, BoolOverload.STRING_TO_BOOL),
+ LOGICAL_NOT(Operator.LOGICAL_NOT.getFunction(), LogicalNotOverload.LOGICAL_NOT),
+ IN(Operator.IN.getFunction(), InOverload.IN_LIST, InOverload.IN_MAP),
+ NOT_STRICTLY_FALSE(
+ Operator.NOT_STRICTLY_FALSE.getFunction(), NotStrictlyFalseOverload.NOT_STRICTLY_FALSE),
+ EQUALS(Operator.EQUALS.getFunction(), EqualsOverload.EQUALS),
+ NOT_EQUALS(Operator.NOT_EQUALS.getFunction(), NotEqualsOverload.NOT_EQUALS),
+ BOOL("bool", BoolOverload.BOOL_TO_BOOL, BoolOverload.STRING_TO_BOOL),
ADD(
+ Operator.ADD.getFunction(),
AddOverload.ADD_INT64,
AddOverload.ADD_UINT64,
AddOverload.ADD_DOUBLE,
@@ -204,6 +211,7 @@ public enum StandardFunction {
AddOverload.ADD_DURATION_TIMESTAMP,
AddOverload.ADD_DURATION_DURATION),
SUBTRACT(
+ Operator.SUBTRACT.getFunction(),
SubtractOverload.SUBTRACT_INT64,
SubtractOverload.SUBTRACT_TIMESTAMP_TIMESTAMP,
SubtractOverload.SUBTRACT_TIMESTAMP_DURATION,
@@ -211,14 +219,22 @@ public enum StandardFunction {
SubtractOverload.SUBTRACT_DOUBLE,
SubtractOverload.SUBTRACT_DURATION_DURATION),
MULTIPLY(
+ Operator.MULTIPLY.getFunction(),
MultiplyOverload.MULTIPLY_INT64,
MultiplyOverload.MULTIPLY_DOUBLE,
MultiplyOverload.MULTIPLY_UINT64),
- DIVIDE(DivideOverload.DIVIDE_DOUBLE, DivideOverload.DIVIDE_INT64, DivideOverload.DIVIDE_UINT64),
- MODULO(ModuloOverload.MODULO_INT64, ModuloOverload.MODULO_UINT64),
- NEGATE(NegateOverload.NEGATE_INT64, NegateOverload.NEGATE_DOUBLE),
- INDEX(IndexOverload.INDEX_LIST, IndexOverload.INDEX_MAP),
+ DIVIDE(
+ Operator.DIVIDE.getFunction(),
+ DivideOverload.DIVIDE_DOUBLE,
+ DivideOverload.DIVIDE_INT64,
+ DivideOverload.DIVIDE_UINT64),
+ MODULO(
+ Operator.MODULO.getFunction(), ModuloOverload.MODULO_INT64, ModuloOverload.MODULO_UINT64),
+ NEGATE(
+ Operator.NEGATE.getFunction(), NegateOverload.NEGATE_INT64, NegateOverload.NEGATE_DOUBLE),
+ INDEX(Operator.INDEX.getFunction(), IndexOverload.INDEX_LIST, IndexOverload.INDEX_MAP),
SIZE(
+ "size",
SizeOverload.SIZE_STRING,
SizeOverload.SIZE_BYTES,
SizeOverload.SIZE_LIST,
@@ -228,22 +244,26 @@ public enum StandardFunction {
SizeOverload.LIST_SIZE,
SizeOverload.MAP_SIZE),
INT(
+ "int",
IntOverload.INT64_TO_INT64,
IntOverload.UINT64_TO_INT64,
IntOverload.DOUBLE_TO_INT64,
IntOverload.STRING_TO_INT64,
IntOverload.TIMESTAMP_TO_INT64),
UINT(
+ "uint",
UintOverload.UINT64_TO_UINT64,
UintOverload.INT64_TO_UINT64,
UintOverload.DOUBLE_TO_UINT64,
UintOverload.STRING_TO_UINT64),
DOUBLE(
+ "double",
DoubleOverload.DOUBLE_TO_DOUBLE,
DoubleOverload.INT64_TO_DOUBLE,
DoubleOverload.STRING_TO_DOUBLE,
DoubleOverload.UINT64_TO_DOUBLE),
STRING(
+ "string",
StringOverload.STRING_TO_STRING,
StringOverload.INT64_TO_STRING,
StringOverload.DOUBLE_TO_STRING,
@@ -252,51 +272,67 @@ public enum StandardFunction {
StringOverload.TIMESTAMP_TO_STRING,
StringOverload.DURATION_TO_STRING,
StringOverload.UINT64_TO_STRING),
- BYTES(BytesOverload.BYTES_TO_BYTES, BytesOverload.STRING_TO_BYTES),
- DURATION(DurationOverload.DURATION_TO_DURATION, DurationOverload.STRING_TO_DURATION),
+ BYTES("bytes", BytesOverload.BYTES_TO_BYTES, BytesOverload.STRING_TO_BYTES),
+ DURATION(
+ "duration", DurationOverload.DURATION_TO_DURATION, DurationOverload.STRING_TO_DURATION),
TIMESTAMP(
+ "timestamp",
TimestampOverload.STRING_TO_TIMESTAMP,
TimestampOverload.TIMESTAMP_TO_TIMESTAMP,
TimestampOverload.INT64_TO_TIMESTAMP),
- DYN(DynOverload.TO_DYN),
- MATCHES(MatchesOverload.MATCHES, MatchesOverload.MATCHES_STRING),
- CONTAINS(ContainsOverload.CONTAINS_STRING),
- ENDS_WITH(EndsWithOverload.ENDS_WITH_STRING),
- STARTS_WITH(StartsWithOverload.STARTS_WITH_STRING),
+ DYN("dyn", DynOverload.TO_DYN),
+ MATCHES("matches", MatchesOverload.MATCHES, MatchesOverload.MATCHES_STRING),
+ CONTAINS("contains", ContainsOverload.CONTAINS_STRING),
+ ENDS_WITH("endsWith", EndsWithOverload.ENDS_WITH_STRING),
+ STARTS_WITH("startsWith", StartsWithOverload.STARTS_WITH_STRING),
// Date/time Functions
GET_FULL_YEAR(
- GetFullYearOverload.TIMESTAMP_TO_YEAR, GetFullYearOverload.TIMESTAMP_TO_YEAR_WITH_TZ),
- GET_MONTH(GetMonthOverload.TIMESTAMP_TO_MONTH, GetMonthOverload.TIMESTAMP_TO_MONTH_WITH_TZ),
+ "getFullYear",
+ GetFullYearOverload.TIMESTAMP_TO_YEAR,
+ GetFullYearOverload.TIMESTAMP_TO_YEAR_WITH_TZ),
+ GET_MONTH(
+ "getMonth",
+ GetMonthOverload.TIMESTAMP_TO_MONTH,
+ GetMonthOverload.TIMESTAMP_TO_MONTH_WITH_TZ),
GET_DAY_OF_YEAR(
+ "getDayOfYear",
GetDayOfYearOverload.TIMESTAMP_TO_DAY_OF_YEAR,
GetDayOfYearOverload.TIMESTAMP_TO_DAY_OF_YEAR_WITH_TZ),
GET_DAY_OF_MONTH(
+ "getDayOfMonth",
GetDayOfMonthOverload.TIMESTAMP_TO_DAY_OF_MONTH,
GetDayOfMonthOverload.TIMESTAMP_TO_DAY_OF_MONTH_WITH_TZ),
GET_DATE(
+ "getDate",
GetDateOverload.TIMESTAMP_TO_DAY_OF_MONTH_1_BASED,
GetDateOverload.TIMESTAMP_TO_DAY_OF_MONTH_1_BASED_WITH_TZ),
GET_DAY_OF_WEEK(
+ "getDayOfWeek",
GetDayOfWeekOverload.TIMESTAMP_TO_DAY_OF_WEEK,
GetDayOfWeekOverload.TIMESTAMP_TO_DAY_OF_WEEK_WITH_TZ),
GET_HOURS(
+ "getHours",
GetHoursOverload.TIMESTAMP_TO_HOURS,
GetHoursOverload.TIMESTAMP_TO_HOURS_WITH_TZ,
GetHoursOverload.DURATION_TO_HOURS),
GET_MINUTES(
+ "getMinutes",
GetMinutesOverload.TIMESTAMP_TO_MINUTES,
GetMinutesOverload.TIMESTAMP_TO_MINUTES_WITH_TZ,
GetMinutesOverload.DURATION_TO_MINUTES),
GET_SECONDS(
+ "getSeconds",
GetSecondsOverload.TIMESTAMP_TO_SECONDS,
GetSecondsOverload.TIMESTAMP_TO_SECONDS_WITH_TZ,
GetSecondsOverload.DURATION_TO_SECONDS),
GET_MILLISECONDS(
+ "getMilliseconds",
GetMillisecondsOverload.TIMESTAMP_TO_MILLISECONDS,
GetMillisecondsOverload.TIMESTAMP_TO_MILLISECONDS_WITH_TZ,
GetMillisecondsOverload.DURATION_TO_MILLISECONDS),
LESS(
+ Operator.LESS.getFunction(),
LessOverload.LESS_BOOL,
LessOverload.LESS_INT64,
LessOverload.LESS_UINT64,
@@ -312,6 +348,7 @@ public enum StandardFunction {
LessOverload.LESS_UINT64_DOUBLE,
LessOverload.LESS_DOUBLE_UINT64),
LESS_EQUALS(
+ Operator.LESS_EQUALS.getFunction(),
LessEqualsOverload.LESS_EQUALS_BOOL,
LessEqualsOverload.LESS_EQUALS_INT64,
LessEqualsOverload.LESS_EQUALS_UINT64,
@@ -327,6 +364,7 @@ public enum StandardFunction {
LessEqualsOverload.LESS_EQUALS_UINT64_DOUBLE,
LessEqualsOverload.LESS_EQUALS_DOUBLE_UINT64),
GREATER(
+ Operator.GREATER.getFunction(),
GreaterOverload.GREATER_BOOL,
GreaterOverload.GREATER_INT64,
GreaterOverload.GREATER_UINT64,
@@ -342,6 +380,7 @@ public enum StandardFunction {
GreaterOverload.GREATER_UINT64_DOUBLE,
GreaterOverload.GREATER_DOUBLE_UINT64),
GREATER_EQUALS(
+ Operator.GREATER_EQUALS.getFunction(),
GreaterEqualsOverload.GREATER_EQUALS_BOOL,
GreaterEqualsOverload.GREATER_EQUALS_BYTES,
GreaterEqualsOverload.GREATER_EQUALS_DOUBLE,
@@ -357,9 +396,11 @@ public enum StandardFunction {
GreaterEqualsOverload.GREATER_EQUALS_UINT64_DOUBLE,
GreaterEqualsOverload.GREATER_EQUALS_DOUBLE_UINT64);
+ private final String functionName;
private final ImmutableSet standardOverloads;
- StandardFunction(CelStandardOverload... overloads) {
+ StandardFunction(String functionName, CelStandardOverload... overloads) {
+ this.functionName = functionName;
this.standardOverloads = ImmutableSet.copyOf(overloads);
}
@@ -371,15 +412,25 @@ ImmutableSet getOverloads() {
@VisibleForTesting
ImmutableSet getOverloads() {
- return standardOverloads;
+ return ImmutableSet.copyOf(standardOverloads.values());
}
@Internal
public ImmutableSet newFunctionBindings(
RuntimeEquality runtimeEquality, CelOptions celOptions) {
ImmutableSet.Builder builder = ImmutableSet.builder();
- for (CelStandardOverload overload : standardOverloads) {
- builder.add(overload.newFunctionBinding(celOptions, runtimeEquality));
+
+ for (Map.Entry> entry :
+ standardOverloads.asMap().entrySet()) {
+ String functionName = entry.getKey();
+ Collection overloads = entry.getValue();
+
+ ImmutableSet bindings =
+ overloads.stream()
+ .map(o -> o.newFunctionBinding(celOptions, runtimeEquality))
+ .collect(toImmutableSet());
+
+ builder.addAll(CelFunctionBinding.fromOverloads(functionName, bindings));
}
return builder.build();
@@ -454,39 +505,36 @@ public CelStandardFunctions build() {
"You may only populate one of the following builder methods: includeFunctions,"
+ " excludeFunctions or filterFunctions");
- ImmutableSet.Builder standardOverloadBuilder = ImmutableSet.builder();
+ ImmutableMultimap.Builder standardOverloadBuilder =
+ ImmutableMultimap.builder();
for (StandardFunction standardFunction : StandardFunction.values()) {
if (hasIncludeFunctions) {
if (this.includeFunctions.contains(standardFunction)) {
- standardOverloadBuilder.addAll(standardFunction.standardOverloads);
+ standardOverloadBuilder.putAll(
+ standardFunction.functionName, standardFunction.standardOverloads);
}
continue;
}
if (hasExcludeFunctions) {
if (!this.excludeFunctions.contains(standardFunction)) {
- standardOverloadBuilder.addAll(standardFunction.standardOverloads);
+ standardOverloadBuilder.putAll(
+ standardFunction.functionName, standardFunction.standardOverloads);
}
continue;
}
if (hasFilterFunction) {
- ImmutableSet.Builder filteredOverloadsBuilder =
- ImmutableSet.builder();
for (CelStandardOverload standardOverload : standardFunction.standardOverloads) {
boolean includeOverload = functionFilter.include(standardFunction, standardOverload);
if (includeOverload) {
- standardOverloadBuilder.add(standardOverload);
+ standardOverloadBuilder.put(standardFunction.functionName, standardOverload);
}
}
- ImmutableSet filteredOverloads = filteredOverloadsBuilder.build();
- if (!filteredOverloads.isEmpty()) {
- standardOverloadBuilder.addAll(filteredOverloads);
- }
-
continue;
}
- standardOverloadBuilder.addAll(standardFunction.standardOverloads);
+ standardOverloadBuilder.putAll(
+ standardFunction.functionName, standardFunction.standardOverloads);
}
return new CelStandardFunctions(standardOverloadBuilder.build());
@@ -511,7 +559,7 @@ static boolean isHeterogeneousComparison(CelStandardOverload overload) {
return HETEROGENEOUS_COMPARISON_OPERATORS.contains(overload);
}
- private CelStandardFunctions(ImmutableSet standardOverloads) {
+ private CelStandardFunctions(ImmutableMultimap standardOverloads) {
this.standardOverloads = standardOverloads;
}
}
diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java
index 35ce243b3..b5e5ca55c 100644
--- a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java
+++ b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java
@@ -19,11 +19,13 @@
import com.google.auto.value.AutoBuilder;
import com.google.common.base.Joiner;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.errorprone.annotations.Immutable;
import dev.cel.common.CelErrorCode;
import dev.cel.common.annotations.Internal;
+import dev.cel.runtime.FunctionBindingImpl.DynamicDispatchOverload;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -127,7 +129,7 @@ public abstract static class Builder {
@CanIgnoreReturnValue
public Builder addOverload(
String overloadId,
- List> argTypes,
+ ImmutableList> argTypes,
boolean isStrict,
CelFunctionOverload overload) {
checkNotNull(overloadId);
@@ -136,13 +138,36 @@ public Builder addOverload(
checkNotNull(overload);
overloadsBuilder()
- .put(overloadId, CelResolvedOverload.of(overloadId, overload, isStrict, argTypes));
+ .put(
+ overloadId,
+ CelResolvedOverload.of(
+ overloadId,
+ args -> guardedOp(overloadId, args, argTypes, isStrict, overload),
+ isStrict,
+ argTypes));
return this;
}
public abstract DefaultDispatcher build();
}
+ /** Creates an invocation guard around the overload definition. */
+ private static Object guardedOp(
+ String functionName,
+ Object[] args,
+ ImmutableList> argTypes,
+ boolean isStrict,
+ CelFunctionOverload overload)
+ throws CelEvaluationException {
+ // Argument checking for DynamicDispatch is handled inside the overload's apply method itself.
+ if (overload instanceof DynamicDispatchOverload
+ || CelFunctionOverload.canHandle(args, argTypes, isStrict)) {
+ return overload.apply(args);
+ }
+
+ throw new IllegalArgumentException("No matching overload for function: " + functionName);
+ }
+
DefaultDispatcher(ImmutableMap overloads) {
this.overloads = overloads;
}
diff --git a/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java b/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java
index b554ce41a..4dcdcb74f 100644
--- a/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java
+++ b/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java
@@ -15,6 +15,8 @@
package dev.cel.runtime;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
import com.google.errorprone.annotations.Immutable;
@Immutable
@@ -58,4 +60,88 @@ public boolean isStrict() {
this.definition = definition;
this.isStrict = isStrict;
}
+
+ static ImmutableSet groupOverloadsToFunction(
+ String functionName, ImmutableSet overloadBindings) {
+ ImmutableSet.Builder builder = ImmutableSet.builder();
+ builder.addAll(overloadBindings);
+
+ // If there is already a binding with the same name as the function, we treat it as a
+ // "Singleton" binding and do not create a dynamic dispatch wrapper for it.
+ // (Ex: "matches" function)
+ boolean hasSingletonBinding =
+ overloadBindings.stream().anyMatch(b -> b.getOverloadId().equals(functionName));
+
+ if (!hasSingletonBinding) {
+ if (overloadBindings.size() == 1) {
+ CelFunctionBinding singleBinding = Iterables.getOnlyElement(overloadBindings);
+ builder.add(
+ new FunctionBindingImpl(
+ functionName,
+ singleBinding.getArgTypes(),
+ singleBinding.getDefinition(),
+ singleBinding.isStrict()));
+ } else {
+ builder.add(new DynamicDispatchBinding(functionName, overloadBindings));
+ }
+ }
+
+ return builder.build();
+ }
+
+ @Immutable
+ static final class DynamicDispatchBinding implements CelFunctionBinding {
+
+ private final boolean isStrict;
+ private final DynamicDispatchOverload dynamicDispatchOverload;
+
+ @Override
+ public String getOverloadId() {
+ return dynamicDispatchOverload.functionName;
+ }
+
+ @Override
+ public ImmutableList> getArgTypes() {
+ return ImmutableList.of();
+ }
+
+ @Override
+ public CelFunctionOverload getDefinition() {
+ return dynamicDispatchOverload;
+ }
+
+ @Override
+ public boolean isStrict() {
+ return isStrict;
+ }
+
+ private DynamicDispatchBinding(
+ String functionName, ImmutableSet overloadBindings) {
+ this.isStrict = overloadBindings.stream().allMatch(CelFunctionBinding::isStrict);
+ this.dynamicDispatchOverload = new DynamicDispatchOverload(functionName, overloadBindings);
+ }
+ }
+
+ @Immutable
+ static final class DynamicDispatchOverload implements CelFunctionOverload {
+ private final String functionName;
+ private final ImmutableSet overloadBindings;
+
+ @Override
+ public Object apply(Object[] args) throws CelEvaluationException {
+ for (CelFunctionBinding overload : overloadBindings) {
+ if (CelFunctionOverload.canHandle(args, overload.getArgTypes(), overload.isStrict())) {
+ return overload.getDefinition().apply(args);
+ }
+ }
+
+ throw new IllegalArgumentException("No matching overload for function: " + functionName);
+ }
+
+ private DynamicDispatchOverload(
+ String functionName, ImmutableSet overloadBindings) {
+ this.functionName = functionName;
+ this.overloadBindings = overloadBindings;
+ }
+ }
}
diff --git a/runtime/src/main/java/dev/cel/runtime/standard/CelStandardFunction.java b/runtime/src/main/java/dev/cel/runtime/standard/CelStandardFunction.java
index 73d53287c..f9f919413 100644
--- a/runtime/src/main/java/dev/cel/runtime/standard/CelStandardFunction.java
+++ b/runtime/src/main/java/dev/cel/runtime/standard/CelStandardFunction.java
@@ -15,6 +15,7 @@
package dev.cel.runtime.standard;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableSet;
@@ -34,12 +35,12 @@ public abstract class CelStandardFunction {
public ImmutableSet newFunctionBindings(
CelOptions celOptions, RuntimeEquality runtimeEquality) {
- ImmutableSet.Builder builder = ImmutableSet.builder();
- for (CelStandardOverload overload : overloads) {
- builder.add(overload.newFunctionBinding(celOptions, runtimeEquality));
- }
+ ImmutableSet overloadBindings =
+ overloads.stream()
+ .map(overload -> overload.newFunctionBinding(celOptions, runtimeEquality))
+ .collect(toImmutableSet());
- return builder.build();
+ return CelFunctionBinding.fromOverloads(name, overloadBindings);
}
CelStandardFunction(String name, ImmutableSet overloads) {
diff --git a/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel
index 2106fa5fe..6749be24f 100644
--- a/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel
+++ b/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel
@@ -20,7 +20,6 @@ java_library(
"//common:compiler_common",
"//common:container",
"//common:error_codes",
- "//common:operator",
"//common:options",
"//common/ast",
"//common/exceptions:divide_by_zero",
@@ -44,21 +43,10 @@ java_library(
"//runtime:dispatcher",
"//runtime:function_binding",
"//runtime:program",
- "//runtime:resolved_overload",
"//runtime:runtime_equality",
"//runtime:runtime_helpers",
+ "//runtime:standard_functions",
"//runtime/planner:program_planner",
- "//runtime/standard:add",
- "//runtime/standard:divide",
- "//runtime/standard:dyn",
- "//runtime/standard:equals",
- "//runtime/standard:greater",
- "//runtime/standard:greater_equals",
- "//runtime/standard:index",
- "//runtime/standard:less",
- "//runtime/standard:logical_not",
- "//runtime/standard:not_strictly_false",
- "//runtime/standard:standard_function",
"@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto",
"@maven//:com_google_guava_guava",
"@maven//:com_google_testparameterinjector_test_parameter_injector",
diff --git a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java
index 52ab5e417..ba0c8a548 100644
--- a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java
+++ b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java
@@ -25,7 +25,6 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Iterables;
import com.google.common.primitives.UnsignedLong;
import com.google.testing.junit.testparameterinjector.TestParameter;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
@@ -36,7 +35,6 @@
import dev.cel.common.CelErrorCode;
import dev.cel.common.CelOptions;
import dev.cel.common.CelSource;
-import dev.cel.common.Operator;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.exceptions.CelDivideByZeroException;
import dev.cel.common.internal.CelDescriptorPool;
@@ -69,23 +67,12 @@
import dev.cel.parser.CelStandardMacro;
import dev.cel.runtime.CelEvaluationException;
import dev.cel.runtime.CelFunctionBinding;
-import dev.cel.runtime.CelFunctionOverload;
-import dev.cel.runtime.CelResolvedOverload;
+import dev.cel.runtime.CelStandardFunctions;
+import dev.cel.runtime.CelStandardFunctions.StandardFunction;
import dev.cel.runtime.DefaultDispatcher;
import dev.cel.runtime.Program;
import dev.cel.runtime.RuntimeEquality;
import dev.cel.runtime.RuntimeHelpers;
-import dev.cel.runtime.standard.AddOperator;
-import dev.cel.runtime.standard.CelStandardFunction;
-import dev.cel.runtime.standard.DivideOperator;
-import dev.cel.runtime.standard.DynFunction;
-import dev.cel.runtime.standard.EqualsOperator;
-import dev.cel.runtime.standard.GreaterEqualsOperator;
-import dev.cel.runtime.standard.GreaterOperator;
-import dev.cel.runtime.standard.IndexOperator;
-import dev.cel.runtime.standard.LessOperator;
-import dev.cel.runtime.standard.LogicalNotOperator;
-import dev.cel.runtime.standard.NotStrictlyFalseFunction;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -162,136 +149,88 @@ private static DefaultDispatcher newDispatcher() {
DefaultDispatcher.Builder builder = DefaultDispatcher.newBuilder();
// Subsetted StdLib
- addBindings(
- builder, Operator.INDEX.getFunction(), fromStandardFunction(IndexOperator.create()));
- addBindings(
- builder,
- Operator.LOGICAL_NOT.getFunction(),
- fromStandardFunction(LogicalNotOperator.create()));
- addBindings(builder, Operator.ADD.getFunction(), fromStandardFunction(AddOperator.create()));
- addBindings(
- builder, Operator.GREATER.getFunction(), fromStandardFunction(GreaterOperator.create()));
- addBindings(
- builder,
- Operator.GREATER_EQUALS.getFunction(),
- fromStandardFunction(GreaterEqualsOperator.create()));
- addBindings(builder, Operator.LESS.getFunction(), fromStandardFunction(LessOperator.create()));
- addBindings(
- builder, Operator.DIVIDE.getFunction(), fromStandardFunction(DivideOperator.create()));
- addBindings(
- builder, Operator.EQUALS.getFunction(), fromStandardFunction(EqualsOperator.create()));
- addBindings(
- builder,
- Operator.NOT_STRICTLY_FALSE.getFunction(),
- fromStandardFunction(NotStrictlyFalseFunction.create()));
- addBindings(builder, "dyn", fromStandardFunction(DynFunction.create()));
+ CelStandardFunctions stdFunctions =
+ CelStandardFunctions.newBuilder()
+ .includeFunctions(
+ StandardFunction.INDEX,
+ StandardFunction.LOGICAL_NOT,
+ StandardFunction.ADD,
+ StandardFunction.GREATER,
+ StandardFunction.GREATER_EQUALS,
+ StandardFunction.LESS,
+ StandardFunction.DIVIDE,
+ StandardFunction.EQUALS,
+ StandardFunction.NOT_STRICTLY_FALSE,
+ StandardFunction.DYN)
+ .build();
+ addBindingsToDispatcher(
+ builder, stdFunctions.newFunctionBindings(RUNTIME_EQUALITY, CEL_OPTIONS));
// Custom functions
- addBindings(
+ addBindingsToDispatcher(
builder,
- "zero",
- CelFunctionBinding.from("zero_overload", ImmutableList.of(), (unused) -> 0L));
- addBindings(
+ CelFunctionBinding.fromOverloads(
+ "zero", CelFunctionBinding.from("zero_overload", ImmutableList.of(), (unused) -> 0L)));
+
+ addBindingsToDispatcher(
builder,
- "error",
- CelFunctionBinding.from(
- "error_overload",
- ImmutableList.of(),
- (unused) -> {
- throw new IllegalArgumentException("Intentional error");
- }));
- addBindings(
+ CelFunctionBinding.fromOverloads(
+ "error",
+ CelFunctionBinding.from(
+ "error_overload",
+ ImmutableList.of(),
+ (unused) -> {
+ throw new IllegalArgumentException("Intentional error");
+ })));
+
+ addBindingsToDispatcher(
builder,
- "neg",
- CelFunctionBinding.from("neg_int", Long.class, arg -> -arg),
- CelFunctionBinding.from("neg_double", Double.class, arg -> -arg));
- addBindings(
+ CelFunctionBinding.fromOverloads(
+ "neg",
+ CelFunctionBinding.from("neg_int", Long.class, arg -> -arg),
+ CelFunctionBinding.from("neg_double", Double.class, arg -> -arg)));
+
+ addBindingsToDispatcher(
builder,
- "cel.expr.conformance.proto3.power",
- CelFunctionBinding.from(
- "power_int_int",
- Long.class,
- Long.class,
- (value, power) -> (long) Math.pow(value, power)));
- addBindings(
+ CelFunctionBinding.fromOverloads(
+ "cel.expr.conformance.proto3.power",
+ CelFunctionBinding.from(
+ "power_int_int",
+ Long.class,
+ Long.class,
+ (value, power) -> (long) Math.pow(value, power))));
+
+ addBindingsToDispatcher(
builder,
- "concat",
- CelFunctionBinding.from(
- "concat_bytes_bytes",
- CelByteString.class,
- CelByteString.class,
- ProgramPlannerTest::concatenateByteArrays),
- CelFunctionBinding.from(
- "bytes_concat_bytes",
- CelByteString.class,
- CelByteString.class,
- ProgramPlannerTest::concatenateByteArrays));
+ CelFunctionBinding.fromOverloads(
+ "concat",
+ CelFunctionBinding.from(
+ "concat_bytes_bytes",
+ CelByteString.class,
+ CelByteString.class,
+ ProgramPlannerTest::concatenateByteArrays),
+ CelFunctionBinding.from(
+ "bytes_concat_bytes",
+ CelByteString.class,
+ CelByteString.class,
+ ProgramPlannerTest::concatenateByteArrays)));
return builder.build();
}
- private static void addBindings(
- DefaultDispatcher.Builder builder,
- String functionName,
- CelFunctionBinding... functionBindings) {
- addBindings(builder, functionName, ImmutableSet.copyOf(functionBindings));
- }
-
- private static void addBindings(
- DefaultDispatcher.Builder builder,
- String functionName,
- ImmutableCollection overloadBindings) {
+ private static void addBindingsToDispatcher(
+ DefaultDispatcher.Builder builder, ImmutableCollection overloadBindings) {
if (overloadBindings.isEmpty()) {
throw new IllegalArgumentException("Invalid bindings");
}
- // TODO: Runtime top-level APIs currently does not allow grouping overloads with
- // the function name. This capability will have to be added.
- if (overloadBindings.size() == 1) {
- CelFunctionBinding singleBinding = Iterables.getOnlyElement(overloadBindings);
- builder.addOverload(
- functionName,
- singleBinding.getArgTypes(),
- singleBinding.isStrict(),
- args -> guardedOp(functionName, args, singleBinding));
- } else {
- overloadBindings.forEach(
- overload ->
- builder.addOverload(
- overload.getOverloadId(),
- overload.getArgTypes(),
- overload.isStrict(),
- args -> guardedOp(functionName, args, overload)));
-
- // Setup dynamic dispatch
- CelFunctionOverload dynamicDispatchDef =
- args -> {
- for (CelFunctionBinding overload : overloadBindings) {
- if (CelResolvedOverload.canHandle(
- args, overload.getArgTypes(), overload.isStrict())) {
- return overload.getDefinition().apply(args);
- }
- }
-
- throw new IllegalArgumentException(
- "No matching overload for function: " + functionName);
- };
-
- boolean allOverloadsStrict = overloadBindings.stream().allMatch(CelFunctionBinding::isStrict);
- builder.addOverload(
- functionName, ImmutableList.of(), /* isStrict= */ allOverloadsStrict, dynamicDispatchDef);
- }
- }
- /** Creates an invocation guard around the overload definition. */
- private static Object guardedOp(
- String functionName, Object[] args, CelFunctionBinding singleBinding)
- throws CelEvaluationException {
- if (!CelResolvedOverload.canHandle(
- args, singleBinding.getArgTypes(), singleBinding.isStrict())) {
- throw new IllegalArgumentException("No matching overload for function: " + functionName);
- }
-
- return singleBinding.getDefinition().apply(args);
+ overloadBindings.forEach(
+ overload ->
+ builder.addOverload(
+ overload.getOverloadId(),
+ overload.getArgTypes(),
+ overload.isStrict(),
+ overload.getDefinition()));
}
@TestParameter boolean isParseOnly;
@@ -872,11 +811,6 @@ private static CelByteString concatenateByteArrays(CelByteString bytes1, CelByte
return bytes1.concat(bytes2);
}
- private static ImmutableSet fromStandardFunction(
- CelStandardFunction standardFunction) {
- return standardFunction.newFunctionBindings(CEL_OPTIONS, RUNTIME_EQUALITY);
- }
-
@SuppressWarnings("ImmutableEnumChecker") // Test only
private enum ConstantTestCase {
NULL("null", NullValue.NULL_VALUE),