diff --git a/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java b/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java index 86f9b8f29..db57ccb79 100644 --- a/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java +++ b/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java @@ -16,8 +16,6 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; -import static dev.cel.testing.utils.ExprValueUtils.DEFAULT_EXTENSION_REGISTRY; -import static dev.cel.testing.utils.ExprValueUtils.DEFAULT_TYPE_REGISTRY; import static dev.cel.testing.utils.ExprValueUtils.fromValue; import static dev.cel.testing.utils.ExprValueUtils.toExprValue; @@ -29,6 +27,8 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.TypeRegistry; import dev.cel.checker.CelChecker; import dev.cel.common.CelContainer; import dev.cel.common.CelOptions; @@ -84,6 +84,21 @@ public final class ConformanceTest extends Statement { CelExtensions.strings(), CelOptionalLibrary.INSTANCE); + static final TypeRegistry CONFORMANCE_TYPE_REGISTRY = + TypeRegistry.newBuilder() + .add(dev.cel.expr.conformance.proto2.TestAllTypes.getDescriptor()) + .add(dev.cel.expr.conformance.proto3.TestAllTypes.getDescriptor()) + .build(); + + static final ExtensionRegistry CONFORMANCE_EXTENSION_REGISTRY = + createConformanceExtensionRegistry(); + + private static ExtensionRegistry createConformanceExtensionRegistry() { + ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); + dev.cel.expr.conformance.proto2.TestAllTypesExtensions.registerAllExtensions(extensionRegistry); + return extensionRegistry; + } + private static final CelParser PARSER_WITH_MACROS = CelParserFactory.standardCelParserBuilder() .setOptions(OPTIONS) @@ -106,7 +121,7 @@ private static CelChecker getChecker(SimpleTest test) throws Exception { ImmutableList.Builder decls = ImmutableList.builderWithExpectedSize(test.getTypeEnvCount()); for (dev.cel.expr.Decl decl : test.getTypeEnvList()) { - decls.add(Decl.parseFrom(decl.toByteArray(), DEFAULT_EXTENSION_REGISTRY)); + decls.add(Decl.parseFrom(decl.toByteArray(), CONFORMANCE_EXTENSION_REGISTRY)); } return CelCompilerFactory.standardCelCheckerBuilder() .setOptions(OPTIONS) @@ -127,7 +142,7 @@ private static CelRuntime getRuntime(SimpleTest test, boolean usePlanner) { // CEL-Internal-2 .setOptions(OPTIONS) .addLibraries(CANONICAL_RUNTIME_EXTENSIONS) - .setExtensionRegistry(DEFAULT_EXTENSION_REGISTRY) + .setExtensionRegistry(CONFORMANCE_EXTENSION_REGISTRY) .addMessageTypes(dev.cel.expr.conformance.proto2.TestAllTypes.getDescriptor()) .addMessageTypes(dev.cel.expr.conformance.proto3.TestAllTypes.getDescriptor()) .addFileTypes(dev.cel.expr.conformance.proto2.TestAllTypesExtensions.getDescriptor()); @@ -151,7 +166,8 @@ private static ImmutableMap getBindings(SimpleTest test) throws private static Object fromExprValue(ExprValue value) throws Exception { switch (value.getKindCase()) { case VALUE: - return fromValue(value.getValue()); + return fromValue( + value.getValue(), CONFORMANCE_TYPE_REGISTRY, CONFORMANCE_EXTENSION_REGISTRY); default: throw new IllegalArgumentException( String.format("Unexpected binding value kind: %s", value.getKindCase())); @@ -224,7 +240,7 @@ public void evaluate() throws Throwable { assertThat(result) .ignoringRepeatedFieldOrderOfFieldDescriptors( MapValue.getDescriptor().findFieldByName("entries")) - .unpackingAnyUsing(DEFAULT_TYPE_REGISTRY, DEFAULT_EXTENSION_REGISTRY) + .unpackingAnyUsing(CONFORMANCE_TYPE_REGISTRY, CONFORMANCE_EXTENSION_REGISTRY) .isEqualTo(ExprValue.newBuilder().setValue(test.getValue()).build()); break; case EVAL_ERROR: @@ -237,7 +253,7 @@ public void evaluate() throws Throwable { assertThat(result) .ignoringRepeatedFieldOrderOfFieldDescriptors( MapValue.getDescriptor().findFieldByName("entries")) - .unpackingAnyUsing(DEFAULT_TYPE_REGISTRY, DEFAULT_EXTENSION_REGISTRY) + .unpackingAnyUsing(CONFORMANCE_TYPE_REGISTRY, CONFORMANCE_EXTENSION_REGISTRY) .isEqualTo(ExprValue.newBuilder().setValue(test.getTypedResult().getResult()).build()); assertThat(resultType).isEqualTo(test.getTypedResult().getDeducedType()); break; diff --git a/conformance/src/test/java/dev/cel/conformance/ConformanceTestRunner.java b/conformance/src/test/java/dev/cel/conformance/ConformanceTestRunner.java index dc3d5021e..4c3631d31 100644 --- a/conformance/src/test/java/dev/cel/conformance/ConformanceTestRunner.java +++ b/conformance/src/test/java/dev/cel/conformance/ConformanceTestRunner.java @@ -14,8 +14,7 @@ package dev.cel.conformance; -import static dev.cel.testing.utils.ExprValueUtils.DEFAULT_EXTENSION_REGISTRY; -import static dev.cel.testing.utils.ExprValueUtils.DEFAULT_TYPE_REGISTRY; +import static dev.cel.conformance.ConformanceTest.CONFORMANCE_EXTENSION_REGISTRY; import com.google.common.base.Preconditions; import com.google.common.base.Splitter; @@ -50,14 +49,16 @@ private static ImmutableSortedMap loadTestFiles() { SPLITTER.splitToList(System.getProperty("dev.cel.conformance.ConformanceTests.tests")); try { TextFormat.Parser parser = - TextFormat.Parser.newBuilder().setTypeRegistry(DEFAULT_TYPE_REGISTRY).build(); + TextFormat.Parser.newBuilder() + .setTypeRegistry(ConformanceTest.CONFORMANCE_TYPE_REGISTRY) + .build(); ImmutableSortedMap.Builder testFiles = ImmutableSortedMap.naturalOrder(); for (String testPath : testPaths) { SimpleTestFile.Builder fileBuilder = SimpleTestFile.newBuilder(); try (BufferedReader input = Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8)) { - parser.merge(input, DEFAULT_EXTENSION_REGISTRY, fileBuilder); + parser.merge(input, CONFORMANCE_EXTENSION_REGISTRY, fileBuilder); } SimpleTestFile testFile = fileBuilder.build(); testFiles.put(testFile.getName(), testFile); diff --git a/policy/src/main/java/dev/cel/policy/CelPolicy.java b/policy/src/main/java/dev/cel/policy/CelPolicy.java index 9980d0cad..9e442a2e7 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicy.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicy.java @@ -27,6 +27,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -77,8 +78,7 @@ public abstract static class Builder { public abstract Builder setPolicySource(CelPolicySource policySource); - // This should stay package-private to encourage add/set methods to be used instead. - abstract ImmutableMap.Builder metadataBuilder(); + private final HashMap metadata = new HashMap<>(); public abstract Builder setMetadata(ImmutableMap value); @@ -90,6 +90,10 @@ public List imports() { return Collections.unmodifiableList(importList); } + public Map metadata() { + return Collections.unmodifiableMap(metadata); + } + @CanIgnoreReturnValue public Builder addImport(Import value) { importList.add(value); @@ -104,13 +108,13 @@ public Builder addImports(Collection values) { @CanIgnoreReturnValue public Builder putMetadata(String key, Object value) { - metadataBuilder().put(key, value); + metadata.put(key, value); return this; } @CanIgnoreReturnValue public Builder putMetadata(Map map) { - metadataBuilder().putAll(map); + metadata.putAll(map); return this; } @@ -118,6 +122,7 @@ public Builder putMetadata(Map map) { public CelPolicy build() { setImports(ImmutableList.copyOf(importList)); + setMetadata(ImmutableMap.copyOf(metadata)); return autoBuild(); } } diff --git a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel index 6924f753f..5af0665f9 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel @@ -92,6 +92,7 @@ java_library( "//bundle:environment", "//bundle:environment_yaml_parser", "//common:cel_ast", + "//common:cel_descriptor_util", "//common:compiler_common", "//common:options", "//common:proto_ast", @@ -134,6 +135,7 @@ java_library( ":cel_test_suite", ":cel_test_suite_exception", "//common:compiler_common", + "//common/annotations", "//common/formats:file_source", "//common/formats:parser_context", "//common/formats:yaml_helper", @@ -163,10 +165,14 @@ java_library( ":result_matcher", "//:auto_value", "//bundle:cel", + "//common:cel_descriptor_util", "//common:options", "//policy:parser", "//runtime", + "//testing/testrunner:proto_descriptor_utils", + "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", ], ) @@ -223,6 +229,7 @@ java_library( ":cel_test_suite", ":cel_test_suite_exception", ":registry_utils", + "//common/annotations", "@cel_spec//proto/cel/expr:expr_java_proto", "@cel_spec//proto/cel/expr/conformance/test:suite_java_proto", "@maven//:com_google_guava_guava", diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java index aa0d4b34f..5635b6152 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java @@ -14,12 +14,23 @@ package dev.cel.testing.testrunner; import com.google.auto.value.AutoValue; +import com.google.auto.value.extension.memoized.Memoized; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.TypeRegistry; import dev.cel.bundle.Cel; import dev.cel.bundle.CelFactory; +import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelOptions; import dev.cel.policy.CelPolicyParser; import dev.cel.runtime.CelLateFunctionBindings; +import dev.cel.testing.utils.ProtoDescriptorUtils; +import java.io.IOException; +import java.util.Arrays; import java.util.Map; import java.util.Optional; @@ -63,6 +74,19 @@ public abstract class CelTestContext { */ public abstract Optional celLateFunctionBindings(); + /** Interface for transforming bindings before evaluation. */ + @FunctionalInterface + public interface BindingTransformer { + ImmutableMap transform(ImmutableMap bindings) throws Exception; + } + + /** + * The binding transformer for the CEL test. + * + *

This transformer is used to transform the bindings before evaluation. + */ + public abstract Optional bindingTransformer(); + /** * The variable bindings for the CEL test. * @@ -99,6 +123,34 @@ public abstract class CelTestContext { */ public abstract Optional fileDescriptorSetPath(); + abstract ImmutableSet fileTypes(); + + @Memoized + public Optional typeRegistry() { + if (fileTypes().isEmpty() && !fileDescriptorSetPath().isPresent()) { + return Optional.empty(); + } + TypeRegistry.Builder builder = TypeRegistry.newBuilder(); + if (!fileTypes().isEmpty()) { + builder.add( + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileTypes()) + .messageTypeDescriptors()); + } + if (fileDescriptorSetPath().isPresent()) { + try { + builder.add( + ProtoDescriptorUtils.getAllDescriptorsFromJvm(fileDescriptorSetPath().get()) + .messageTypeDescriptors()); + } catch (IOException e) { + throw new IllegalStateException( + "Failed to load descriptors from path: " + fileDescriptorSetPath().get(), e); + } + } + return Optional.of(builder.build()); + } + + public abstract Optional extensionRegistry(); + /** Returns a builder for {@link CelTestContext} with the current instance's values. */ public abstract Builder toBuilder(); @@ -123,6 +175,8 @@ public abstract static class Builder { public abstract Builder setCelLateFunctionBindings( CelLateFunctionBindings celLateFunctionBindings); + public abstract Builder setBindingTransformer(BindingTransformer bindingTransformer); + public abstract Builder setVariableBindings(Map variableBindings); public abstract Builder setResultMatcher(ResultMatcher resultMatcher); @@ -133,6 +187,34 @@ public abstract Builder setCelLateFunctionBindings( public abstract Builder setFileDescriptorSetPath(String fileDescriptorSetPath); + abstract ImmutableSet.Builder fileTypesBuilder(); + + @CanIgnoreReturnValue + public Builder addMessageTypes(Descriptor... descriptors) { + return addMessageTypes(Arrays.asList(descriptors)); + } + + @CanIgnoreReturnValue + public Builder addMessageTypes(Iterable descriptors) { + for (Descriptor descriptor : descriptors) { + addFileTypes(descriptor.getFile()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addFileTypes(FileDescriptor... fileDescriptors) { + return addFileTypes(Arrays.asList(fileDescriptors)); + } + + @CanIgnoreReturnValue + public Builder addFileTypes(Iterable fileDescriptors) { + fileTypesBuilder().addAll(fileDescriptors); + return this; + } + + public abstract Builder setExtensionRegistry(ExtensionRegistry extensionRegistry); + public abstract CelTestContext build(); } } diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuite.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuite.java index e6086f128..a8869a8fb 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuite.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuite.java @@ -93,7 +93,7 @@ public abstract static class Builder { public abstract Builder toBuilder(); public static Builder newBuilder() { - return new AutoValue_CelTestSuite_CelTestSection.Builder(); + return new AutoValue_CelTestSuite_CelTestSection.Builder().setDescription(""); } /** Class representing a CEL test case within a test section. */ @@ -237,7 +237,8 @@ public abstract static class Builder { public static Builder newBuilder() { return new AutoValue_CelTestSuite_CelTestSection_CelTestCase.Builder() - .setInput(Input.ofNoInput()); // Default input to no input. + .setInput(Input.ofNoInput()) // Default input to no input. + .setDescription(""); } } } diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java index 3819e38d2..5e7e62498 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java @@ -22,6 +22,7 @@ import com.google.protobuf.TextFormat; import com.google.protobuf.TextFormat.ParseException; import com.google.protobuf.TypeRegistry; +import dev.cel.common.annotations.Internal; import dev.cel.expr.conformance.test.InputValue; import dev.cel.expr.conformance.test.TestCase; import dev.cel.expr.conformance.test.TestSection; @@ -35,23 +36,40 @@ /** * CelTestSuiteTextProtoParser intakes a textproto document that describes the structure of a CEL * test suite, parses it then creates a {@link CelTestSuite}. + * + *

CEL Library Internals. Do Not Use. */ -final class CelTestSuiteTextProtoParser { +@Internal +public final class CelTestSuiteTextProtoParser { /** Creates a new instance of {@link CelTestSuiteTextProtoParser}. */ - static CelTestSuiteTextProtoParser newInstance() { + public static CelTestSuiteTextProtoParser newInstance() { return new CelTestSuiteTextProtoParser(); } - CelTestSuite parse(String textProto) throws IOException, CelTestSuiteException { - TestSuite testSuite = parseTestSuite(textProto); + public CelTestSuite parse(String textProto) throws IOException, CelTestSuiteException { + return parse( + textProto, TypeRegistry.getEmptyTypeRegistry(), ExtensionRegistry.getEmptyRegistry()); + } + + public CelTestSuite parse(String textProto, TypeRegistry customTypeRegistry) + throws IOException, CelTestSuiteException { + return parse(textProto, customTypeRegistry, ExtensionRegistry.getEmptyRegistry()); + } + + public CelTestSuite parse( + String textProto, TypeRegistry customTypeRegistry, ExtensionRegistry customExtensionRegistry) + throws IOException, CelTestSuiteException { + TestSuite testSuite = parseTestSuite(textProto, customTypeRegistry, customExtensionRegistry); return parseCelTestSuite(testSuite); } - private TestSuite parseTestSuite(String textProto) throws IOException { + private TestSuite parseTestSuite( + String textProto, TypeRegistry customTypeRegistry, ExtensionRegistry customExtensionRegistry) + throws IOException { String fileDescriptorSetPath = System.getProperty("file_descriptor_set_path"); - TypeRegistry typeRegistry = TypeRegistry.getEmptyTypeRegistry(); - ExtensionRegistry extensionRegistry = ExtensionRegistry.getEmptyRegistry(); + TypeRegistry typeRegistry = customTypeRegistry; + ExtensionRegistry extensionRegistry = customExtensionRegistry; if (fileDescriptorSetPath != null) { extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptorSetPath); typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptorSetPath); diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteYamlParser.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteYamlParser.java index d1a3d6615..71c4b9231 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteYamlParser.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteYamlParser.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import dev.cel.common.CelIssue; +import dev.cel.common.annotations.Internal; import dev.cel.common.formats.CelFileSource; import dev.cel.common.formats.ParserContext; import dev.cel.common.formats.YamlHelper.YamlNodeType; @@ -43,15 +44,18 @@ /** * CelTestSuiteYamlParser intakes a YAML document that describes the structure of a CEL test suite, * parses it then creates a {@link CelTestSuite}. + * + *

CEL Library Internals. Do Not Use. */ -final class CelTestSuiteYamlParser { +@Internal +public final class CelTestSuiteYamlParser { /** Creates a new instance of {@link CelTestSuiteYamlParser}. */ - static CelTestSuiteYamlParser newInstance() { + public static CelTestSuiteYamlParser newInstance() { return new CelTestSuiteYamlParser(); } - CelTestSuite parse(String celTestSuiteYamlContent) throws CelTestSuiteException { + public CelTestSuite parse(String celTestSuiteYamlContent) throws CelTestSuiteException { return parseYaml(celTestSuiteYamlContent, ""); } @@ -110,6 +114,7 @@ private CelTestSuite.Builder parseTestSuite(ParserContext ctx, Node node) case "description": builder.setDescription(newString(ctx, valueNode)); break; + case "section": case "sections": builder.setSections(parseSections(ctx, valueNode)); break; diff --git a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java index a5e912ccb..2465d330e 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java @@ -31,11 +31,13 @@ import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import com.google.protobuf.TextFormat; +import com.google.protobuf.TypeRegistry; import dev.cel.bundle.Cel; import dev.cel.bundle.CelEnvironment; import dev.cel.bundle.CelEnvironment.ExtensionConfig; import dev.cel.bundle.CelEnvironmentYamlParser; import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelOptions; import dev.cel.common.CelProtoAbstractSyntaxTree; import dev.cel.common.CelValidationException; @@ -104,6 +106,13 @@ public static void runTest( } } + /** Runs the test with the provided AST. */ + public static void runTest( + CelAbstractSyntaxTree ast, CelTestCase testCase, CelTestContext celTestContext) + throws Exception { + evaluate(ast, testCase, celTestContext, /* celCoverageIndex= */ null); + } + @VisibleForTesting static void evaluateTestCase(CelTestCase testCase, CelTestContext celTestContext) throws Exception { @@ -205,6 +214,16 @@ private static Cel extendCel(CelTestContext celTestContext, CelOptions celOption .build(); } + if (!celTestContext.fileTypes().isEmpty()) { + extendedCel = + extendedCel + .toCelBuilder() + .addMessageTypes( + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(celTestContext.fileTypes()) + .messageTypeDescriptors()) + .build(); + } + CelEnvironment environment = CelEnvironment.newBuilder().build(); // Extend the cel object with the config file if provided. @@ -302,8 +321,15 @@ private static Object getEvaluationResult( return getEvaluationResultWithMessage( getEvaluatedContextExpr(testCase, celTestContext), program, celCoverageIndex); case BINDINGS: - return getEvaluationResultWithBindings( - getBindings(testCase, celTestContext), program, celCoverageIndex); + ImmutableMap bindings = getBindings(testCase, celTestContext); + if (celTestContext.bindingTransformer().isPresent()) { + try { + bindings = celTestContext.bindingTransformer().get().transform(bindings); + } catch (Exception e) { + throw new CelEvaluationException("Binding transformation failed: " + e.getMessage(), e); + } + } + return getEvaluationResultWithBindings(bindings, program, celCoverageIndex); case NO_INPUT: ImmutableMap.Builder newBindings = ImmutableMap.builder(); for (Map.Entry entry : celTestContext.variableBindings().entrySet()) { @@ -396,10 +422,23 @@ private static Object evaluateInput(Cel cel, String expr) private static Object getValueFromBinding(Object value, CelTestContext celTestContext) throws IOException { if (value instanceof Value) { - if (celTestContext.fileDescriptorSetPath().isPresent()) { - return fromValue((Value) value, celTestContext.fileDescriptorSetPath().get()); + if (celTestContext.typeRegistry().isPresent() + || celTestContext.extensionRegistry().isPresent()) { + if (celTestContext.typeRegistry().isPresent()) { + ExtensionRegistry extensionRegistry = + celTestContext.extensionRegistry().orElse(ExtensionRegistry.getEmptyRegistry()); + return fromValue((Value) value, celTestContext.typeRegistry().get(), extensionRegistry); + } else if (celTestContext.extensionRegistry().isPresent()) { + return fromValue( + (Value) value, + TypeRegistry.newBuilder().build(), + celTestContext.extensionRegistry().get()); + } else if (celTestContext.fileDescriptorSetPath().isPresent()) { + return fromValue((Value) value, celTestContext.fileDescriptorSetPath().get()); + } } - return fromValue((Value) value); + return fromValue( + (Value) value, TypeRegistry.newBuilder().build(), ExtensionRegistry.getEmptyRegistry()); } return value; } diff --git a/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java b/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java index 041c0f52d..9bccecc95 100644 --- a/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java +++ b/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java @@ -28,8 +28,6 @@ import com.google.protobuf.Message; import com.google.protobuf.NullValue; import com.google.protobuf.TypeRegistry; -import dev.cel.common.CelDescriptorUtil; -import dev.cel.common.CelDescriptors; import dev.cel.common.internal.DefaultInstanceMessageFactory; import dev.cel.common.internal.ProtoTimeUtils; import dev.cel.common.types.CelType; @@ -55,8 +53,6 @@ public final class ExprValueUtils { private ExprValueUtils() {} - public static final TypeRegistry DEFAULT_TYPE_REGISTRY = newDefaultTypeRegistry(); - public static final ExtensionRegistry DEFAULT_EXTENSION_REGISTRY = newDefaultExtensionRegistry(); /** * Converts a {@link Value} to a Java native object using the given file descriptor set to parse @@ -68,10 +64,9 @@ private ExprValueUtils() {} * @throws IOException If there's an error during conversion. */ public static Object fromValue(Value value, String fileDescriptorSetPath) throws IOException { - if (value.getKindCase().equals(Value.KindCase.OBJECT_VALUE)) { - return parseAny(value.getObjectValue(), fileDescriptorSetPath); - } - return toNativeObject(value); + TypeRegistry typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptorSetPath); + ExtensionRegistry extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptorSetPath); + return fromValue(value, typeRegistry, extensionRegistry); } /** @@ -81,19 +76,38 @@ public static Object fromValue(Value value, String fileDescriptorSetPath) throws * @return The converted Java object. * @throws IOException If there's an error during conversion. */ - public static Object fromValue(Value value) throws IOException { + + /** + * Converts a {@link Value} to a Java native object using custom registries. + * + * @param value The {@link Value} to convert. + * @param typeRegistry The type registry to use for object resolution. + * @param extensionRegistry The extension registry to use for object resolution. + * @return The converted Java object. + * @throws IOException If there's an error during conversion. + */ + public static Object fromValue( + Value value, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) + throws IOException { if (value.getKindCase().equals(Value.KindCase.OBJECT_VALUE)) { Descriptor descriptor = - DEFAULT_TYPE_REGISTRY.getDescriptorForTypeUrl(value.getObjectValue().getTypeUrl()); + typeRegistry.getDescriptorForTypeUrl(value.getObjectValue().getTypeUrl()); + if (descriptor == null) { + throw new IOException( + "Unknown type, descriptor was not found in registry: " + + value.getObjectValue().getTypeUrl()); + } Message prototype = getDefaultInstance(descriptor); return prototype .getParserForType() - .parseFrom(value.getObjectValue().getValue(), DEFAULT_EXTENSION_REGISTRY); + .parseFrom(value.getObjectValue().getValue(), extensionRegistry); } - return toNativeObject(value); + return toNativeObject(value, typeRegistry, extensionRegistry); } - private static Object toNativeObject(Value value) throws IOException { + private static Object toNativeObject( + Value value, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) + throws IOException { switch (value.getKindCase()) { case NULL_VALUE: return dev.cel.common.values.NullValue.NULL_VALUE; @@ -118,7 +132,9 @@ private static Object toNativeObject(Value value) throws IOException { ImmutableMap.Builder builder = ImmutableMap.builderWithExpectedSize(map.getEntriesCount()); for (MapValue.Entry entry : map.getEntriesList()) { - builder.put(fromValue(entry.getKey()), fromValue(entry.getValue())); + builder.put( + fromValue(entry.getKey(), typeRegistry, extensionRegistry), + fromValue(entry.getValue(), typeRegistry, extensionRegistry)); } return builder.buildOrThrow(); } @@ -128,7 +144,7 @@ private static Object toNativeObject(Value value) throws IOException { ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(list.getValuesCount()); for (Value element : list.getValuesList()) { - builder.add(fromValue(element)); + builder.add(fromValue(element, typeRegistry, extensionRegistry)); } return builder.build(); } @@ -181,7 +197,7 @@ public static Value toValue(Object object, CelType type) throws Exception { if (object instanceof dev.cel.expr.Value) { object = Value.parseFrom( - ((dev.cel.expr.Value) object).toByteArray(), DEFAULT_EXTENSION_REGISTRY); + ((dev.cel.expr.Value) object).toByteArray(), ExtensionRegistry.getEmptyRegistry()); } if (object instanceof Value) { return (Value) object; @@ -287,19 +303,6 @@ public static Value toValue(Object object, CelType type) throws Exception { String.format("Unexpected result type: %s", object.getClass())); } - private static Message parseAny(Any value, String fileDescriptorSetPath) throws IOException { - TypeRegistry typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptorSetPath); - ExtensionRegistry extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptorSetPath); - Descriptor descriptor = typeRegistry.getDescriptorForTypeUrl(value.getTypeUrl()); - return unpackAny(value, descriptor, extensionRegistry); - } - - private static Message unpackAny( - Any value, Descriptor descriptor, ExtensionRegistry extensionRegistry) throws IOException { - Message defaultInstance = getDefaultInstance(descriptor); - return defaultInstance.getParserForType().parseFrom(value.getValue(), extensionRegistry); - } - private static Message getDefaultInstance(Descriptor descriptor) { return DefaultInstanceMessageFactory.getInstance() .getPrototype(descriptor) @@ -309,20 +312,6 @@ private static Message getDefaultInstance(Descriptor descriptor) { "Could not find a default message for: " + descriptor.getFullName())); } - private static ExtensionRegistry newDefaultExtensionRegistry() { - ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); - dev.cel.expr.conformance.proto2.TestAllTypesExtensions.registerAllExtensions(extensionRegistry); - return extensionRegistry; - } - private static TypeRegistry newDefaultTypeRegistry() { - CelDescriptors allDescriptors = - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - ImmutableList.of( - dev.cel.expr.conformance.proto2.TestAllTypes.getDescriptor().getFile(), - dev.cel.expr.conformance.proto3.TestAllTypes.getDescriptor().getFile())); - - return TypeRegistry.newBuilder().add(allDescriptors.messageTypeDescriptors()).build(); - } } diff --git a/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java b/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java index d5a5248a8..b83375b35 100644 --- a/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java +++ b/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java @@ -26,6 +26,7 @@ import dev.cel.common.types.SimpleType; import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase; +import java.util.Map; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -281,4 +282,74 @@ public void triggerRunTest_evaluateRawExpr_withCoverage() throws Exception { .build(), celCoverageIndex); } + + @Test + public void runTest_withBindingTransformer() throws Exception { + CelTestCase testCase = + CelTestCase.newBuilder() + .setName("binding_transformer_test") + .setDescription("Test binding transformer") + .setInput( + CelTestCase.Input.ofBindings( + ImmutableMap.of("x", CelTestCase.Input.Binding.ofValue(1L)))) + .setOutput(CelTestCase.Output.ofResultValue(3L)) // 1 + 1 (transformed) + 1 (expr) = 3 + .build(); + + TestRunnerLibrary.evaluateTestCase( + testCase, + CelTestContext.newBuilder() + .setCelExpression(CelExpressionSource.fromRawExpr("x + 1")) + .setCel(CelFactory.standardCelBuilder().addVar("x", SimpleType.INT).build()) + .setBindingTransformer( + bindings -> { + ImmutableMap.Builder transformed = ImmutableMap.builder(); + for (Map.Entry entry : bindings.entrySet()) { + if (entry.getKey().equals("x")) { + transformed.put("x", (Long) entry.getValue() + 1L); + } else { + transformed.put(entry); + } + } + return transformed.buildOrThrow(); + }) + .build()); + } + + @Test + public void runTest_withMessageTypes() throws Exception { + CelTestCase testCase = + CelTestCase.newBuilder() + .setName("message_types_consolidation_test") + .setDescription("Test message types consolidation") + .setOutput(CelTestCase.Output.ofResultValue(true)) + .build(); + + TestRunnerLibrary.evaluateTestCase( + testCase, + CelTestContext.newBuilder() + .setCelExpression( + CelExpressionSource.fromRawExpr( + "cel.expr.conformance.proto3.TestAllTypes{single_int64: 1} ==" + + " cel.expr.conformance.proto3.TestAllTypes{single_int64: 1}")) + .addMessageTypes(TestAllTypes.getDescriptor()) + .build()); + } + + @Test + public void typeRegistry_withFileTypes() throws Exception { + CelTestContext celTestContext = + CelTestContext.newBuilder() + .setCelExpression(CelExpressionSource.fromRawExpr("true")) + .setCel(CelFactory.standardCelBuilder().build()) + .addMessageTypes(TestAllTypes.getDescriptor()) + .build(); + + assertThat( + celTestContext + .typeRegistry() + .get() + .find("cel.expr.conformance.proto3.TestAllTypes") + .getFullName()) + .isEqualTo("cel.expr.conformance.proto3.TestAllTypes"); + } }