diff --git a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java index 13c3392d3..bf0b9f0cd 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java @@ -35,6 +35,7 @@ import dev.cel.expr.Reference; import dev.cel.expr.Type; import dev.cel.expr.Type.PrimitiveType; +import com.google.common.collect.ImmutableCollection; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -63,7 +64,6 @@ import dev.cel.checker.CelCheckerLegacyImpl; import dev.cel.checker.DescriptorTypeProvider; import dev.cel.checker.ProtoTypeMask; -import dev.cel.checker.TypeProvider; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelDescriptorUtil; @@ -82,12 +82,14 @@ import dev.cel.common.types.CelProtoMessageTypes; import dev.cel.common.types.CelProtoTypes; import dev.cel.common.types.CelType; +import dev.cel.common.types.CelTypeProvider; import dev.cel.common.types.EnumType; import dev.cel.common.types.ListType; import dev.cel.common.types.MapType; import dev.cel.common.types.OptionalType; import dev.cel.common.types.ProtoMessageTypeProvider; import dev.cel.common.types.SimpleType; +import dev.cel.common.types.StructType; import dev.cel.common.types.StructTypeReference; import dev.cel.common.values.CelByteString; import dev.cel.common.values.NullValue; @@ -123,7 +125,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ThreadPoolExecutor; -import org.jspecify.annotations.Nullable; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -297,13 +298,12 @@ public void compile_customTypeProvider() { @Test public void compile_customTypesWithAliasingCombinedProviders() throws Exception { - // The custom type provider sets up an alias from "Condition" to "google.type.Expr". // However, the first type resolution from the alias to the qualified type name won't be // sufficient as future checks will expect the resolved alias to also be a type. - TypeProvider customTypeProvider = + CelTypeProvider customTypeProvider = aliasingProvider( - ImmutableMap.of("Condition", CelProtoTypes.createMessage("google.type.Expr"))); + ImmutableMap.of("Condition", StructTypeReference.create("google.type.Expr"))); // The registration of the aliasing TypeProvider and the google.type.Expr descriptor // ensures that once the alias is resolved, the additional details about the Expr type @@ -329,15 +329,19 @@ public void compile_customTypesWithAliasingCombinedProviders() throws Exception @Test public void compile_customTypesWithAliasingSelfContainedProvider() throws Exception { - // The custom type provider sets up an alias from "Condition" to "google.type.Expr". - TypeProvider customTypeProvider = + StructType exprStruct = StructType.create( + "google.type.Expr", + ImmutableSet.of("expression"), + fieldName -> Optional.of(SimpleType.STRING) + ); + CelTypeProvider customTypeProvider = aliasingProvider( ImmutableMap.of( "Condition", - CelProtoTypes.createMessage("google.type.Expr"), + exprStruct, "google.type.Expr", - CelProtoTypes.createMessage("google.type.Expr"))); + exprStruct)); // The registration of the aliasing TypeProvider and the google.type.Expr descriptor // ensures that once the alias is resolved, the additional details about the Expr type @@ -1001,14 +1005,11 @@ public void program_protoActivation() throws Exception { } @Test - @TestParameters("{resolveTypeDependencies: false}") - @TestParameters("{resolveTypeDependencies: true}") - public void program_enumTypeDirectResolution(boolean resolveTypeDependencies) throws Exception { + public void program_enumTypeDirectResolution() throws Exception { Cel cel = standardCelBuilderWithMacros() .addFileTypes(StandaloneGlobalEnum.getDescriptor().getFile()) - .setOptions( - CelOptions.current().resolveTypeDependencies(resolveTypeDependencies).build()) + .setOptions(CelOptions.current().resolveTypeDependencies(true).build()) .setContainer( CelContainer.ofName("dev.cel.testing.testdata.proto3.StandaloneGlobalEnum")) .setResultType(SimpleType.BOOL) @@ -2193,28 +2194,16 @@ public void toBuilder_isImmutable() { assertThat(newRuntimeBuilder).isNotEqualTo(celImpl.toRuntimeBuilder()); } - private static TypeProvider aliasingProvider(ImmutableMap typeAliases) { - return new TypeProvider() { - @Override - public @Nullable Type lookupType(String typeName) { - Type alias = typeAliases.get(typeName); - if (alias != null) { - return CelProtoTypes.create(alias); - } - return null; - } - + private static CelTypeProvider aliasingProvider(ImmutableMap typeAliases) { + return new CelTypeProvider() { @Override - public @Nullable Integer lookupEnumValue(String enumName) { - return null; + public ImmutableCollection types() { + return typeAliases.values(); } @Override - public TypeProvider.@Nullable FieldType lookupFieldType(Type type, String fieldName) { - if (typeAliases.containsKey(type.getMessageType())) { - return TypeProvider.FieldType.of(CelProtoTypes.STRING); - } - return null; + public Optional findType(String typeName) { + return Optional.ofNullable(typeAliases.get(typeName)); } }; } diff --git a/checker/src/main/java/dev/cel/checker/BUILD.bazel b/checker/src/main/java/dev/cel/checker/BUILD.bazel index 6c486bd92..a1f4d235c 100644 --- a/checker/src/main/java/dev/cel/checker/BUILD.bazel +++ b/checker/src/main/java/dev/cel/checker/BUILD.bazel @@ -32,6 +32,7 @@ CHECKER_LEGACY_ENV_SOURCES = [ "InferenceContext.java", "TypeFormatter.java", "TypeProvider.java", + "TypeProviderLegacyImpl.java", "Types.java", ] @@ -70,7 +71,6 @@ java_library( ":checker_legacy_environment", ":proto_type_mask", ":standard_decl", - ":type_provider_legacy_impl", "//:auto_value", "//common:cel_ast", "//common:cel_descriptor_util", @@ -158,12 +158,9 @@ java_library( "//:auto_value", "//common/annotations", "//common/types", - "//common/types:cel_proto_types", "//common/types:type_providers", - "@cel_spec//proto/cel/expr:checked_java_proto", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", - "@maven//:org_jspecify_jspecify", ], ) @@ -190,6 +187,7 @@ java_library( "//common/types", "//common/types:cel_proto_types", "//common/types:cel_types", + "//common/types:message_type_provider", "//common/types:type_providers", "//parser:macro", "@cel_spec//proto/cel/expr:checked_java_proto", diff --git a/checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java b/checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java index 41d1ca073..ec78eaff6 100644 --- a/checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java +++ b/checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java @@ -74,9 +74,6 @@ public final class CelCheckerLegacyImpl implements CelChecker, EnvVisitable { private final ImmutableSet functionDeclarations; private final Optional expectedResultType; - @SuppressWarnings("Immutable") - private final TypeProvider typeProvider; - private final CelTypeProvider celTypeProvider; private final boolean standardEnvironmentEnabled; @@ -163,11 +160,11 @@ public void accept(EnvVisitor envVisitor) { private Env getEnv(Errors errors) { Env env; if (standardEnvironmentEnabled) { - env = Env.standard(errors, typeProvider, celOptions); + env = Env.standard(errors, celTypeProvider, celOptions); } else if (overriddenStandardDeclarations != null) { - env = Env.standard(overriddenStandardDeclarations, errors, typeProvider, celOptions); + env = Env.standard(overriddenStandardDeclarations, errors, celTypeProvider, celOptions); } else { - env = Env.unconfigured(errors, typeProvider, celOptions); + env = Env.unconfigured(errors, celTypeProvider, celOptions); } identDeclarations.forEach(env::add); functionDeclarations.forEach(env::add); @@ -483,11 +480,10 @@ public CelCheckerLegacyImpl build() { messageTypeProvider = protoTypeMaskTypeProvider; } - TypeProvider legacyProvider = new TypeProviderLegacyImpl(messageTypeProvider); if (customTypeProvider != null) { - legacyProvider = - new TypeProvider.CombinedTypeProvider( - ImmutableList.of(customTypeProvider, legacyProvider)); + messageTypeProvider = + new CelTypeProvider.CombinedCelTypeProvider( + messageTypeProvider, new TypeProviderLegacyImpl(customTypeProvider)); } return new CelCheckerLegacyImpl( @@ -496,7 +492,7 @@ public CelCheckerLegacyImpl build() { identDeclarationSet, functionDeclarations.build(), Optional.fromNullable(expectedResultType), - legacyProvider, + customTypeProvider, messageTypeProvider, standardEnvironmentEnabled, standardDeclarations, @@ -535,7 +531,6 @@ private CelCheckerLegacyImpl( this.identDeclarations = identDeclarations; this.functionDeclarations = functionDeclarations; this.expectedResultType = expectedResultType; - this.typeProvider = typeProvider; this.celTypeProvider = celTypeProvider; this.standardEnvironmentEnabled = standardEnvironmentEnabled; this.overriddenStandardDeclarations = overriddenStandardDeclarations; diff --git a/checker/src/main/java/dev/cel/checker/DescriptorTypeProvider.java b/checker/src/main/java/dev/cel/checker/DescriptorTypeProvider.java index b5f849d50..174ee3198 100644 --- a/checker/src/main/java/dev/cel/checker/DescriptorTypeProvider.java +++ b/checker/src/main/java/dev/cel/checker/DescriptorTypeProvider.java @@ -18,6 +18,7 @@ import com.google.auto.value.AutoValue; import com.google.common.base.Ascii; import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableCollection; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -32,14 +33,18 @@ import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.Descriptors.OneofDescriptor; -import dev.cel.common.annotations.Internal; import dev.cel.common.internal.FileDescriptorSetConverter; +import dev.cel.common.types.CelProtoTypes; +import dev.cel.common.types.CelType; +import dev.cel.common.types.ProtoMessageType; +import dev.cel.common.types.TypeType; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import org.jspecify.annotations.Nullable; @@ -47,13 +52,10 @@ * The {@code DescriptorTypeProvider} provides type information for one or more {@link Descriptor} * instances of proto messages. * - *

TODO: Unify implementation across the runtime (i.e: DescriptorMessageProvider) - * and the compilation. This class can likely be eliminated as part of the work. - * - *

CEL Library Internals. Do Not Use. + * @deprecated Do not use. Migrate to {@code ProtoMessageTypeProvider). */ @Immutable -@Internal +@Deprecated public class DescriptorTypeProvider implements TypeProvider { @SuppressWarnings("Immutable") @@ -86,6 +88,45 @@ public DescriptorTypeProvider(Iterable descriptors) { return typeDef != null ? Types.create(Types.createMessage(typeDef.name())) : null; } + @Override + public Optional lookupCelType(String typeName) { + TypeDef typeDef = lookupMessageTypeDef(typeName); + if (typeDef == null) { + return Optional.empty(); + } + + ImmutableSet.Builder fieldsBuilder = ImmutableSet.builder(); + for (FieldDef fieldDef : typeDef.fields()) { + fieldsBuilder.add(fieldDef.name()); + } + + @SuppressWarnings("Immutable") + ProtoMessageType protoMessageType = + ProtoMessageType.create( + typeName, + fieldsBuilder.build(), + fieldName -> { + FieldDef fieldDef = typeDef.lookupField(fieldName); + if (fieldDef == null) { + return Optional.empty(); + } + + Type type = fieldDefToType(fieldDef); + return Optional.of(CelProtoTypes.typeToCelType(type)); + }, + extensionFieldName -> { + ExtensionFieldType extensionFieldType = + symbolTable.lookupExtension(extensionFieldName); + if (extensionFieldType == null) { + return Optional.empty(); + } + + return Optional.of(extensionFieldType.fieldType().celType()); + }); + + return Optional.of(TypeType.create(protoMessageType)); + } + @Override public @Nullable Integer lookupEnumValue(String enumName) { int dot = enumName.lastIndexOf('.'); @@ -339,6 +380,8 @@ private TypeDef buildTypeDef(EnumDescriptor descriptor, Map typ /** Value object for a proto-based primitive, message, or enum definition. */ @AutoValue + @AutoValue.CopyAnnotations + @SuppressWarnings("Immutable") protected abstract static class TypeDef { /** The qualified name of the message or enum. */ @@ -434,12 +477,28 @@ static TypeDef ofEnum(String name, Iterable enumValues) { } } + @Override + public ImmutableCollection types() { + ImmutableList.Builder typesBuilder = ImmutableList.builder(); + for (TypeDef typeDef : symbolTable.typeMap.values()) { + TypeType typeType = lookupCelType(typeDef.name()).orElse(null); + if (typeType == null) { + continue; + } + + typesBuilder.add(typeType.type()); + } + + return typesBuilder.build(); + } + /** * Value object for a proto-based field definition. * *

Only one of the {@link #type} or {@link #mapEntryType} may be set. */ @AutoValue + @AutoValue.CopyAnnotations protected abstract static class FieldDef { /** The field name. */ diff --git a/checker/src/main/java/dev/cel/checker/Env.java b/checker/src/main/java/dev/cel/checker/Env.java index 7029781e5..1d375173b 100644 --- a/checker/src/main/java/dev/cel/checker/Env.java +++ b/checker/src/main/java/dev/cel/checker/Env.java @@ -19,7 +19,6 @@ import dev.cel.expr.Decl.FunctionDecl.Overload; import dev.cel.expr.Expr; import dev.cel.expr.Type; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -43,8 +42,12 @@ import dev.cel.common.types.CelKind; import dev.cel.common.types.CelProtoTypes; import dev.cel.common.types.CelType; +import dev.cel.common.types.CelTypeProvider; import dev.cel.common.types.CelTypes; +import dev.cel.common.types.EnumType; +import dev.cel.common.types.ProtoMessageTypeProvider; import dev.cel.common.types.SimpleType; +import dev.cel.common.types.TypeType; import dev.cel.parser.CelStandardMacro; import java.util.ArrayList; import java.util.HashMap; @@ -78,7 +81,7 @@ public class Env { CelFunctionDecl.newBuilder().setName("*error*").build(); /** Type provider responsible for resolving CEL message references to strong types. */ - private final TypeProvider typeProvider; + private final CelTypeProvider typeProvider; /** * Stack of declaration groups where each entry in stack represents a scope capable of hinding @@ -105,7 +108,7 @@ public class Env { .build(); private Env( - Errors errors, TypeProvider typeProvider, DeclGroup declGroup, CelOptions celOptions) { + Errors errors, CelTypeProvider typeProvider, DeclGroup declGroup, CelOptions celOptions) { this.celOptions = celOptions; this.errors = Preconditions.checkNotNull(errors); this.typeProvider = Preconditions.checkNotNull(typeProvider); @@ -118,27 +121,10 @@ private Env( */ @Deprecated public static Env unconfigured(Errors errors) { - return unconfigured(errors, LEGACY_TYPE_CHECKER_OPTIONS); + return unconfigured(errors, new ProtoMessageTypeProvider(), LEGACY_TYPE_CHECKER_OPTIONS); } - /** - * Creates an unconfigured {@code Env} value without the standard CEL types, functions, and - * operators with a reference to the configured {@code celOptions}. - */ - @VisibleForTesting - static Env unconfigured(Errors errors, CelOptions celOptions) { - return unconfigured(errors, new DescriptorTypeProvider(), celOptions); - } - - /** - * Creates an unconfigured {@code Env} value without the standard CEL types, functions, and - * operators using a custom {@code typeProvider}. - * - * @deprecated Do not use. This exists for compatibility reasons. Migrate to CEL-Java fluent APIs. - * See {@code CelCompilerFactory}. - */ - @Deprecated - public static Env unconfigured(Errors errors, TypeProvider typeProvider, CelOptions celOptions) { + static Env unconfigured(Errors errors, CelTypeProvider typeProvider, CelOptions celOptions) { return new Env(errors, typeProvider, new DeclGroup(), celOptions); } @@ -148,7 +134,7 @@ public static Env unconfigured(Errors errors, TypeProvider typeProvider, CelOpti */ @Deprecated public static Env standard(Errors errors) { - return standard(errors, new DescriptorTypeProvider()); + return standard(errors, new ProtoMessageTypeProvider(), LEGACY_TYPE_CHECKER_OPTIONS); } /** @@ -173,6 +159,11 @@ public static Env standard(Errors errors, TypeProvider typeProvider) { */ @Deprecated public static Env standard(Errors errors, TypeProvider typeProvider, CelOptions celOptions) { + CelTypeProvider adapted = new TypeProviderLegacyImpl(typeProvider); + return standard(errors, adapted, celOptions); + } + + static Env standard(Errors errors, CelTypeProvider typeProvider, CelOptions celOptions) { CelStandardDeclarations celStandardDeclaration = CelStandardDeclarations.newBuilder() .filterFunctions( @@ -209,10 +200,10 @@ public static Env standard(Errors errors, TypeProvider typeProvider, CelOptions return standard(celStandardDeclaration, errors, typeProvider, celOptions); } - public static Env standard( + static Env standard( CelStandardDeclarations celStandardDeclaration, Errors errors, - TypeProvider typeProvider, + CelTypeProvider typeProvider, CelOptions celOptions) { Env env = Env.unconfigured(errors, typeProvider, celOptions); // Isolate the standard declarations into their own scope for forward compatibility. @@ -228,8 +219,8 @@ public Errors getErrorContext() { return errors; } - /** Returns the {@code TypeProvider}. */ - public TypeProvider getTypeProvider() { + /** Returns the {@code CelTypeProvider}. */ + public CelTypeProvider getTypeProvider() { return typeProvider; } @@ -491,30 +482,52 @@ public Env add(String name, Type type) { // Next try to import the name as a reference to a message type. // This is done via the type provider. - Optional type = typeProvider.lookupCelType(cand); + Optional type = typeProvider.findType(cand); if (type.isPresent()) { - decl = CelIdentDecl.newIdentDeclaration(cand, type.get()); + decl = CelIdentDecl.newIdentDeclaration(cand, TypeType.create(type.get())); decls.get(0).putIdent(decl); return decl; } // Next try to import this as an enum value by splitting the name in a type prefix and // the enum inside. - Integer enumValue = typeProvider.lookupEnumValue(cand); - if (enumValue != null) { + Optional enumValue = findEnumValue(cand); + if (enumValue.isPresent()) { decl = CelIdentDecl.newBuilder() .setName(cand) .setType(SimpleType.INT) - .setConstant(CelConstant.ofValue(enumValue)) + .setConstant(CelConstant.ofValue(enumValue.get())) .build(); decls.get(0).putIdent(decl); return decl; } + return null; } + private Optional findEnumValue(String fullyQualifiedEnumName) { + int dot = fullyQualifiedEnumName.lastIndexOf('.'); + if (dot <= 0) { + return Optional.empty(); + } + + String enumTypeName = fullyQualifiedEnumName.substring(0, dot); + EnumType enumType = + typeProvider + .findType(enumTypeName) + .filter(t -> t instanceof EnumType) + .map(EnumType.class::cast) + .orElse(null); + if (enumType == null) { + return Optional.empty(); + } + + String enumValueName = fullyQualifiedEnumName.substring(dot + 1); + return enumType.findNumberByName(enumValueName); + } + /** * Lookup a local identifier by name. This searches only comprehension scopes, bypassing standard * environment or user-defined environment. diff --git a/checker/src/main/java/dev/cel/checker/ExprChecker.java b/checker/src/main/java/dev/cel/checker/ExprChecker.java index 37b692ecf..6f4395459 100644 --- a/checker/src/main/java/dev/cel/checker/ExprChecker.java +++ b/checker/src/main/java/dev/cel/checker/ExprChecker.java @@ -39,11 +39,16 @@ import dev.cel.common.types.CelKind; import dev.cel.common.types.CelProtoTypes; import dev.cel.common.types.CelType; +import dev.cel.common.types.CelTypeProvider; import dev.cel.common.types.CelTypes; import dev.cel.common.types.ListType; import dev.cel.common.types.MapType; import dev.cel.common.types.OptionalType; +import dev.cel.common.types.ProtoMessageType; +import dev.cel.common.types.ProtoMessageType.Extension; import dev.cel.common.types.SimpleType; +import dev.cel.common.types.StructType; +import dev.cel.common.types.StructType.Field; import dev.cel.common.types.TypeType; import java.util.ArrayList; import java.util.HashSet; @@ -143,7 +148,7 @@ public static CelAbstractSyntaxTree typecheck( } private final Env env; - private final TypeProvider typeProvider; + private final CelTypeProvider typeProvider; private final CelContainer container; private final Map positionMap; private final InferenceContext inferenceContext; @@ -256,6 +261,7 @@ private CelExpr visit(CelExpr expr, CelExpr.CelIdent ident) { private CelExpr visit(CelExpr expr, CelExpr.CelSelect select) { // Before traversing down the tree, try to interpret as qualified name. String qname = asQualifiedName(expr); + if (qname != null) { CelIdentDecl decl = env.tryLookupCelIdent(container, qname); if (decl != null) { @@ -410,7 +416,7 @@ private CelExpr visit(CelExpr expr, CelExpr.CelStruct struct) { expr = replaceStructEntryValueSubtree(expr, visitedValueExpr, i); } CelType fieldType = - getFieldType(entry.id(), getPosition(entry), messageType, entry.fieldKey()).celType(); + getFieldType(entry.id(), getPosition(entry), messageType, entry.fieldKey()).type(); CelType valueType = env.getType(visitedValueExpr); if (entry.optionalEntry()) { if (valueType instanceof OptionalType) { @@ -716,7 +722,7 @@ private OverloadResolution resolveOverload( // Return value from visit is not needed as the subtree is not rewritten here. @SuppressWarnings("CheckReturnValue") private CelType visitSelectField( - CelExpr expr, CelExpr operand, String field, boolean isOptional) { + CelExpr expr, CelExpr operand, String fieldName, boolean isOptional) { CelType operandType = inferenceContext.specialize(env.getType(operand)); CelType resultType = SimpleType.ERROR; @@ -727,10 +733,8 @@ private CelType visitSelectField( if (!Types.isDynOrError(operandType)) { if (operandType.kind() == CelKind.STRUCT) { - TypeProvider.FieldType fieldType = - getFieldType(expr.id(), getPosition(expr), operandType, field); - // Type of the field - resultType = fieldType.celType(); + Field field = getFieldType(expr.id(), getPosition(expr), operandType, fieldName); + resultType = field.type(); } else if (operandType.kind() == CelKind.MAP) { resultType = ((MapType) operandType).valueType(); } else if (operandType.kind() == CelKind.TYPE_PARAM) { @@ -805,19 +809,30 @@ private CelExpr visitOptionalCall(CelExpr expr, CelExpr.CelCall call) { } /** Returns the field type give a type instance and field name. */ - private TypeProvider.FieldType getFieldType( - long exprId, int position, CelType type, String fieldName) { + private Field getFieldType(long exprId, int position, CelType type, String fieldName) { String typeName = type.name(); - if (typeProvider.lookupCelType(typeName).isPresent()) { - TypeProvider.FieldType fieldType = typeProvider.lookupFieldType(type, fieldName); - if (fieldType != null) { - return fieldType; + StructType structType = + typeProvider + .findType(typeName) + .filter(t -> t instanceof StructType) + .map(StructType.class::cast) + .orElse(null); + + if (structType != null) { + Field field = structType.findField(fieldName).orElse(null); + if (field != null) { + return field; } - TypeProvider.ExtensionFieldType extensionFieldType = - typeProvider.lookupExtensionType(fieldName); - if (extensionFieldType != null) { - return extensionFieldType.fieldType(); + + if (structType instanceof ProtoMessageType) { + Extension extensionField = + ((ProtoMessageType) structType).findExtension(fieldName).orElse(null); + + if (extensionField != null) { + return Field.of(extensionField.name(), extensionField.type()); + } } + env.reportError(exprId, position, "undefined field '%s'", fieldName); } else { // Proto message was added as a variable to the environment but the descriptor was not @@ -831,6 +846,7 @@ private TypeProvider.FieldType getFieldType( } env.reportError(exprId, position, errorMessage, fieldName, typeName); } + return ERROR; } @@ -892,8 +908,8 @@ public static OverloadResolution of(CelReference reference, CelType type) { } } - /** Helper object to represent a {@link TypeProvider.FieldType} lookup failure. */ - private static final TypeProvider.FieldType ERROR = TypeProvider.FieldType.of(Types.ERROR); + /** Helper object to represent a {@link CelTypeProvider#findType(String)} lookup failure. */ + private static final Field ERROR = Field.of(SimpleType.ERROR.name(), SimpleType.ERROR); private static CelExpr replaceIdentSubtree(CelExpr expr, String name) { CelExpr.CelIdent newIdent = CelExpr.CelIdent.newBuilder().setName(name).build(); diff --git a/checker/src/main/java/dev/cel/checker/TypeProvider.java b/checker/src/main/java/dev/cel/checker/TypeProvider.java index 2dd5261ab..1aff54d00 100644 --- a/checker/src/main/java/dev/cel/checker/TypeProvider.java +++ b/checker/src/main/java/dev/cel/checker/TypeProvider.java @@ -16,10 +16,12 @@ import dev.cel.expr.Type; import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableCollection; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import dev.cel.common.types.CelProtoTypes; import dev.cel.common.types.CelType; +import dev.cel.common.types.TypeType; import java.util.Optional; import java.util.function.Function; import org.jspecify.annotations.Nullable; @@ -36,9 +38,13 @@ public interface TypeProvider { @Nullable Type lookupType(String typeName); /** Lookup the a {@link CelType} given a qualified {@code typeName}. Returns null if not found. */ - default Optional lookupCelType(String typeName) { + default Optional lookupCelType(String typeName) { Type type = lookupType(typeName); - return Optional.ofNullable(type).map(CelProtoTypes::typeToCelType); + return Optional.ofNullable(type).map(CelProtoTypes::typeToCelType).map(TypeType.class::cast); + } + + default ImmutableCollection types() { + return ImmutableList.of(); } /** Lookup the {@code Integer} enum value given an {@code enumName}. Returns null if not found. */ diff --git a/checker/src/main/java/dev/cel/checker/TypeProviderLegacyImpl.java b/checker/src/main/java/dev/cel/checker/TypeProviderLegacyImpl.java index b2ac51d95..e77614cf6 100644 --- a/checker/src/main/java/dev/cel/checker/TypeProviderLegacyImpl.java +++ b/checker/src/main/java/dev/cel/checker/TypeProviderLegacyImpl.java @@ -14,19 +14,13 @@ package dev.cel.checker; -import dev.cel.expr.Type; -import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableCollection; import com.google.errorprone.annotations.CheckReturnValue; import dev.cel.common.annotations.Internal; -import dev.cel.common.types.CelProtoTypes; import dev.cel.common.types.CelType; import dev.cel.common.types.CelTypeProvider; -import dev.cel.common.types.EnumType; -import dev.cel.common.types.ProtoMessageType; -import dev.cel.common.types.StructType; import dev.cel.common.types.TypeType; import java.util.Optional; -import org.jspecify.annotations.Nullable; /** * The {@code TypeProviderLegacyImpl} acts as a bridge between the old and new type provider APIs @@ -35,87 +29,31 @@ */ @CheckReturnValue @Internal -final class TypeProviderLegacyImpl implements TypeProvider { +final class TypeProviderLegacyImpl implements CelTypeProvider { - private final CelTypeProvider celTypeProvider; + // Legacy typeProvider is immutable, just not marked as such + @SuppressWarnings("Immutable") + private final TypeProvider typeProvider; - TypeProviderLegacyImpl(CelTypeProvider celTypeProvider) { - this.celTypeProvider = celTypeProvider; + TypeProviderLegacyImpl(TypeProvider typeProvider) { + this.typeProvider = typeProvider; } @Override - public @Nullable Type lookupType(String typeName) { - return lookupCelType(typeName).map(CelProtoTypes::celTypeToType).orElse(null); + public ImmutableCollection types() { + return typeProvider.types(); } @Override - public Optional lookupCelType(String typeName) { - return celTypeProvider.findType(typeName).map(TypeType::create); - } - - @Override - public @Nullable FieldType lookupFieldType(CelType type, String fieldName) { - String messageType = type.name(); - StructType structType = - (StructType) - celTypeProvider.findType(messageType).filter(t -> t instanceof StructType).orElse(null); - if (structType == null) { - return null; - } - - return structType - .findField(fieldName) - .map(f -> FieldType.of(CelProtoTypes.celTypeToType(f.type()))) + public Optional findType(String typeName) { + TypeType type = typeProvider + .lookupCelType(typeName) .orElse(null); - } - @Override - public @Nullable FieldType lookupFieldType(Type type, String fieldName) { - return lookupFieldType(CelProtoTypes.typeToCelType(type), fieldName); - } - - @Override - public @Nullable ImmutableSet lookupFieldNames(Type type) { - String messageType = type.getMessageType(); - return celTypeProvider - .findType(messageType) - .filter(t -> t instanceof StructType) - .map(t -> ((StructType) t).fieldNames()) - .orElse(null); - } - - @Override - public @Nullable Integer lookupEnumValue(String enumName) { - int dotIndex = enumName.lastIndexOf("."); - if (dotIndex < 0 || dotIndex == enumName.length() - 1) { - return null; + if (type == null) { + return Optional.empty(); } - String enumTypeName = enumName.substring(0, dotIndex); - String localEnumName = enumName.substring(dotIndex + 1); - return celTypeProvider - .findType(enumTypeName) - .filter(t -> t instanceof EnumType) - .flatMap(t -> ((EnumType) t).findNumberByName(localEnumName)) - .orElse(null); - } - - @Override - public @Nullable ExtensionFieldType lookupExtensionType(String extensionName) { - Optional extension = - celTypeProvider.types().stream() - .filter(t -> t instanceof ProtoMessageType) - .map(t -> (ProtoMessageType) t) - .map(t -> t.findExtension(extensionName)) - .filter(Optional::isPresent) - .map(Optional::get) - .findFirst(); - return extension - .map( - et -> - ExtensionFieldType.of( - CelProtoTypes.celTypeToType(et.type()), - CelProtoTypes.celTypeToType(et.messageType()))) - .orElse(null); + return Optional.of(type.type()); } } diff --git a/checker/src/test/java/dev/cel/checker/ExprCheckerTest.java b/checker/src/test/java/dev/cel/checker/ExprCheckerTest.java index d5d5d9a3a..46f75ad36 100644 --- a/checker/src/test/java/dev/cel/checker/ExprCheckerTest.java +++ b/checker/src/test/java/dev/cel/checker/ExprCheckerTest.java @@ -91,7 +91,7 @@ private void runTest() throws Exception { private void runErroneousTest(CelAbstractSyntaxTree parsedAst) { checkArgument(!parsedAst.isChecked()); Errors errors = new Errors("", source); - Env env = Env.unconfigured(errors, TEST_OPTIONS); + Env env = Env.unconfigured(errors, new ProtoMessageTypeProvider(), TEST_OPTIONS); ExprChecker.typecheck(env, container, parsedAst, Optional.absent()); testOutput().println(errors.getAllErrorsAsString()); testOutput().println(); diff --git a/checker/src/test/java/dev/cel/checker/TypeProviderLegacyImplTest.java b/checker/src/test/java/dev/cel/checker/TypeProviderLegacyImplTest.java index 4569877c3..6f274feb2 100644 --- a/checker/src/test/java/dev/cel/checker/TypeProviderLegacyImplTest.java +++ b/checker/src/test/java/dev/cel/checker/TypeProviderLegacyImplTest.java @@ -15,16 +15,13 @@ package dev.cel.checker; import static com.google.common.truth.Truth.assertThat; -import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; -import dev.cel.expr.Type; +import com.google.common.collect.ImmutableCollection; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import com.google.protobuf.Descriptors.Descriptor; -import dev.cel.common.types.CelProtoTypes; -import dev.cel.common.types.ProtoMessageTypeProvider; -import dev.cel.expr.conformance.proto2.Proto2ExtensionScopedMessage; +import dev.cel.common.types.CelType; import dev.cel.expr.conformance.proto2.TestAllTypes; +import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -33,101 +30,40 @@ public final class TypeProviderLegacyImplTest { private static final ImmutableList DESCRIPTORS = - ImmutableList.of(TestAllTypes.getDescriptor(), Proto2ExtensionScopedMessage.getDescriptor()); - - private final ProtoMessageTypeProvider proto2Provider = new ProtoMessageTypeProvider(DESCRIPTORS); - - private final DescriptorTypeProvider descriptorTypeProvider = - new DescriptorTypeProvider(DESCRIPTORS); - - private final TypeProviderLegacyImpl compatTypeProvider = - new TypeProviderLegacyImpl(proto2Provider); + ImmutableList.of(TestAllTypes.getDescriptor()); @Test - public void lookupType() { - assertThat(compatTypeProvider.lookupType("cel.expr.conformance.proto2.TestAllTypes")) - .isEqualTo(descriptorTypeProvider.lookupType("cel.expr.conformance.proto2.TestAllTypes")); - assertThat(compatTypeProvider.lookupType("not.registered.TypeName")) - .isEqualTo(descriptorTypeProvider.lookupType("not.registered.TypeName")); - } + public void findType_delegatesToLegacyLookup() { + DescriptorTypeProvider legacyProvider = new DescriptorTypeProvider(DESCRIPTORS); + TypeProviderLegacyImpl celTypeProvider = new TypeProviderLegacyImpl(legacyProvider); + String typeName = TestAllTypes.getDescriptor().getFullName(); - @Test - public void lookupFieldNames() { - Type nestedTestAllTypes = - compatTypeProvider.lookupType("cel.expr.conformance.proto2.NestedTestAllTypes").getType(); - ImmutableSet fieldNames = compatTypeProvider.lookupFieldNames(nestedTestAllTypes); - assertThat(fieldNames) - .containsExactlyElementsIn(descriptorTypeProvider.lookupFieldNames(nestedTestAllTypes)); - assertThat(fieldNames).containsExactly("payload", "child"); - } + Optional result = celTypeProvider.findType(typeName); - @Test - public void lookupFieldType() { - Type nestedTestAllTypes = - compatTypeProvider.lookupType("cel.expr.conformance.proto2.NestedTestAllTypes").getType(); - assertThat(compatTypeProvider.lookupFieldType(nestedTestAllTypes, "payload")) - .isEqualTo(descriptorTypeProvider.lookupFieldType(nestedTestAllTypes, "payload")); - assertThat(compatTypeProvider.lookupFieldType(nestedTestAllTypes, "child")) - .isEqualTo(descriptorTypeProvider.lookupFieldType(nestedTestAllTypes, "child")); + assertThat(result).isPresent(); + assertThat(result.get().name()).isEqualTo(typeName); } @Test - public void lookupFieldType_inputNotMessage() { - Type globalEnumType = - compatTypeProvider.lookupType("cel.expr.conformance.proto2.GlobalEnum").getType(); - assertThat(compatTypeProvider.lookupFieldType(globalEnumType, "payload")).isNull(); - assertThat(compatTypeProvider.lookupFieldType(globalEnumType, "payload")) - .isEqualTo(descriptorTypeProvider.lookupFieldType(globalEnumType, "payload")); - } + public void findType_returnsEmptyForUnknownType() { + DescriptorTypeProvider legacyProvider = new DescriptorTypeProvider(DESCRIPTORS); + TypeProviderLegacyImpl celTypeProvider = new TypeProviderLegacyImpl(legacyProvider); - @Test - public void lookupExtension() { - TypeProvider.ExtensionFieldType extensionType = - compatTypeProvider.lookupExtensionType("cel.expr.conformance.proto2.nested_enum_ext"); - assertThat(extensionType.messageType()) - .isEqualTo(CelProtoTypes.createMessage("cel.expr.conformance.proto2.TestAllTypes")); - assertThat(extensionType.fieldType().type()).isEqualTo(CelProtoTypes.INT64); - assertThat(extensionType) - .isEqualTo( - descriptorTypeProvider.lookupExtensionType( - "cel.expr.conformance.proto2.nested_enum_ext")); - } + Optional result = celTypeProvider.findType("unknown.Type"); - @Test - public void lookupEnumValue() { - Integer enumValue = - compatTypeProvider.lookupEnumValue("cel.expr.conformance.proto2.GlobalEnum.GAR"); - assertThat(enumValue).isEqualTo(1); - assertThat(enumValue) - .isEqualTo( - descriptorTypeProvider.lookupEnumValue("cel.expr.conformance.proto2.GlobalEnum.GAR")); + assertThat(result).isEmpty(); } @Test - public void lookupEnumValue_notFoundValue() { - Integer enumValue = - compatTypeProvider.lookupEnumValue("cel.expr.conformance.proto2.GlobalEnum.BAR"); - assertThat(enumValue).isNull(); - assertThat(enumValue) - .isEqualTo( - descriptorTypeProvider.lookupEnumValue("cel.expr.conformance.proto2.GlobalEnum.BAR")); - } + public void types_delegatesToLegacyTypes() { + DescriptorTypeProvider legacyProvider = new DescriptorTypeProvider(DESCRIPTORS); + TypeProviderLegacyImpl celTypeProvider = new TypeProviderLegacyImpl(legacyProvider); - @Test - public void lookupEnumValue_notFoundEnumType() { - Integer enumValue = - compatTypeProvider.lookupEnumValue("cel.expr.conformance.proto2.InvalidEnum.TEST"); - assertThat(enumValue).isNull(); - assertThat(enumValue) - .isEqualTo( - descriptorTypeProvider.lookupEnumValue("cel.expr.conformance.proto2.InvalidEnum.TEST")); - } + ImmutableCollection types = celTypeProvider.types(); - @Test - public void lookupEnumValue_notFoundBadEnumName() { - assertThat(compatTypeProvider.lookupEnumValue("TEST")).isNull(); - assertThat(compatTypeProvider.lookupEnumValue("TEST.")).isNull(); - assertThat(descriptorTypeProvider.lookupEnumValue("TEST")).isNull(); - assertThat(descriptorTypeProvider.lookupEnumValue("TEST.")).isNull(); + assertThat(types).isNotEmpty(); + assertThat(types).hasSize(legacyProvider.types().size()); + assertThat(types.stream().map(CelType::name)) + .contains(TestAllTypes.getDescriptor().getFullName()); } -} +} \ No newline at end of file diff --git a/common/src/main/java/dev/cel/common/types/CelTypeProvider.java b/common/src/main/java/dev/cel/common/types/CelTypeProvider.java index c452a9dee..c3680ab8c 100644 --- a/common/src/main/java/dev/cel/common/types/CelTypeProvider.java +++ b/common/src/main/java/dev/cel/common/types/CelTypeProvider.java @@ -48,6 +48,7 @@ public interface CelTypeProvider { final class CombinedCelTypeProvider implements CelTypeProvider { private final ImmutableMap allTypes; + private final ImmutableList typeProviders; public CombinedCelTypeProvider(CelTypeProvider first, CelTypeProvider second) { this(ImmutableList.of(first, second)); @@ -59,6 +60,7 @@ public CombinedCelTypeProvider(ImmutableList typeProviders) { typeProvider -> typeProvider.types().forEach(type -> allTypes.putIfAbsent(type.name(), type))); this.allTypes = ImmutableMap.copyOf(allTypes); + this.typeProviders = typeProviders; } @Override @@ -68,7 +70,14 @@ public ImmutableCollection types() { @Override public Optional findType(String typeName) { - return Optional.ofNullable(allTypes.get(typeName)); + for (CelTypeProvider typeProvider : typeProviders) { + Optional foundType = typeProvider.findType(typeName); + if (foundType.isPresent()) { + return foundType; + } + } + + return Optional.empty(); } } }