diff --git a/src/main/java/build/buf/protovalidate/EvaluatorBuilder.java b/src/main/java/build/buf/protovalidate/EvaluatorBuilder.java index c3fe31e3..f44d0603 100644 --- a/src/main/java/build/buf/protovalidate/EvaluatorBuilder.java +++ b/src/main/java/build/buf/protovalidate/EvaluatorBuilder.java @@ -33,9 +33,11 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import org.jspecify.annotations.Nullable; import org.projectnessie.cel.Env; import org.projectnessie.cel.EnvOption; @@ -184,6 +186,31 @@ private void buildMessage(Descriptor desc, MessageEvaluator msgEval) } } + private void collectDependencies(Set dependencyTypes, Descriptor desc) { + dependencyTypes.add(desc); + for (FieldDescriptor field : desc.getFields()) { + if (field.getJavaType() != FieldDescriptor.JavaType.MESSAGE) { + continue; + } + Descriptor submessageDesc = field.getMessageType(); + if (dependencyTypes.contains(submessageDesc)) { + continue; + } + collectDependencies(dependencyTypes, submessageDesc); + } + } + + private Message[] getTypesForMessage(Message message) { + Set dependencyTypes = new HashSet<>(); + collectDependencies(dependencyTypes, message.getDescriptorForType()); + Message[] dependencyTypeMessages = new Message[dependencyTypes.size()]; + int i = 0; + for (Descriptor dependencyType : dependencyTypes) { + dependencyTypeMessages[i++] = DynamicMessage.newBuilder(dependencyType).buildPartial(); + } + return dependencyTypeMessages; + } + private void processMessageExpressions( Descriptor desc, MessageRules msgRules, MessageEvaluator msgEval, DynamicMessage message) throws CompilationException { @@ -193,7 +220,7 @@ private void processMessageExpressions( } Env finalEnv = env.extend( - EnvOption.types(message), + EnvOption.types((Object[]) getTypesForMessage(message)), EnvOption.declarations( Decls.newVar(Variable.THIS_NAME, Decls.newObjectType(desc.getFullName())))); List compiledPrograms = compileRules(celList, finalEnv, false); @@ -350,7 +377,7 @@ private void processFieldExpressions( try { DynamicMessage defaultInstance = DynamicMessage.parseFrom(fieldDescriptor.getMessageType(), new byte[0]); - opts.add(EnvOption.types(defaultInstance)); + opts.add(EnvOption.types((Object[]) getTypesForMessage(defaultInstance))); } catch (InvalidProtocolBufferException e) { throw new CompilationException("field descriptor type is invalid " + e.getMessage(), e); } diff --git a/src/test/java/build/buf/protovalidate/ValidatorImportTest.java b/src/test/java/build/buf/protovalidate/ValidatorImportTest.java new file mode 100644 index 00000000..84479f33 --- /dev/null +++ b/src/test/java/build/buf/protovalidate/ValidatorImportTest.java @@ -0,0 +1,166 @@ +// Copyright 2023-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package build.buf.protovalidate; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.example.imports.validationtest.ExampleImportMessage; +import com.example.imports.validationtest.ExampleImportMessageFieldRule; +import com.example.imports.validationtest.ExampleImportMessageInMap; +import com.example.imports.validationtest.ExampleImportMessageInMapFieldRule; +import com.example.imports.validationtest.ExampleImportedMessage; +import org.junit.jupiter.api.Test; + +public class ValidatorImportTest { + @Test + public void testImportedMessageFromAnotherFile() throws Exception { + com.example.imports.validationtest.ExampleImportMessage valid = + ExampleImportMessage.newBuilder() + .setImportedSubmessage( + ExampleImportedMessage.newBuilder().setHexString("0123456789abcdef").build()) + .build(); + assertThat( + ValidatorFactory.newBuilder() + .build() + .validate(valid) + .toProto() + .getViolationsList() + .size()) + .isEqualTo(0); + + com.example.imports.validationtest.ExampleImportMessage invalid = + ExampleImportMessage.newBuilder() + .setImportedSubmessage(ExampleImportedMessage.newBuilder().setHexString("zyx").build()) + .build(); + assertThat( + ValidatorFactory.newBuilder() + .build() + .validate(invalid) + .toProto() + .getViolationsList() + .size()) + .isEqualTo(1); + } + + @Test + public void testImportedMessageFromAnotherFileInField() throws Exception { + com.example.imports.validationtest.ExampleImportMessageFieldRule valid = + ExampleImportMessageFieldRule.newBuilder() + .setMessageWithImport( + ExampleImportMessage.newBuilder() + .setImportedSubmessage( + ExampleImportedMessage.newBuilder() + .setHexString("0123456789abcdef") + .build()) + .build()) + .build(); + assertThat( + ValidatorFactory.newBuilder() + .build() + .validate(valid) + .toProto() + .getViolationsList() + .size()) + .isEqualTo(0); + + com.example.imports.validationtest.ExampleImportMessageFieldRule invalid = + ExampleImportMessageFieldRule.newBuilder() + .setMessageWithImport( + ExampleImportMessage.newBuilder() + .setImportedSubmessage( + ExampleImportedMessage.newBuilder().setHexString("zyx").build()) + .build()) + .build(); + assertThat( + ValidatorFactory.newBuilder() + .build() + .validate(invalid) + .toProto() + .getViolationsList() + .size()) + .isEqualTo(1); + } + + @Test + public void testImportedMessageFromAnotherFileInMap() throws Exception { + com.example.imports.validationtest.ExampleImportMessageInMap valid = + ExampleImportMessageInMap.newBuilder() + .putImportedSubmessage( + 0, ExampleImportedMessage.newBuilder().setHexString("0123456789abcdef").build()) + .build(); + assertThat( + ValidatorFactory.newBuilder() + .build() + .validate(valid) + .toProto() + .getViolationsList() + .size()) + .isEqualTo(0); + + com.example.imports.validationtest.ExampleImportMessageInMap invalid = + ExampleImportMessageInMap.newBuilder() + .putImportedSubmessage( + 0, ExampleImportedMessage.newBuilder().setHexString("zyx").build()) + .build(); + assertThat( + ValidatorFactory.newBuilder() + .build() + .validate(invalid) + .toProto() + .getViolationsList() + .size()) + .isEqualTo(1); + } + + @Test + public void testImportedMessageFromAnotherFileInMapInField() throws Exception { + com.example.imports.validationtest.ExampleImportMessageInMapFieldRule valid = + ExampleImportMessageInMapFieldRule.newBuilder() + .setMessageWithImport( + ExampleImportMessageInMap.newBuilder() + .putImportedSubmessage( + 0, + ExampleImportedMessage.newBuilder() + .setHexString("0123456789abcdef") + .build()) + .build()) + .build(); + assertThat( + ValidatorFactory.newBuilder() + .build() + .validate(valid) + .toProto() + .getViolationsList() + .size()) + .isEqualTo(0); + + com.example.imports.validationtest.ExampleImportMessageInMapFieldRule invalid = + ExampleImportMessageInMapFieldRule.newBuilder() + .setMessageWithImport( + ExampleImportMessageInMap.newBuilder() + .putImportedSubmessage( + 0, ExampleImportedMessage.newBuilder().setHexString("zyx").build()) + .build()) + .build(); + assertThat( + ValidatorFactory.newBuilder() + .build() + .validate(invalid) + .toProto() + .getViolationsList() + .size()) + .isEqualTo(1); + } +} diff --git a/src/test/resources/proto/validationtest/import_test.proto b/src/test/resources/proto/validationtest/import_test.proto new file mode 100644 index 00000000..8f518ac5 --- /dev/null +++ b/src/test/resources/proto/validationtest/import_test.proto @@ -0,0 +1,23 @@ +// Copyright 2023-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package validationtest; + +import "buf/validate/validate.proto"; + +message ExampleImportedMessage { + string hex_string = 1 [(buf.validate.field).string.pattern = "^[0-9a-fA-F]+$"]; +} diff --git a/src/test/resources/proto/validationtest/predefined.proto b/src/test/resources/proto/validationtest/predefined.proto index b5d4af70..925a4765 100644 --- a/src/test/resources/proto/validationtest/predefined.proto +++ b/src/test/resources/proto/validationtest/predefined.proto @@ -19,12 +19,10 @@ package validationtest; import "buf/validate/validate.proto"; extend buf.validate.StringRules { - optional bool is_ident = 1161 [ - (buf.validate.predefined).cel = { - id: "string.is_ident", - expression: "(rule && !this.matches('^[a-z0-9]{1,9}$')) ? 'invalid identifier' : ''", - } - ]; + optional bool is_ident = 1161 [(buf.validate.predefined).cel = { + id: "string.is_ident" + expression: "(rule && !this.matches('^[a-z0-9]{1,9}$')) ? 'invalid identifier' : ''" + }]; } message ExamplePredefinedFieldRules { diff --git a/src/test/resources/proto/validationtest/validationtest.proto b/src/test/resources/proto/validationtest/validationtest.proto index b73409b4..188e3760 100644 --- a/src/test/resources/proto/validationtest/validationtest.proto +++ b/src/test/resources/proto/validationtest/validationtest.proto @@ -17,6 +17,7 @@ syntax = "proto3"; package validationtest; import "buf/validate/validate.proto"; +import "validationtest/import_test.proto"; message ExampleFieldRules { string regex_string_field = 1 [(buf.validate.field).string.pattern = "^[a-z0-9]{1,9}$"]; @@ -42,7 +43,7 @@ message ExampleOneofRules { message ExampleMessageRules { option (buf.validate.message).cel = { - id: "secondary_email_depends_on_primary", + id: "secondary_email_depends_on_primary" expression: "has(this.secondary_email) && !has(this.primary_email)" "? 'cannot set a secondary email without setting a primary one'" @@ -66,3 +67,57 @@ message FieldExpressionMapInt32 { expression: "this.all(k, this[k] == 1)" }]; } + +message ExampleImportMessage { + option (buf.validate.message) = { + cel: { + id: "imported_submessage_must_not_be_null" + expression: "this.imported_submessage != null" + } + cel: { + id: "hex_string_must_not_be_empty" + expression: "this.imported_submessage.hex_string != ''" + } + }; + ExampleImportedMessage imported_submessage = 1; +} + +message ExampleImportMessageFieldRule { + ExampleImportMessage message_with_import = 1 [ + (buf.validate.field).cel = { + id: "field_must_not_be_null" + expression: "this.imported_submessage != null" + }, + (buf.validate.field).cel = { + id: "field_string_must_not_be_empty" + expression: "this.imported_submessage.hex_string != ''" + } + ]; +} + +message ExampleImportMessageInMap { + option (buf.validate.message) = { + cel: { + id: "imported_submessage_must_not_be_null" + expression: "this.imported_submessage[0] != null" + } + cel: { + id: "hex_string_must_not_be_empty" + expression: "this.imported_submessage[0].hex_string != ''" + } + }; + map imported_submessage = 1; +} + +message ExampleImportMessageInMapFieldRule { + ExampleImportMessageInMap message_with_import = 1 [ + (buf.validate.field).cel = { + id: "field_must_not_be_null" + expression: "this.imported_submessage[0] != null" + }, + (buf.validate.field).cel = { + id: "field_string_must_not_be_empty" + expression: "this.imported_submessage[0].hex_string != ''" + } + ]; +}