diff --git a/checker/src/main/java/dev/cel/checker/CelStandardDeclarations.java b/checker/src/main/java/dev/cel/checker/CelStandardDeclarations.java index 8a84c2a48..02c5f6097 100644 --- a/checker/src/main/java/dev/cel/checker/CelStandardDeclarations.java +++ b/checker/src/main/java/dev/cel/checker/CelStandardDeclarations.java @@ -93,6 +93,8 @@ enum StandardFunction { Overload.Arithmetic.ADD_STRING, Overload.Arithmetic.ADD_BYTES, Overload.Arithmetic.ADD_LIST, + Overload.Arithmetic.MAP_INSERT_KEY_VALUE, + Overload.Arithmetic.MAP_INSERT_MAP, Overload.Arithmetic.ADD_TIMESTAMP_DURATION, Overload.Arithmetic.ADD_DURATION_TIMESTAMP, Overload.Arithmetic.ADD_DURATION_DURATION), @@ -410,6 +412,12 @@ public enum Arithmetic implements StandardOverload { ADD_LIST( CelOverloadDecl.newGlobalOverload( "add_list", "list concatenation", LIST_OF_A, LIST_OF_A, LIST_OF_A)), + MAP_INSERT_KEY_VALUE( + CelOverloadDecl.newGlobalOverload( + "mapInsert_map_key_value", "map insertion", MAP_OF_AB, MAP_OF_AB, TYPE_PARAM_A, TYPE_PARAM_B)), + MAP_INSERT_MAP( + CelOverloadDecl.newGlobalOverload( + "mapInsert_map_map", "map insertion", MAP_OF_AB, MAP_OF_AB, MAP_OF_AB)), ADD_TIMESTAMP_DURATION( CelOverloadDecl.newGlobalOverload( "add_timestamp_duration", diff --git a/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java b/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java index 3aff6105d..5ef065f14 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java +++ b/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java @@ -17,6 +17,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Ints; import com.google.common.primitives.UnsignedLong; @@ -83,6 +84,8 @@ public enum StandardFunction { Arithmetic.ADD_STRING, Arithmetic.ADD_BYTES, Arithmetic.ADD_LIST, + Arithmetic.MAP_INSERT_MAP_MAP, + Arithmetic.MAP_INSERT_MAP_KEY_VALUE, Arithmetic.ADD_TIMESTAMP_DURATION, Arithmetic.ADD_DURATION_TIMESTAMP, Arithmetic.ADD_DURATION_DURATION), @@ -373,6 +376,16 @@ public enum Arithmetic implements StandardOverload { (bindingHelper) -> CelFunctionBinding.from( "add_list", List.class, List.class, RuntimeHelpers::concat)), + MAP_INSERT_MAP_MAP( + (bindingHelper) -> + CelFunctionBinding.from( + "mapInsert_map_map", Map.class, Map.class, RuntimeHelpers::mapInsert)), + MAP_INSERT_MAP_KEY_VALUE( + (bindingHelper) -> + CelFunctionBinding.from( + "mapInsert_map_key_value", + ImmutableList.of(Map.class, Object.class, Object.class), + RuntimeHelpers::mapInsert)), SUBTRACT_INT64( (bindingHelper) -> CelFunctionBinding.from( diff --git a/runtime/src/main/java/dev/cel/runtime/RuntimeHelpers.java b/runtime/src/main/java/dev/cel/runtime/RuntimeHelpers.java index 1c20017d1..1e59969f0 100644 --- a/runtime/src/main/java/dev/cel/runtime/RuntimeHelpers.java +++ b/runtime/src/main/java/dev/cel/runtime/RuntimeHelpers.java @@ -30,7 +30,9 @@ import dev.cel.common.internal.Converter; import java.time.format.DateTimeParseException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import org.threeten.extra.AmountFormats; @@ -96,6 +98,31 @@ static List concat(List first, List second) { return result; } + /** Concatenates two maps into a new map */ + static Map mapInsert(Map first, Map second) { + // TODO: return a mutable map instead of an actual copy. + Map result = new HashMap<>(first.size() + second.size()); + result.putAll(first); + result.putAll(second); + return result; + } + + /** Add new key value pair to an existing map. */ + static Map mapInsert(Object[] args) { + Map map = (Map) args[0]; + Object key = args[1]; + Object value = args[2]; + // TODO: return a mutable map instead of an actual copy. + if (map.containsKey(key)) { + throw new IllegalArgumentException( + String.format("insert failed: key %s already exists", key)); + } + Map result = new HashMap<>(map.size() + 1); + result.putAll(map); + result.put(key, value); + return result; + } + // Collections // ===========