diff --git a/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java b/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java index d7d0dea00..1832b20c8 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java +++ b/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java @@ -44,10 +44,12 @@ import dev.cel.common.types.TypeParamType; import dev.cel.compiler.CelCompiler; import dev.cel.compiler.CelCompilerBuilder; +import dev.cel.compiler.CelCompilerLibrary; import dev.cel.extensions.CelExtensions; import dev.cel.extensions.CelOptionalLibrary; import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelRuntimeBuilder; +import dev.cel.runtime.CelRuntimeLibrary; import java.util.Arrays; import java.util.Optional; @@ -237,7 +239,11 @@ private void addAllCompilerExtensions( // TODO: Add capability to accept user defined exceptions for (ExtensionConfig extensionConfig : extensions()) { CanonicalCelExtension extension = getExtensionOrThrow(extensionConfig.name()); - extension.addCompilerExtension(celCompilerBuilder, celOptions); + if (extension.compilerExtensionProvider() != null) { + CelCompilerLibrary celCompilerLibrary = extension.compilerExtensionProvider() + .getCelCompilerLibrary(celOptions, extensionConfig.version()); + celCompilerBuilder.addLibraries(celCompilerLibrary); + } } } @@ -245,7 +251,11 @@ private void addAllRuntimeExtensions(CelRuntimeBuilder celRuntimeBuilder, CelOpt // TODO: Add capability to accept user defined exceptions for (ExtensionConfig extensionConfig : extensions()) { CanonicalCelExtension extension = getExtensionOrThrow(extensionConfig.name()); - extension.addRuntimeExtension(celRuntimeBuilder, celOptions); + if (extension.runtimeExtensionProvider() != null) { + CelRuntimeLibrary celRuntimeLibrary = extension.runtimeExtensionProvider() + .getCelRuntimeLibrary(celOptions, extensionConfig.version()); + celRuntimeBuilder.addLibraries(celRuntimeLibrary); + } } } @@ -656,64 +666,68 @@ public static ExtensionConfig of(String name) { public static ExtensionConfig of(String name, int version) { return newBuilder().setName(name).setVersion(version).build(); } + + /** Create a new extension config with the specified name and the latest version. */ + public static ExtensionConfig latest(String name) { + return of(name, Integer.MAX_VALUE); + } } @VisibleForTesting enum CanonicalCelExtension { - BINDINGS((compilerBuilder, options) -> compilerBuilder.addLibraries(CelExtensions.bindings())), - PROTOS((compilerBuilder, options) -> compilerBuilder.addLibraries(CelExtensions.protos())), + BINDINGS((options, version) -> CelExtensions.bindings()), + PROTOS((options, version) -> CelExtensions.protos()), ENCODERS( - (compilerBuilder, options) -> compilerBuilder.addLibraries(CelExtensions.encoders()), - (runtimeBuilder, options) -> runtimeBuilder.addLibraries(CelExtensions.encoders())), + (options, version) -> CelExtensions.encoders(), + (options, version) -> CelExtensions.encoders()), MATH( - (compilerBuilder, options) -> compilerBuilder.addLibraries(CelExtensions.math(options)), - (runtimeBuilder, options) -> runtimeBuilder.addLibraries(CelExtensions.math(options))), + (options, version) -> CelExtensions.math(options, version), + (options, version) -> CelExtensions.math(options, version)), OPTIONAL( - (compilerBuilder, options) -> compilerBuilder.addLibraries(CelOptionalLibrary.INSTANCE), - (runtimeBuilder, options) -> runtimeBuilder.addLibraries(CelOptionalLibrary.INSTANCE)), + (options, version) -> CelOptionalLibrary.INSTANCE, + (options, version) -> CelOptionalLibrary.INSTANCE), STRINGS( - (compilerBuilder, options) -> compilerBuilder.addLibraries(CelExtensions.strings()), - (runtimeBuilder, options) -> runtimeBuilder.addLibraries(CelExtensions.strings())), + (options, version) -> CelExtensions.strings(), + (options, version) -> CelExtensions.strings()), SETS( - (compilerBuilder, options) -> compilerBuilder.addLibraries(CelExtensions.sets(options)), - (runtimeBuilder, options) -> runtimeBuilder.addLibraries(CelExtensions.sets(options))), + (options, version) -> CelExtensions.sets(options), + (options, version) -> CelExtensions.sets(options)), LISTS( - (compilerBuilder, options) -> compilerBuilder.addLibraries(CelExtensions.lists()), - (runtimeBuilder, options) -> runtimeBuilder.addLibraries(CelExtensions.lists())); + (options, version) -> CelExtensions.lists(), + (options, version) -> CelExtensions.lists()); @SuppressWarnings("ImmutableEnumChecker") - private final CompilerExtensionApplier compilerExtensionApplier; + private final CompilerExtensionProvider compilerExtensionProvider; @SuppressWarnings("ImmutableEnumChecker") - private final RuntimeExtensionApplier runtimeExtensionApplier; + private final RuntimeExtensionProvider runtimeExtensionProvider; - interface CompilerExtensionApplier { - void apply(CelCompilerBuilder compilerBuilder, CelOptions options); + interface CompilerExtensionProvider { + CelCompilerLibrary getCelCompilerLibrary(CelOptions options, int version); } - interface RuntimeExtensionApplier { - void apply(CelRuntimeBuilder runtimeBuilder, CelOptions options); + interface RuntimeExtensionProvider { + CelRuntimeLibrary getCelRuntimeLibrary(CelOptions options, int version); } - void addCompilerExtension(CelCompilerBuilder compilerBuilder, CelOptions options) { - compilerExtensionApplier.apply(compilerBuilder, options); + CompilerExtensionProvider compilerExtensionProvider() { + return compilerExtensionProvider; } - void addRuntimeExtension(CelRuntimeBuilder runtimeBuilder, CelOptions options) { - runtimeExtensionApplier.apply(runtimeBuilder, options); + RuntimeExtensionProvider runtimeExtensionProvider() { + return runtimeExtensionProvider; } - CanonicalCelExtension(CompilerExtensionApplier compilerExtensionApplier) { - this( - compilerExtensionApplier, - (runtimeBuilder, options) -> {}); // no-op. Not all extensions augment the runtime. + CanonicalCelExtension(CompilerExtensionProvider compilerExtensionProvider) { + this.compilerExtensionProvider = compilerExtensionProvider; + this.runtimeExtensionProvider = null; // Not all extensions augment the runtime. } CanonicalCelExtension( - CompilerExtensionApplier compilerExtensionApplier, - RuntimeExtensionApplier runtimeExtensionApplier) { - this.compilerExtensionApplier = compilerExtensionApplier; - this.runtimeExtensionApplier = runtimeExtensionApplier; + CompilerExtensionProvider compilerExtensionProvider, + RuntimeExtensionProvider runtimeExtensionProvider) { + this.compilerExtensionProvider = compilerExtensionProvider; + this.runtimeExtensionProvider = runtimeExtensionProvider; } } diff --git a/bundle/src/test/java/dev/cel/bundle/CelEnvironmentTest.java b/bundle/src/test/java/dev/cel/bundle/CelEnvironmentTest.java index 76b479828..bafa20ea6 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelEnvironmentTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelEnvironmentTest.java @@ -25,6 +25,7 @@ import dev.cel.bundle.CelEnvironment.LibrarySubset.FunctionSelector; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelOptions; +import dev.cel.common.CelValidationException; import dev.cel.common.CelValidationResult; import dev.cel.compiler.CelCompiler; import dev.cel.compiler.CelCompilerFactory; @@ -52,14 +53,14 @@ public void newBuilder_defaults() { public void extend_allExtensions() throws Exception { ImmutableSet extensionConfigs = ImmutableSet.of( - ExtensionConfig.of("bindings"), - ExtensionConfig.of("encoders"), - ExtensionConfig.of("lists"), - ExtensionConfig.of("math"), - ExtensionConfig.of("optional"), - ExtensionConfig.of("protos"), - ExtensionConfig.of("sets"), - ExtensionConfig.of("strings")); + ExtensionConfig.latest("bindings"), + ExtensionConfig.latest("encoders"), + ExtensionConfig.latest("lists"), + ExtensionConfig.latest("math"), + ExtensionConfig.latest("optional"), + ExtensionConfig.latest("protos"), + ExtensionConfig.latest("sets"), + ExtensionConfig.latest("strings")); CelEnvironment environment = CelEnvironment.newBuilder().addExtensions(extensionConfigs).build(); @@ -76,6 +77,54 @@ public void extend_allExtensions() throws Exception { assertThat(result).isTrue(); } + @Test + public void extensionVersion_specific() throws Exception { + CelEnvironment environment = + CelEnvironment.newBuilder().addExtensions(ExtensionConfig.of("math", 1)).build(); + + Cel cel = environment.extend(CelFactory.standardCelBuilder().build(), CelOptions.DEFAULT); + CelAbstractSyntaxTree ast1 = cel.compile("math.abs(-4)").getAst(); + assertThat(cel.createProgram(ast1).eval()).isEqualTo(4); + + // Version 1 of the 'math' extension does not include sqrt + assertThat( + assertThrows( + CelValidationException.class, + () -> { + cel.compile("math.sqrt(4)").getAst(); + })) + .hasMessageThat() + .contains("undeclared reference to 'sqrt'"); + } + + @Test + public void extensionVersion_latest() throws Exception { + CelEnvironment environment = + CelEnvironment.newBuilder() + .addExtensions(ExtensionConfig.latest("math")) + .build(); + + Cel cel = environment.extend(CelFactory.standardCelBuilder().build(), CelOptions.DEFAULT); + CelAbstractSyntaxTree ast = cel.compile("math.sqrt(4)").getAst(); + double result = (double) cel.createProgram(ast).eval(); + assertThat(result).isEqualTo(2.0); + } + + @Test + public void extensionVersion_unsupportedVersion_throws() { + CelEnvironment environment = + CelEnvironment.newBuilder().addExtensions(ExtensionConfig.of("math", -5)).build(); + + assertThat( + assertThrows( + CelEnvironmentException.class, + () -> { + environment.extend(CelFactory.standardCelBuilder().build(), CelOptions.DEFAULT); + })) + .hasMessageThat() + .contains("Unsupported 'math' extension version -5"); + } + @Test public void stdlibSubset_bothIncludeExcludeSet_throws() { assertThat( diff --git a/extensions/src/main/java/dev/cel/extensions/CelExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelExtensions.java index 3e08f1d15..76fcda436 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelExtensions.java @@ -103,13 +103,22 @@ public static CelProtoExtensions protos() { * *

This will include all functions denoted in {@link CelMathExtensions.Function}, including any * future additions. To expose only a subset of these, use {@link #math(CelOptions, - * CelMathExtensions.Function...)} instead. + * CelMathExtensions.Function...)} or {@link #math(CelOptions,int)} instead. * * @param celOptions CelOptions to configure CelMathExtension with. This should be the same * options object used to configure the compilation/runtime environments. */ public static CelMathExtensions math(CelOptions celOptions) { - return new CelMathExtensions(celOptions); + return new CelMathExtensions(celOptions, Integer.MAX_VALUE); + } + + /** + * Returns the specified version of the 'math' extension. + * + *

Refer to README.md for functions available in each version. + */ + public static CelMathExtensions math(CelOptions celOptions, int version) { + return new CelMathExtensions(celOptions, version); } /** diff --git a/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java index 9558dbee3..65f202ac0 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java @@ -89,6 +89,8 @@ final class CelMathExtensions implements CelCompilerLibrary, CelRuntimeLibrary { private static final String MATH_BIT_LEFT_SHIFT_FUNCTION = "math.bitShiftLeft"; private static final String MATH_BIT_RIGHT_SHIFT_FUNCTION = "math.bitShiftRight"; + private static final String MATH_SQRT_FUNCTION = "math.sqrt"; + private static final int MAX_BIT_SHIFT = 63; /** @@ -614,7 +616,80 @@ enum Function { "math_bitShiftRight_uint_int", UnsignedLong.class, Long.class, - CelMathExtensions::uintBitShiftRight))); + CelMathExtensions::uintBitShiftRight))), + SQRT( + CelFunctionDecl.newFunctionDeclaration( + MATH_SQRT_FUNCTION, + CelOverloadDecl.newGlobalOverload( + "math_sqrt_double", + "Computes square root of the double value.", + SimpleType.DOUBLE, + SimpleType.DOUBLE), + CelOverloadDecl.newGlobalOverload( + "math_sqrt_int", + "Computes square root of the int value.", + SimpleType.DOUBLE, + SimpleType.INT), + CelOverloadDecl.newGlobalOverload( + "math_sqrt_uint", + "Computes square root of the unsigned value.", + SimpleType.DOUBLE, + SimpleType.UINT)), + ImmutableSet.of( + CelFunctionBinding.from( + "math_sqrt_double", Double.class, CelMathExtensions::sqrtDouble), + CelFunctionBinding.from( + "math_sqrt_int", Long.class, CelMathExtensions::sqrtInt), + CelFunctionBinding.from( + "math_sqrt_uint", UnsignedLong.class, CelMathExtensions::sqrtUint))); + + private static final ImmutableSet VERSION_0 = ImmutableSet.of( + MIN, + MAX); + + private static final ImmutableSet VERSION_1 = + ImmutableSet.builder() + .addAll(VERSION_0) + .add( + CEIL, + FLOOR, + ROUND, + TRUNC, + ISINF, + ISNAN, + ISFINITE, + ABS, + SIGN, + BITAND, + BITOR, + BITXOR, + BITNOT, + BITSHIFTLEFT, + BITSHIFTRIGHT) + .build(); + + private static final ImmutableSet VERSION_2 = + ImmutableSet.builder() + .addAll(VERSION_1) + .add(SQRT) + .build(); + + private static final ImmutableSet VERSION_LATEST = VERSION_2; + + private static ImmutableSet byVersion(int version) { + switch (version) { + case 0: + return Function.VERSION_0; + case 1: + return Function.VERSION_1; + case 2: + return Function.VERSION_2; + case Integer.MAX_VALUE: + return Function.VERSION_LATEST; + default: + throw new IllegalArgumentException("Unsupported 'math' extension version " + version); + } + } private final CelFunctionDecl functionDecl; private final ImmutableSet functionBindings; @@ -644,8 +719,8 @@ String getFunction() { private final boolean enableUnsignedLongs; private final ImmutableSet functions; - CelMathExtensions(CelOptions celOptions) { - this(celOptions, ImmutableSet.copyOf(Function.values())); + CelMathExtensions(CelOptions celOptions, int version) { + this(celOptions, Function.byVersion(version)); } CelMathExtensions(CelOptions celOptions, Set functions) { @@ -880,6 +955,18 @@ private static UnsignedLong uintBitShiftRight(UnsignedLong value, long shiftAmou return UnsignedLong.fromLongBits(value.longValue() >>> shiftAmount); } + private static Double sqrtDouble(double x) { + return Math.sqrt(x); + } + + private static Double sqrtInt(Long x) { + return sqrtDouble(x.doubleValue()); + } + + private static Double sqrtUint(UnsignedLong x) { + return sqrtDouble(x.doubleValue()); + } + private static Comparable minList(List list) { if (list.isEmpty()) { throw new IllegalStateException("math.@min(list) argument must not be empty"); diff --git a/extensions/src/main/java/dev/cel/extensions/README.md b/extensions/src/main/java/dev/cel/extensions/README.md index 57fbdf397..a3b522f2b 100644 --- a/extensions/src/main/java/dev/cel/extensions/README.md +++ b/extensions/src/main/java/dev/cel/extensions/README.md @@ -334,6 +334,23 @@ Examples: math.isFinite(0.0/0.0) // returns false math.isFinite(1.2) // returns true +### Math.sqrt + +Introduced at version: 2 + +Returns the square root of the numeric type provided as input. If the value is +NaN, the output is NaN. If the input is negative, the output is NaN. + + math.sqrt() -> + math.sqrt() -> + math.sqrt() -> + +Examples: + + math.sqrt(81.0) // returns 9.0 + math.sqrt(4) // returns 2.0 + math.sqrt(-4) // returns NaN + ## Protos Extended macros and functions for proto manipulation. diff --git a/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java index 26edd95bf..2ee8ccf27 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java @@ -162,6 +162,7 @@ public void getAllFunctionNames() { "math.bitNot", "math.bitShiftLeft", "math.bitShiftRight", + "math.sqrt", "charAt", "indexOf", "join", diff --git a/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java index 742cd10f4..bcdfb0a21 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java @@ -674,6 +674,7 @@ public void least_nonProtoNamespace_success(String expr) throws Exception { @TestParameters("{expr: 'math.isNaN(-1.0/0.0)', expectedResult: false}") @TestParameters("{expr: 'math.isNaN(math.round(0.0/0.0))', expectedResult: true}") @TestParameters("{expr: 'math.isNaN(math.sign(0.0/0.0))', expectedResult: true}") + @TestParameters("{expr: 'math.isNaN(math.sqrt(-4))', expectedResult: true}") public void isNaN_success(String expr, boolean expectedResult) throws Exception { CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); @@ -1164,4 +1165,19 @@ public void bitShiftRight_invalidArgs_throwsException(String expr) throws Except assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); assertThat(e).hasCauseThat().hasMessageThat().contains("math.bitShiftRight() negative offset"); } + + @Test + @TestParameters("{expr: 'math.sqrt(49.0)', expectedResult: 7.0}") + @TestParameters("{expr: 'math.sqrt(82)', expectedResult: 9.055385138137417}") + @TestParameters("{expr: 'math.sqrt(25u)', expectedResult: 5.0}") + @TestParameters("{expr: 'math.sqrt(0.0/0.0)', expectedResult: NaN}") + @TestParameters("{expr: 'math.sqrt(1.0/0.0)', expectedResult: Infinity}") + @TestParameters("{expr: 'math.sqrt(-1)', expectedResult: NaN}") + public void sqrt_success(String expr, double expectedResult) throws Exception { + CelAbstractSyntaxTree ast = CEL_UNSIGNED_COMPILER.compile(expr).getAst(); + + Object result = CEL_UNSIGNED_RUNTIME.createProgram(ast).eval(); + + assertThat(result).isEqualTo(expectedResult); + } }