From 233430f86fdd7144b836e29b3b1c3b0a1101a034 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 6 Jan 2026 15:33:49 -0800 Subject: [PATCH] Group overloads to a single function to facilitate dynamic dispatch PiperOrigin-RevId: 852961755 --- .../src/main/java/dev/cel/runtime/BUILD.bazel | 12 +- .../dev/cel/runtime/CelFunctionBinding.java | 25 ++- .../dev/cel/runtime/CelFunctionOverload.java | 37 ++++ .../dev/cel/runtime/CelResolvedOverload.java | 38 +--- .../dev/cel/runtime/CelStandardFunctions.java | 122 +++++++---- .../dev/cel/runtime/DefaultDispatcher.java | 29 ++- .../dev/cel/runtime/FunctionBindingImpl.java | 86 ++++++++ .../runtime/standard/CelStandardFunction.java | 11 +- .../java/dev/cel/runtime/planner/BUILD.bazel | 14 +- .../runtime/planner/ProgramPlannerTest.java | 204 ++++++------------ 10 files changed, 346 insertions(+), 232 deletions(-) diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index ac5599ba9..3b0e2762a 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -122,6 +122,7 @@ java_library( deps = [ ":evaluation_exception", ":evaluation_exception_builder", + ":function_binding", ":function_overload", ":function_resolver", ":resolved_overload", @@ -141,6 +142,7 @@ cel_android_library( deps = [ ":evaluation_exception", ":evaluation_exception_builder", + ":function_binding_android", ":function_overload_android", ":function_resolver_android", ":resolved_overload_android", @@ -609,6 +611,7 @@ java_library( deps = [ ":function_binding", ":runtime_equality", + "//common:operator", "//common:options", "//common/annotations", "//runtime/standard:add", @@ -666,6 +669,7 @@ cel_android_library( deps = [ ":function_binding_android", ":runtime_equality_android", + "//common:operator_android", "//common:options", "//common/annotations", "//runtime/standard:add_android", @@ -723,6 +727,7 @@ java_library( tags = [ ], deps = [ + ":evaluation_exception", ":function_overload", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -735,6 +740,7 @@ cel_android_library( tags = [ ], deps = [ + ":evaluation_exception", ":function_overload_android", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_guava_guava", @@ -774,7 +780,9 @@ java_library( ], deps = [ ":evaluation_exception", + "//runtime:unknown_attributes", "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", ], ) @@ -785,7 +793,9 @@ cel_android_library( ], deps = [ ":evaluation_exception", + "//runtime:unknown_attributes_android", "@maven//:com_google_errorprone_error_prone_annotations", + "@maven_android//:com_google_guava_guava", ], ) @@ -1173,7 +1183,6 @@ java_library( ], deps = [ ":function_overload", - ":unknown_attributes", "//:auto_value", "//common/annotations", "@maven//:com_google_errorprone_error_prone_annotations", @@ -1188,7 +1197,6 @@ cel_android_library( ], deps = [ ":function_overload_android", - ":unknown_attributes_android", "//:auto_value", "//common/annotations", "@maven//:com_google_errorprone_error_prone_annotations", diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java index 79b0f3f54..c7b63926b 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java +++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java @@ -14,8 +14,13 @@ package dev.cel.runtime; +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.Immutable; +import java.util.Collection; /** * Binding consisting of an overload id, a Java-native argument signature, and an overload @@ -35,7 +40,6 @@ * *

Examples: string_startsWith_string, mathMax_list, lessThan_money_money */ - @Immutable public interface CelFunctionBinding { String getOverloadId(); @@ -70,4 +74,23 @@ static CelFunctionBinding from( return new FunctionBindingImpl( overloadId, ImmutableList.copyOf(argTypes), impl, /* isStrict= */ true); } + + /** See {@link #fromOverloads(String, Collection)}. */ + static ImmutableSet fromOverloads( + String functionName, CelFunctionBinding... overloadBindings) { + return fromOverloads(functionName, ImmutableList.copyOf(overloadBindings)); + } + + /** + * Creates a set of bindings for a function, enabling dynamic dispatch logic to select the correct + * overload at runtime based on argument types. + */ + static ImmutableSet fromOverloads( + String functionName, Collection overloadBindings) { + checkArgument(!Strings.isNullOrEmpty(functionName), "Function name cannot be null or empty"); + checkArgument(!overloadBindings.isEmpty(), "You must provide at least one binding."); + + return FunctionBindingImpl.groupOverloadsToFunction( + functionName, ImmutableSet.copyOf(overloadBindings)); + } } diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java index a1341cb21..3e30a2146 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java @@ -14,7 +14,9 @@ package dev.cel.runtime; +import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; +import java.util.Map; /** Interface describing the general signature of all CEL custom function implementations. */ @Immutable @@ -43,4 +45,39 @@ interface Unary { interface Binary { Object apply(T1 arg1, T2 arg2) throws CelEvaluationException; } + + /** + * Returns true if the overload's expected argument types match the types of the given arguments. + */ + static boolean canHandle( + Object[] arguments, ImmutableList> parameterTypes, boolean isStrict) { + if (parameterTypes.size() != arguments.length) { + return false; + } + for (int i = 0; i < parameterTypes.size(); i++) { + Class paramType = parameterTypes.get(i); + Object arg = arguments[i]; + if (arg == null) { + // null can be assigned to messages, maps, and to objects. + // TODO: Remove null special casing + if (paramType != Object.class && !Map.class.isAssignableFrom(paramType)) { + return false; + } + continue; + } + + if (arg instanceof Exception || arg instanceof CelUnknownSet) { + // Only non-strict functions can accept errors/unknowns as arguments to a function + if (!isStrict) { + // Skip assignability check below, but continue to validate remaining args + continue; + } + } + + if (!paramType.isAssignableFrom(arg.getClass())) { + return false; + } + } + return true; + } } diff --git a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java index f6a0c4f99..2bcdf3a2d 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java @@ -19,7 +19,6 @@ import com.google.errorprone.annotations.Immutable; import dev.cel.common.annotations.Internal; import java.util.List; -import java.util.Map; /** * Representation of a function overload which has been resolved to a specific set of argument types @@ -80,41 +79,6 @@ public static CelResolvedOverload of( * Returns true if the overload's expected argument types match the types of the given arguments. */ boolean canHandle(Object[] arguments) { - return canHandle(arguments, getParameterTypes(), isStrict()); - } - - /** - * Returns true if the overload's expected argument types match the types of the given arguments. - */ - public static boolean canHandle( - Object[] arguments, ImmutableList> parameterTypes, boolean isStrict) { - if (parameterTypes.size() != arguments.length) { - return false; - } - for (int i = 0; i < parameterTypes.size(); i++) { - Class paramType = parameterTypes.get(i); - Object arg = arguments[i]; - if (arg == null) { - // null can be assigned to messages, maps, and to objects. - // TODO: Remove null special casing - if (paramType != Object.class && !Map.class.isAssignableFrom(paramType)) { - return false; - } - continue; - } - - if (arg instanceof Exception || arg instanceof CelUnknownSet) { - // Only non-strict functions can accept errors/unknowns as arguments to a function - if (!isStrict) { - // Skip assignability check below, but continue to validate remaining args - continue; - } - } - - if (!paramType.isAssignableFrom(arg.getClass())) { - return false; - } - } - return true; + return CelFunctionOverload.canHandle(arguments, getParameterTypes(), isStrict()); } } diff --git a/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java b/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java index bedd41728..39797e086 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java +++ b/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java @@ -15,12 +15,15 @@ package dev.cel.runtime; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.Immutable; import dev.cel.common.CelOptions; +import dev.cel.common.Operator; import dev.cel.common.annotations.Internal; import dev.cel.runtime.standard.AddOperator; import dev.cel.runtime.standard.AddOperator.AddOverload; @@ -104,6 +107,8 @@ import dev.cel.runtime.standard.TimestampFunction.TimestampOverload; import dev.cel.runtime.standard.UintFunction; import dev.cel.runtime.standard.UintFunction.UintOverload; +import java.util.Collection; +import java.util.Map; /** Runtime function bindings for the standard functions in CEL. */ @Immutable @@ -135,7 +140,7 @@ public final class CelStandardFunctions { GreaterEqualsOverload.GREATER_EQUALS_UINT64_DOUBLE, GreaterEqualsOverload.GREATER_EQUALS_DOUBLE_UINT64); - private final ImmutableSet standardOverloads; + private final ImmutableMultimap standardOverloads; public static final ImmutableSet ALL_STANDARD_FUNCTIONS = ImmutableSet.of( @@ -187,13 +192,15 @@ public final class CelStandardFunctions { * special-cased, and does not appear in this enum. */ public enum StandardFunction { - LOGICAL_NOT(LogicalNotOverload.LOGICAL_NOT), - IN(InOverload.IN_LIST, InOverload.IN_MAP), - NOT_STRICTLY_FALSE(NotStrictlyFalseOverload.NOT_STRICTLY_FALSE), - EQUALS(EqualsOverload.EQUALS), - NOT_EQUALS(NotEqualsOverload.NOT_EQUALS), - BOOL(BoolOverload.BOOL_TO_BOOL, BoolOverload.STRING_TO_BOOL), + LOGICAL_NOT(Operator.LOGICAL_NOT.getFunction(), LogicalNotOverload.LOGICAL_NOT), + IN(Operator.IN.getFunction(), InOverload.IN_LIST, InOverload.IN_MAP), + NOT_STRICTLY_FALSE( + Operator.NOT_STRICTLY_FALSE.getFunction(), NotStrictlyFalseOverload.NOT_STRICTLY_FALSE), + EQUALS(Operator.EQUALS.getFunction(), EqualsOverload.EQUALS), + NOT_EQUALS(Operator.NOT_EQUALS.getFunction(), NotEqualsOverload.NOT_EQUALS), + BOOL("bool", BoolOverload.BOOL_TO_BOOL, BoolOverload.STRING_TO_BOOL), ADD( + Operator.ADD.getFunction(), AddOverload.ADD_INT64, AddOverload.ADD_UINT64, AddOverload.ADD_DOUBLE, @@ -204,6 +211,7 @@ public enum StandardFunction { AddOverload.ADD_DURATION_TIMESTAMP, AddOverload.ADD_DURATION_DURATION), SUBTRACT( + Operator.SUBTRACT.getFunction(), SubtractOverload.SUBTRACT_INT64, SubtractOverload.SUBTRACT_TIMESTAMP_TIMESTAMP, SubtractOverload.SUBTRACT_TIMESTAMP_DURATION, @@ -211,14 +219,22 @@ public enum StandardFunction { SubtractOverload.SUBTRACT_DOUBLE, SubtractOverload.SUBTRACT_DURATION_DURATION), MULTIPLY( + Operator.MULTIPLY.getFunction(), MultiplyOverload.MULTIPLY_INT64, MultiplyOverload.MULTIPLY_DOUBLE, MultiplyOverload.MULTIPLY_UINT64), - DIVIDE(DivideOverload.DIVIDE_DOUBLE, DivideOverload.DIVIDE_INT64, DivideOverload.DIVIDE_UINT64), - MODULO(ModuloOverload.MODULO_INT64, ModuloOverload.MODULO_UINT64), - NEGATE(NegateOverload.NEGATE_INT64, NegateOverload.NEGATE_DOUBLE), - INDEX(IndexOverload.INDEX_LIST, IndexOverload.INDEX_MAP), + DIVIDE( + Operator.DIVIDE.getFunction(), + DivideOverload.DIVIDE_DOUBLE, + DivideOverload.DIVIDE_INT64, + DivideOverload.DIVIDE_UINT64), + MODULO( + Operator.MODULO.getFunction(), ModuloOverload.MODULO_INT64, ModuloOverload.MODULO_UINT64), + NEGATE( + Operator.NEGATE.getFunction(), NegateOverload.NEGATE_INT64, NegateOverload.NEGATE_DOUBLE), + INDEX(Operator.INDEX.getFunction(), IndexOverload.INDEX_LIST, IndexOverload.INDEX_MAP), SIZE( + "size", SizeOverload.SIZE_STRING, SizeOverload.SIZE_BYTES, SizeOverload.SIZE_LIST, @@ -228,22 +244,26 @@ public enum StandardFunction { SizeOverload.LIST_SIZE, SizeOverload.MAP_SIZE), INT( + "int", IntOverload.INT64_TO_INT64, IntOverload.UINT64_TO_INT64, IntOverload.DOUBLE_TO_INT64, IntOverload.STRING_TO_INT64, IntOverload.TIMESTAMP_TO_INT64), UINT( + "uint", UintOverload.UINT64_TO_UINT64, UintOverload.INT64_TO_UINT64, UintOverload.DOUBLE_TO_UINT64, UintOverload.STRING_TO_UINT64), DOUBLE( + "double", DoubleOverload.DOUBLE_TO_DOUBLE, DoubleOverload.INT64_TO_DOUBLE, DoubleOverload.STRING_TO_DOUBLE, DoubleOverload.UINT64_TO_DOUBLE), STRING( + "string", StringOverload.STRING_TO_STRING, StringOverload.INT64_TO_STRING, StringOverload.DOUBLE_TO_STRING, @@ -252,51 +272,67 @@ public enum StandardFunction { StringOverload.TIMESTAMP_TO_STRING, StringOverload.DURATION_TO_STRING, StringOverload.UINT64_TO_STRING), - BYTES(BytesOverload.BYTES_TO_BYTES, BytesOverload.STRING_TO_BYTES), - DURATION(DurationOverload.DURATION_TO_DURATION, DurationOverload.STRING_TO_DURATION), + BYTES("bytes", BytesOverload.BYTES_TO_BYTES, BytesOverload.STRING_TO_BYTES), + DURATION( + "duration", DurationOverload.DURATION_TO_DURATION, DurationOverload.STRING_TO_DURATION), TIMESTAMP( + "timestamp", TimestampOverload.STRING_TO_TIMESTAMP, TimestampOverload.TIMESTAMP_TO_TIMESTAMP, TimestampOverload.INT64_TO_TIMESTAMP), - DYN(DynOverload.TO_DYN), - MATCHES(MatchesOverload.MATCHES, MatchesOverload.MATCHES_STRING), - CONTAINS(ContainsOverload.CONTAINS_STRING), - ENDS_WITH(EndsWithOverload.ENDS_WITH_STRING), - STARTS_WITH(StartsWithOverload.STARTS_WITH_STRING), + DYN("dyn", DynOverload.TO_DYN), + MATCHES("matches", MatchesOverload.MATCHES, MatchesOverload.MATCHES_STRING), + CONTAINS("contains", ContainsOverload.CONTAINS_STRING), + ENDS_WITH("endsWith", EndsWithOverload.ENDS_WITH_STRING), + STARTS_WITH("startsWith", StartsWithOverload.STARTS_WITH_STRING), // Date/time Functions GET_FULL_YEAR( - GetFullYearOverload.TIMESTAMP_TO_YEAR, GetFullYearOverload.TIMESTAMP_TO_YEAR_WITH_TZ), - GET_MONTH(GetMonthOverload.TIMESTAMP_TO_MONTH, GetMonthOverload.TIMESTAMP_TO_MONTH_WITH_TZ), + "getFullYear", + GetFullYearOverload.TIMESTAMP_TO_YEAR, + GetFullYearOverload.TIMESTAMP_TO_YEAR_WITH_TZ), + GET_MONTH( + "getMonth", + GetMonthOverload.TIMESTAMP_TO_MONTH, + GetMonthOverload.TIMESTAMP_TO_MONTH_WITH_TZ), GET_DAY_OF_YEAR( + "getDayOfYear", GetDayOfYearOverload.TIMESTAMP_TO_DAY_OF_YEAR, GetDayOfYearOverload.TIMESTAMP_TO_DAY_OF_YEAR_WITH_TZ), GET_DAY_OF_MONTH( + "getDayOfMonth", GetDayOfMonthOverload.TIMESTAMP_TO_DAY_OF_MONTH, GetDayOfMonthOverload.TIMESTAMP_TO_DAY_OF_MONTH_WITH_TZ), GET_DATE( + "getDate", GetDateOverload.TIMESTAMP_TO_DAY_OF_MONTH_1_BASED, GetDateOverload.TIMESTAMP_TO_DAY_OF_MONTH_1_BASED_WITH_TZ), GET_DAY_OF_WEEK( + "getDayOfWeek", GetDayOfWeekOverload.TIMESTAMP_TO_DAY_OF_WEEK, GetDayOfWeekOverload.TIMESTAMP_TO_DAY_OF_WEEK_WITH_TZ), GET_HOURS( + "getHours", GetHoursOverload.TIMESTAMP_TO_HOURS, GetHoursOverload.TIMESTAMP_TO_HOURS_WITH_TZ, GetHoursOverload.DURATION_TO_HOURS), GET_MINUTES( + "getMinutes", GetMinutesOverload.TIMESTAMP_TO_MINUTES, GetMinutesOverload.TIMESTAMP_TO_MINUTES_WITH_TZ, GetMinutesOverload.DURATION_TO_MINUTES), GET_SECONDS( + "getSeconds", GetSecondsOverload.TIMESTAMP_TO_SECONDS, GetSecondsOverload.TIMESTAMP_TO_SECONDS_WITH_TZ, GetSecondsOverload.DURATION_TO_SECONDS), GET_MILLISECONDS( + "getMilliseconds", GetMillisecondsOverload.TIMESTAMP_TO_MILLISECONDS, GetMillisecondsOverload.TIMESTAMP_TO_MILLISECONDS_WITH_TZ, GetMillisecondsOverload.DURATION_TO_MILLISECONDS), LESS( + Operator.LESS.getFunction(), LessOverload.LESS_BOOL, LessOverload.LESS_INT64, LessOverload.LESS_UINT64, @@ -312,6 +348,7 @@ public enum StandardFunction { LessOverload.LESS_UINT64_DOUBLE, LessOverload.LESS_DOUBLE_UINT64), LESS_EQUALS( + Operator.LESS_EQUALS.getFunction(), LessEqualsOverload.LESS_EQUALS_BOOL, LessEqualsOverload.LESS_EQUALS_INT64, LessEqualsOverload.LESS_EQUALS_UINT64, @@ -327,6 +364,7 @@ public enum StandardFunction { LessEqualsOverload.LESS_EQUALS_UINT64_DOUBLE, LessEqualsOverload.LESS_EQUALS_DOUBLE_UINT64), GREATER( + Operator.GREATER.getFunction(), GreaterOverload.GREATER_BOOL, GreaterOverload.GREATER_INT64, GreaterOverload.GREATER_UINT64, @@ -342,6 +380,7 @@ public enum StandardFunction { GreaterOverload.GREATER_UINT64_DOUBLE, GreaterOverload.GREATER_DOUBLE_UINT64), GREATER_EQUALS( + Operator.GREATER_EQUALS.getFunction(), GreaterEqualsOverload.GREATER_EQUALS_BOOL, GreaterEqualsOverload.GREATER_EQUALS_BYTES, GreaterEqualsOverload.GREATER_EQUALS_DOUBLE, @@ -357,9 +396,11 @@ public enum StandardFunction { GreaterEqualsOverload.GREATER_EQUALS_UINT64_DOUBLE, GreaterEqualsOverload.GREATER_EQUALS_DOUBLE_UINT64); + private final String functionName; private final ImmutableSet standardOverloads; - StandardFunction(CelStandardOverload... overloads) { + StandardFunction(String functionName, CelStandardOverload... overloads) { + this.functionName = functionName; this.standardOverloads = ImmutableSet.copyOf(overloads); } @@ -371,15 +412,25 @@ ImmutableSet getOverloads() { @VisibleForTesting ImmutableSet getOverloads() { - return standardOverloads; + return ImmutableSet.copyOf(standardOverloads.values()); } @Internal public ImmutableSet newFunctionBindings( RuntimeEquality runtimeEquality, CelOptions celOptions) { ImmutableSet.Builder builder = ImmutableSet.builder(); - for (CelStandardOverload overload : standardOverloads) { - builder.add(overload.newFunctionBinding(celOptions, runtimeEquality)); + + for (Map.Entry> entry : + standardOverloads.asMap().entrySet()) { + String functionName = entry.getKey(); + Collection overloads = entry.getValue(); + + ImmutableSet bindings = + overloads.stream() + .map(o -> o.newFunctionBinding(celOptions, runtimeEquality)) + .collect(toImmutableSet()); + + builder.addAll(CelFunctionBinding.fromOverloads(functionName, bindings)); } return builder.build(); @@ -454,39 +505,36 @@ public CelStandardFunctions build() { "You may only populate one of the following builder methods: includeFunctions," + " excludeFunctions or filterFunctions"); - ImmutableSet.Builder standardOverloadBuilder = ImmutableSet.builder(); + ImmutableMultimap.Builder standardOverloadBuilder = + ImmutableMultimap.builder(); for (StandardFunction standardFunction : StandardFunction.values()) { if (hasIncludeFunctions) { if (this.includeFunctions.contains(standardFunction)) { - standardOverloadBuilder.addAll(standardFunction.standardOverloads); + standardOverloadBuilder.putAll( + standardFunction.functionName, standardFunction.standardOverloads); } continue; } if (hasExcludeFunctions) { if (!this.excludeFunctions.contains(standardFunction)) { - standardOverloadBuilder.addAll(standardFunction.standardOverloads); + standardOverloadBuilder.putAll( + standardFunction.functionName, standardFunction.standardOverloads); } continue; } if (hasFilterFunction) { - ImmutableSet.Builder filteredOverloadsBuilder = - ImmutableSet.builder(); for (CelStandardOverload standardOverload : standardFunction.standardOverloads) { boolean includeOverload = functionFilter.include(standardFunction, standardOverload); if (includeOverload) { - standardOverloadBuilder.add(standardOverload); + standardOverloadBuilder.put(standardFunction.functionName, standardOverload); } } - ImmutableSet filteredOverloads = filteredOverloadsBuilder.build(); - if (!filteredOverloads.isEmpty()) { - standardOverloadBuilder.addAll(filteredOverloads); - } - continue; } - standardOverloadBuilder.addAll(standardFunction.standardOverloads); + standardOverloadBuilder.putAll( + standardFunction.functionName, standardFunction.standardOverloads); } return new CelStandardFunctions(standardOverloadBuilder.build()); @@ -511,7 +559,7 @@ static boolean isHeterogeneousComparison(CelStandardOverload overload) { return HETEROGENEOUS_COMPARISON_OPERATORS.contains(overload); } - private CelStandardFunctions(ImmutableSet standardOverloads) { + private CelStandardFunctions(ImmutableMultimap standardOverloads) { this.standardOverloads = standardOverloads; } } diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java index 35ce243b3..b5e5ca55c 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java @@ -19,11 +19,13 @@ import com.google.auto.value.AutoBuilder; import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.Immutable; import dev.cel.common.CelErrorCode; import dev.cel.common.annotations.Internal; +import dev.cel.runtime.FunctionBindingImpl.DynamicDispatchOverload; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -127,7 +129,7 @@ public abstract static class Builder { @CanIgnoreReturnValue public Builder addOverload( String overloadId, - List> argTypes, + ImmutableList> argTypes, boolean isStrict, CelFunctionOverload overload) { checkNotNull(overloadId); @@ -136,13 +138,36 @@ public Builder addOverload( checkNotNull(overload); overloadsBuilder() - .put(overloadId, CelResolvedOverload.of(overloadId, overload, isStrict, argTypes)); + .put( + overloadId, + CelResolvedOverload.of( + overloadId, + args -> guardedOp(overloadId, args, argTypes, isStrict, overload), + isStrict, + argTypes)); return this; } public abstract DefaultDispatcher build(); } + /** Creates an invocation guard around the overload definition. */ + private static Object guardedOp( + String functionName, + Object[] args, + ImmutableList> argTypes, + boolean isStrict, + CelFunctionOverload overload) + throws CelEvaluationException { + // Argument checking for DynamicDispatch is handled inside the overload's apply method itself. + if (overload instanceof DynamicDispatchOverload + || CelFunctionOverload.canHandle(args, argTypes, isStrict)) { + return overload.apply(args); + } + + throw new IllegalArgumentException("No matching overload for function: " + functionName); + } + DefaultDispatcher(ImmutableMap overloads) { this.overloads = overloads; } diff --git a/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java b/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java index b554ce41a..4dcdcb74f 100644 --- a/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java @@ -15,6 +15,8 @@ package dev.cel.runtime; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.errorprone.annotations.Immutable; @Immutable @@ -58,4 +60,88 @@ public boolean isStrict() { this.definition = definition; this.isStrict = isStrict; } + + static ImmutableSet groupOverloadsToFunction( + String functionName, ImmutableSet overloadBindings) { + ImmutableSet.Builder builder = ImmutableSet.builder(); + builder.addAll(overloadBindings); + + // If there is already a binding with the same name as the function, we treat it as a + // "Singleton" binding and do not create a dynamic dispatch wrapper for it. + // (Ex: "matches" function) + boolean hasSingletonBinding = + overloadBindings.stream().anyMatch(b -> b.getOverloadId().equals(functionName)); + + if (!hasSingletonBinding) { + if (overloadBindings.size() == 1) { + CelFunctionBinding singleBinding = Iterables.getOnlyElement(overloadBindings); + builder.add( + new FunctionBindingImpl( + functionName, + singleBinding.getArgTypes(), + singleBinding.getDefinition(), + singleBinding.isStrict())); + } else { + builder.add(new DynamicDispatchBinding(functionName, overloadBindings)); + } + } + + return builder.build(); + } + + @Immutable + static final class DynamicDispatchBinding implements CelFunctionBinding { + + private final boolean isStrict; + private final DynamicDispatchOverload dynamicDispatchOverload; + + @Override + public String getOverloadId() { + return dynamicDispatchOverload.functionName; + } + + @Override + public ImmutableList> getArgTypes() { + return ImmutableList.of(); + } + + @Override + public CelFunctionOverload getDefinition() { + return dynamicDispatchOverload; + } + + @Override + public boolean isStrict() { + return isStrict; + } + + private DynamicDispatchBinding( + String functionName, ImmutableSet overloadBindings) { + this.isStrict = overloadBindings.stream().allMatch(CelFunctionBinding::isStrict); + this.dynamicDispatchOverload = new DynamicDispatchOverload(functionName, overloadBindings); + } + } + + @Immutable + static final class DynamicDispatchOverload implements CelFunctionOverload { + private final String functionName; + private final ImmutableSet overloadBindings; + + @Override + public Object apply(Object[] args) throws CelEvaluationException { + for (CelFunctionBinding overload : overloadBindings) { + if (CelFunctionOverload.canHandle(args, overload.getArgTypes(), overload.isStrict())) { + return overload.getDefinition().apply(args); + } + } + + throw new IllegalArgumentException("No matching overload for function: " + functionName); + } + + private DynamicDispatchOverload( + String functionName, ImmutableSet overloadBindings) { + this.functionName = functionName; + this.overloadBindings = overloadBindings; + } + } } diff --git a/runtime/src/main/java/dev/cel/runtime/standard/CelStandardFunction.java b/runtime/src/main/java/dev/cel/runtime/standard/CelStandardFunction.java index 73d53287c..f9f919413 100644 --- a/runtime/src/main/java/dev/cel/runtime/standard/CelStandardFunction.java +++ b/runtime/src/main/java/dev/cel/runtime/standard/CelStandardFunction.java @@ -15,6 +15,7 @@ package dev.cel.runtime.standard; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import com.google.common.base.Strings; import com.google.common.collect.ImmutableSet; @@ -34,12 +35,12 @@ public abstract class CelStandardFunction { public ImmutableSet newFunctionBindings( CelOptions celOptions, RuntimeEquality runtimeEquality) { - ImmutableSet.Builder builder = ImmutableSet.builder(); - for (CelStandardOverload overload : overloads) { - builder.add(overload.newFunctionBinding(celOptions, runtimeEquality)); - } + ImmutableSet overloadBindings = + overloads.stream() + .map(overload -> overload.newFunctionBinding(celOptions, runtimeEquality)) + .collect(toImmutableSet()); - return builder.build(); + return CelFunctionBinding.fromOverloads(name, overloadBindings); } CelStandardFunction(String name, ImmutableSet overloads) { diff --git a/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel index 2106fa5fe..6749be24f 100644 --- a/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel @@ -20,7 +20,6 @@ java_library( "//common:compiler_common", "//common:container", "//common:error_codes", - "//common:operator", "//common:options", "//common/ast", "//common/exceptions:divide_by_zero", @@ -44,21 +43,10 @@ java_library( "//runtime:dispatcher", "//runtime:function_binding", "//runtime:program", - "//runtime:resolved_overload", "//runtime:runtime_equality", "//runtime:runtime_helpers", + "//runtime:standard_functions", "//runtime/planner:program_planner", - "//runtime/standard:add", - "//runtime/standard:divide", - "//runtime/standard:dyn", - "//runtime/standard:equals", - "//runtime/standard:greater", - "//runtime/standard:greater_equals", - "//runtime/standard:index", - "//runtime/standard:less", - "//runtime/standard:logical_not", - "//runtime/standard:not_strictly_false", - "//runtime/standard:standard_function", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", "@maven//:com_google_guava_guava", "@maven//:com_google_testparameterinjector_test_parameter_injector", diff --git a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java index 52ab5e417..ba0c8a548 100644 --- a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java +++ b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java @@ -25,7 +25,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; import com.google.common.primitives.UnsignedLong; import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; @@ -36,7 +35,6 @@ import dev.cel.common.CelErrorCode; import dev.cel.common.CelOptions; import dev.cel.common.CelSource; -import dev.cel.common.Operator; import dev.cel.common.ast.CelExpr; import dev.cel.common.exceptions.CelDivideByZeroException; import dev.cel.common.internal.CelDescriptorPool; @@ -69,23 +67,12 @@ import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionBinding; -import dev.cel.runtime.CelFunctionOverload; -import dev.cel.runtime.CelResolvedOverload; +import dev.cel.runtime.CelStandardFunctions; +import dev.cel.runtime.CelStandardFunctions.StandardFunction; import dev.cel.runtime.DefaultDispatcher; import dev.cel.runtime.Program; import dev.cel.runtime.RuntimeEquality; import dev.cel.runtime.RuntimeHelpers; -import dev.cel.runtime.standard.AddOperator; -import dev.cel.runtime.standard.CelStandardFunction; -import dev.cel.runtime.standard.DivideOperator; -import dev.cel.runtime.standard.DynFunction; -import dev.cel.runtime.standard.EqualsOperator; -import dev.cel.runtime.standard.GreaterEqualsOperator; -import dev.cel.runtime.standard.GreaterOperator; -import dev.cel.runtime.standard.IndexOperator; -import dev.cel.runtime.standard.LessOperator; -import dev.cel.runtime.standard.LogicalNotOperator; -import dev.cel.runtime.standard.NotStrictlyFalseFunction; import org.junit.Test; import org.junit.runner.RunWith; @@ -162,136 +149,88 @@ private static DefaultDispatcher newDispatcher() { DefaultDispatcher.Builder builder = DefaultDispatcher.newBuilder(); // Subsetted StdLib - addBindings( - builder, Operator.INDEX.getFunction(), fromStandardFunction(IndexOperator.create())); - addBindings( - builder, - Operator.LOGICAL_NOT.getFunction(), - fromStandardFunction(LogicalNotOperator.create())); - addBindings(builder, Operator.ADD.getFunction(), fromStandardFunction(AddOperator.create())); - addBindings( - builder, Operator.GREATER.getFunction(), fromStandardFunction(GreaterOperator.create())); - addBindings( - builder, - Operator.GREATER_EQUALS.getFunction(), - fromStandardFunction(GreaterEqualsOperator.create())); - addBindings(builder, Operator.LESS.getFunction(), fromStandardFunction(LessOperator.create())); - addBindings( - builder, Operator.DIVIDE.getFunction(), fromStandardFunction(DivideOperator.create())); - addBindings( - builder, Operator.EQUALS.getFunction(), fromStandardFunction(EqualsOperator.create())); - addBindings( - builder, - Operator.NOT_STRICTLY_FALSE.getFunction(), - fromStandardFunction(NotStrictlyFalseFunction.create())); - addBindings(builder, "dyn", fromStandardFunction(DynFunction.create())); + CelStandardFunctions stdFunctions = + CelStandardFunctions.newBuilder() + .includeFunctions( + StandardFunction.INDEX, + StandardFunction.LOGICAL_NOT, + StandardFunction.ADD, + StandardFunction.GREATER, + StandardFunction.GREATER_EQUALS, + StandardFunction.LESS, + StandardFunction.DIVIDE, + StandardFunction.EQUALS, + StandardFunction.NOT_STRICTLY_FALSE, + StandardFunction.DYN) + .build(); + addBindingsToDispatcher( + builder, stdFunctions.newFunctionBindings(RUNTIME_EQUALITY, CEL_OPTIONS)); // Custom functions - addBindings( + addBindingsToDispatcher( builder, - "zero", - CelFunctionBinding.from("zero_overload", ImmutableList.of(), (unused) -> 0L)); - addBindings( + CelFunctionBinding.fromOverloads( + "zero", CelFunctionBinding.from("zero_overload", ImmutableList.of(), (unused) -> 0L))); + + addBindingsToDispatcher( builder, - "error", - CelFunctionBinding.from( - "error_overload", - ImmutableList.of(), - (unused) -> { - throw new IllegalArgumentException("Intentional error"); - })); - addBindings( + CelFunctionBinding.fromOverloads( + "error", + CelFunctionBinding.from( + "error_overload", + ImmutableList.of(), + (unused) -> { + throw new IllegalArgumentException("Intentional error"); + }))); + + addBindingsToDispatcher( builder, - "neg", - CelFunctionBinding.from("neg_int", Long.class, arg -> -arg), - CelFunctionBinding.from("neg_double", Double.class, arg -> -arg)); - addBindings( + CelFunctionBinding.fromOverloads( + "neg", + CelFunctionBinding.from("neg_int", Long.class, arg -> -arg), + CelFunctionBinding.from("neg_double", Double.class, arg -> -arg))); + + addBindingsToDispatcher( builder, - "cel.expr.conformance.proto3.power", - CelFunctionBinding.from( - "power_int_int", - Long.class, - Long.class, - (value, power) -> (long) Math.pow(value, power))); - addBindings( + CelFunctionBinding.fromOverloads( + "cel.expr.conformance.proto3.power", + CelFunctionBinding.from( + "power_int_int", + Long.class, + Long.class, + (value, power) -> (long) Math.pow(value, power)))); + + addBindingsToDispatcher( builder, - "concat", - CelFunctionBinding.from( - "concat_bytes_bytes", - CelByteString.class, - CelByteString.class, - ProgramPlannerTest::concatenateByteArrays), - CelFunctionBinding.from( - "bytes_concat_bytes", - CelByteString.class, - CelByteString.class, - ProgramPlannerTest::concatenateByteArrays)); + CelFunctionBinding.fromOverloads( + "concat", + CelFunctionBinding.from( + "concat_bytes_bytes", + CelByteString.class, + CelByteString.class, + ProgramPlannerTest::concatenateByteArrays), + CelFunctionBinding.from( + "bytes_concat_bytes", + CelByteString.class, + CelByteString.class, + ProgramPlannerTest::concatenateByteArrays))); return builder.build(); } - private static void addBindings( - DefaultDispatcher.Builder builder, - String functionName, - CelFunctionBinding... functionBindings) { - addBindings(builder, functionName, ImmutableSet.copyOf(functionBindings)); - } - - private static void addBindings( - DefaultDispatcher.Builder builder, - String functionName, - ImmutableCollection overloadBindings) { + private static void addBindingsToDispatcher( + DefaultDispatcher.Builder builder, ImmutableCollection overloadBindings) { if (overloadBindings.isEmpty()) { throw new IllegalArgumentException("Invalid bindings"); } - // TODO: Runtime top-level APIs currently does not allow grouping overloads with - // the function name. This capability will have to be added. - if (overloadBindings.size() == 1) { - CelFunctionBinding singleBinding = Iterables.getOnlyElement(overloadBindings); - builder.addOverload( - functionName, - singleBinding.getArgTypes(), - singleBinding.isStrict(), - args -> guardedOp(functionName, args, singleBinding)); - } else { - overloadBindings.forEach( - overload -> - builder.addOverload( - overload.getOverloadId(), - overload.getArgTypes(), - overload.isStrict(), - args -> guardedOp(functionName, args, overload))); - - // Setup dynamic dispatch - CelFunctionOverload dynamicDispatchDef = - args -> { - for (CelFunctionBinding overload : overloadBindings) { - if (CelResolvedOverload.canHandle( - args, overload.getArgTypes(), overload.isStrict())) { - return overload.getDefinition().apply(args); - } - } - - throw new IllegalArgumentException( - "No matching overload for function: " + functionName); - }; - - boolean allOverloadsStrict = overloadBindings.stream().allMatch(CelFunctionBinding::isStrict); - builder.addOverload( - functionName, ImmutableList.of(), /* isStrict= */ allOverloadsStrict, dynamicDispatchDef); - } - } - /** Creates an invocation guard around the overload definition. */ - private static Object guardedOp( - String functionName, Object[] args, CelFunctionBinding singleBinding) - throws CelEvaluationException { - if (!CelResolvedOverload.canHandle( - args, singleBinding.getArgTypes(), singleBinding.isStrict())) { - throw new IllegalArgumentException("No matching overload for function: " + functionName); - } - - return singleBinding.getDefinition().apply(args); + overloadBindings.forEach( + overload -> + builder.addOverload( + overload.getOverloadId(), + overload.getArgTypes(), + overload.isStrict(), + overload.getDefinition())); } @TestParameter boolean isParseOnly; @@ -872,11 +811,6 @@ private static CelByteString concatenateByteArrays(CelByteString bytes1, CelByte return bytes1.concat(bytes2); } - private static ImmutableSet fromStandardFunction( - CelStandardFunction standardFunction) { - return standardFunction.newFunctionBindings(CEL_OPTIONS, RUNTIME_EQUALITY); - } - @SuppressWarnings("ImmutableEnumChecker") // Test only private enum ConstantTestCase { NULL("null", NullValue.NULL_VALUE),