diff --git a/conjure-java-core/src/main/java/com/palantir/conjure/java/types/SafetyEvaluator.java b/conjure-java-core/src/main/java/com/palantir/conjure/java/types/SafetyEvaluator.java index 1f3ad20a0..0996a3123 100644 --- a/conjure-java-core/src/main/java/com/palantir/conjure/java/types/SafetyEvaluator.java +++ b/conjure-java-core/src/main/java/com/palantir/conjure/java/types/SafetyEvaluator.java @@ -38,6 +38,7 @@ import com.palantir.conjure.spec.TypeName; import com.palantir.conjure.spec.UnionDefinition; import com.palantir.logsafe.Preconditions; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -65,6 +66,10 @@ public final class SafetyEvaluator { private final Map definitionMap; + // Memoization cache shared across all evaluate() calls on this instance. Avoids redundant recursive + // traversals of the type graph, which otherwise dominate generation time for large definitions. + private final Map> cache = new HashMap<>(); + public SafetyEvaluator(ConjureDefinition definition) { this(TypeFunctions.toTypesMap(definition)); } @@ -75,12 +80,12 @@ public SafetyEvaluator(Map definitionMap) { public Optional evaluate(TypeDefinition def) { return Preconditions.checkNotNull(def, "TypeDefinition is required") - .accept(new TypeDefinitionSafetyVisitor(definitionMap, new HashSet<>())); + .accept(new TypeDefinitionSafetyVisitor(definitionMap, cache, new HashSet<>())); } public Optional evaluate(Type type) { return Preconditions.checkNotNull(type, "TypeDefinition is required") - .accept(new TypeDefinitionSafetyVisitor(definitionMap, new HashSet<>()).fieldVisitor); + .accept(new TypeDefinitionSafetyVisitor(definitionMap, cache, new HashSet<>()).fieldVisitor); } public Optional evaluate(Type type, Optional declaredSafety) { @@ -124,10 +129,19 @@ public Optional getUsageTimeSafety(FieldDefinition field) { } private static final class TypeDefinitionSafetyVisitor implements TypeDefinition.Visitor> { + private final Map> cache; private final Set inProgress; private final Type.Visitor> fieldVisitor; - private TypeDefinitionSafetyVisitor(Map definitionMap, Set inProgress) { + // Tracks whether cycle-breaking (the SAFE fallback for back-edges) was used anywhere + // in the current evaluation subtree. Used to decide whether a result is safe to cache. + private boolean encounteredCycle; + + private TypeDefinitionSafetyVisitor( + Map definitionMap, + Map> cache, + Set inProgress) { + this.cache = cache; this.inProgress = inProgress; this.fieldVisitor = new FieldSafetyVisitor(definitionMap, this); } @@ -170,15 +184,42 @@ public Optional visitUnknown(String unknownType) { } private Optional with(TypeName typeName, Supplier> task) { + // Return memoized result if this type has already been fully evaluated. + // Note: cache values are Optional which may be Optional.empty(), + // so we check for null (absent key) rather than emptiness. + Optional cached = cache.get(typeName); + if (cached != null) { + return cached; + } if (!inProgress.add(typeName)) { // Given recursive evaluation, we return the least restrictive type: SAFE. + // Mark that this subtree's result depends on cycle-breaking. + encounteredCycle = true; return OPTIONAL_OF_SAFE; } + + // Save and reset cycle state so we can detect cycles within this type's subtree only. + boolean previousCycleState = encounteredCycle; + encounteredCycle = false; + Optional result = task.get(); + + boolean subtreeHadCycle = encounteredCycle; + // Propagate cycle detection upward: if this subtree had a cycle, callers should know. + encounteredCycle = previousCycleState || subtreeHadCycle; + if (!inProgress.remove(typeName)) { throw new IllegalStateException( "Failed to remove " + typeName + " from in-progress, something is very wrong!"); } + + // Only cache results where no cycle was encountered in the subtree. + // When a cycle is broken with the SAFE heuristic, the result depends on which type + // was the entry point, so caching it would produce incorrect results if the same + // type is later evaluated from a different starting point. + if (!subtreeHadCycle) { + cache.put(typeName, result); + } return result; } diff --git a/conjure-java-core/src/test/java/com/palantir/conjure/java/types/SafetyEvaluatorTest.java b/conjure-java-core/src/test/java/com/palantir/conjure/java/types/SafetyEvaluatorTest.java index 922ec9459..7f176be68 100644 --- a/conjure-java-core/src/test/java/com/palantir/conjure/java/types/SafetyEvaluatorTest.java +++ b/conjure-java-core/src/test/java/com/palantir/conjure/java/types/SafetyEvaluatorTest.java @@ -366,6 +366,119 @@ void testEmptyEnum() { .hasValue(LogSafety.SAFE); } + @Test + void testCyclicTypes_bothSafe() { + // Foo has a field referencing Bar, Bar has a field referencing Foo. + // Both fields are marked safe, so both types should evaluate as safe. + TypeDefinition foo = TypeDefinition.object(ObjectDefinition.builder() + .typeName(FOO) + .fields(FieldDefinition.builder() + .fieldName(FieldName.of("bar")) + .type(Type.reference(BAR)) + .build()) + .build()); + TypeDefinition bar = TypeDefinition.object(ObjectDefinition.builder() + .typeName(BAR) + .fields(FieldDefinition.builder() + .fieldName(FieldName.of("foo")) + .type(Type.reference(FOO)) + .build()) + .build()); + ConjureDefinition conjureDef = + ConjureDefinition.builder().version(1).types(foo).types(bar).build(); + SafetyEvaluator evaluator = new SafetyEvaluator(conjureDef); + // Evaluation order should not matter for the result + assertThat(evaluator.evaluate(foo)).hasValue(LogSafety.SAFE); + assertThat(evaluator.evaluate(bar)).hasValue(LogSafety.SAFE); + } + + @Test + void testCyclicTypes_withUnsafeField() { + // Foo references Bar, Bar references Foo and also has an unsafe string field. + // The unsafe field should propagate through the cycle. + TypeDefinition foo = TypeDefinition.object(ObjectDefinition.builder() + .typeName(FOO) + .fields(FieldDefinition.builder() + .fieldName(FieldName.of("bar")) + .type(Type.reference(BAR)) + .build()) + .build()); + TypeDefinition bar = TypeDefinition.object(ObjectDefinition.builder() + .typeName(BAR) + .fields(FieldDefinition.builder() + .fieldName(FieldName.of("foo")) + .type(Type.reference(FOO)) + .build()) + .fields(FieldDefinition.builder() + .fieldName(FieldName.of("unsafeField")) + .type(Type.primitive(PrimitiveType.STRING)) + .safety(LogSafety.UNSAFE) + .build()) + .build()); + ConjureDefinition conjureDef = + ConjureDefinition.builder().version(1).types(foo).types(bar).build(); + SafetyEvaluator evaluator = new SafetyEvaluator(conjureDef); + assertThat(evaluator.evaluate(bar)).hasValue(LogSafety.UNSAFE); + assertThat(evaluator.evaluate(foo)).hasValue(LogSafety.UNSAFE); + } + + @Test + void testCyclicTypes_evaluationOrderIndependent() { + // Verifies that the cache does not cause different results depending on which + // type in a cycle is evaluated first. + TypeDefinition foo = TypeDefinition.object(ObjectDefinition.builder() + .typeName(FOO) + .fields(FieldDefinition.builder() + .fieldName(FieldName.of("bar")) + .type(Type.reference(BAR)) + .build()) + .fields(FieldDefinition.builder() + .fieldName(FieldName.of("unsafeField")) + .type(Type.primitive(PrimitiveType.STRING)) + .safety(LogSafety.UNSAFE) + .build()) + .build()); + TypeDefinition bar = TypeDefinition.object(ObjectDefinition.builder() + .typeName(BAR) + .fields(FieldDefinition.builder() + .fieldName(FieldName.of("foo")) + .type(Type.reference(FOO)) + .build()) + .build()); + ConjureDefinition conjureDef = + ConjureDefinition.builder().version(1).types(foo).types(bar).build(); + + // Evaluate foo first, then bar + SafetyEvaluator evaluator1 = new SafetyEvaluator(conjureDef); + Optional fooFirst = evaluator1.evaluate(foo); + Optional barAfterFoo = evaluator1.evaluate(bar); + + // Evaluate bar first, then foo + SafetyEvaluator evaluator2 = new SafetyEvaluator(conjureDef); + Optional barFirst = evaluator2.evaluate(bar); + Optional fooAfterBar = evaluator2.evaluate(foo); + + // Results should be identical regardless of evaluation order + assertThat(fooFirst).isEqualTo(fooAfterBar); + assertThat(barAfterFoo).isEqualTo(barFirst); + } + + @Test + void testSelfReferentialType() { + // A type that references itself (e.g. a linked list node) + TypeDefinition node = TypeDefinition.object(ObjectDefinition.builder() + .typeName(FOO) + .fields(FieldDefinition.builder() + .fieldName(FieldName.of("next")) + .type(Type.reference(FOO)) + .build()) + .build()); + ConjureDefinition conjureDef = + ConjureDefinition.builder().version(1).types(node).build(); + SafetyEvaluator evaluator = new SafetyEvaluator(conjureDef); + assertThat(evaluator.evaluate(node)).hasValue(LogSafety.SAFE); + } + private static Stream getTypes(Type externalReference) { TypeDefinition objectType = TypeDefinition.object(ObjectDefinition.builder() .typeName(FOO)