From 85811c1f4b4d198376d69c42609f5df5cbf95ac7 Mon Sep 17 00:00:00 2001 From: mihailoale-db Date: Tue, 17 Mar 2026 11:19:20 +0100 Subject: [PATCH] initial commit Co-authored-by: Isaac --- .../sql/catalyst/plans/NormalizePlan.scala | 189 ++++++++++-------- .../catalyst/plans/NormalizePlanSuite.scala | 176 +++++++++++++++- 2 files changed, 280 insertions(+), 85 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala index 9341dee19c742..ff471cd6f00f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION /** * Object that handles normalization of operators and expressions. Used when comparing plans. @@ -126,96 +127,118 @@ object NormalizePlan extends PredicateHelper { * - CTERelationRef ids will be remapped based on the new CTERelationDef IDs. This is possible, * because WithCTE returns cteDefs as first children, and the defs will be traversed before the * refs. - * - Normalizes inner [[Project]] nodes by sorting project lists alphabetically. - * - Normalizes inner [[Aggregate]] nodes by sorting aggregate expressions lists alphabetically. + * - Normalizes inner [[Project]] and [[Aggregate]] nodes by sorting their output lists + * alphabetically when they are descendants of another [[Project]] or [[Aggregate]], provided no + * schema boundary ([[SubqueryAlias]], [[View]], [[CTERelationDef]]) or non-unary node boundary + * intervenes. */ def normalizePlan(plan: LogicalPlan): LogicalPlan = { - val cteIdNormalizer = new CteIdNormalizer - plan.transformUpWithSubqueries { - case Filter(condition: Expression, child: LogicalPlan) => - Filter( - splitConjunctivePredicates(condition) - .map(rewriteBinaryComparison) - .sortBy(_.hashCode()) - .reduce(And), - child - ) - case sample: Sample => - sample.copy(seed = sample.seed.map(_ => 0L)) - case Join(left, right, joinType, condition, hint) if condition.isDefined => - val newJoinType = joinType match { - case ExistenceJoin(a: Attribute) => - val newAttr = AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) - ExistenceJoin(newAttr) - case other => other - } - - val newCondition = - splitConjunctivePredicates(condition.get) - .map(rewriteBinaryComparison) - .sortBy(_.hashCode()) - .reduce(And) - Join(left, right, newJoinType, Some(newCondition), hint) - case project: Project - if project.containsTag(ResolverTag.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION) => - project.child - - case aggregate @ Aggregate(_, _, innerProject: Project, _) => - aggregate.copy(child = normalizeProjectListOrder(innerProject)) - - case project @ Project(_, innerProject: Project) => - project.copy(child = normalizeProjectListOrder(innerProject)) + normalizeRecursive(plan, normalizeProjectList = false, new CteIdNormalizer()) + } - case project @ Project(_, innerAggregate: Aggregate) => - project.copy(child = normalizeAggregateListOrder(innerAggregate)) + /** + * Recursively normalizes the plan. The `normalizeProjectList` flag is propagated top-down: it + * starts as false. When a [[Project]] or [[Aggregate]] is encountered, the flag is set to true + * for children. The flag is reset to false at schema boundary nodes ([[SubqueryAlias]], [[View]], + * [[CTERelationDef]]) and at non-unary nodes (e.g. [[Join]], [[Union]], [[Intersect]], + * [[Except]]) where column position is semantically significant. Unary nodes like [[Sort]] and + * [[Filter]] pass the flag through unchanged. + * + * When the flag is true and we encounter a [[Project]] or [[Aggregate]], we normalize its output + * list order. + * + * Deduplication projects (tagged with [[ResolverTag.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION]]) + * are removed and recursion continues with their child. This case must be handled before child + * processing because the deduplication project should pass the flag through unchanged, whereas a + * regular [[Project]] would set it to true. + * + * Children are recursed into before the current node is normalized (bottom-up processing order), + * except for deduplication projects which are removed top-down before child processing. This + * bottom-up ordering preserves CTE ID normalization. Plan expressions within the current node's + * expressions are normalized with the flag reset to false (independent scope). + */ + private def normalizeRecursive( + plan: LogicalPlan, + normalizeProjectList: Boolean, + cteIdNormalizer: CteIdNormalizer): LogicalPlan = plan match { + case project: Project + if project.containsTag(ResolverTag.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION) => + normalizeRecursive(project.child, normalizeProjectList, cteIdNormalizer) + case _ => + val shouldNormalizeChildProjectList = plan match { + case _: Project | _: Aggregate => true + case _: SubqueryAlias | _: View | _: CTERelationDef => false + case _ if plan.children.length != 1 => false + case _ => normalizeProjectList + } - /** - * ORDER BY covered by an output-retaining project on top of GROUP BY - */ - case project @ Project(_, sort @ Sort(_, _, innerAggregate: Aggregate, _)) => - project.copy(child = sort.copy(child = normalizeAggregateListOrder(innerAggregate))) + val withNormalizedChildren = plan.mapChildren { child => + normalizeRecursive(child, shouldNormalizeChildProjectList, cteIdNormalizer) + } - /** - * HAVING covered by an output-retaining project on top of GROUP BY - */ - case project @ Project(_, filter @ Filter(_, innerAggregate: Aggregate)) => - project.copy(child = filter.copy(child = normalizeAggregateListOrder(innerAggregate))) + val withNormalizedSubqueries = + withNormalizedChildren.transformExpressionsWithPruning( + _.containsPattern(PLAN_EXPRESSION)) { + case subqueryExpression: SubqueryExpression => + subqueryExpression.withNewPlan( + normalizeRecursive(subqueryExpression.plan, normalizeProjectList = false, + cteIdNormalizer)) + } - /** - * HAVING ... ORDER BY covered by an output-retaining project on top of GROUP BY - */ - case project @ Project( - _, - sort @ Sort(_, _, filter @ Filter(_, innerAggregate: Aggregate), _) - ) => - project.copy( - child = - sort.copy(child = filter.copy(child = normalizeAggregateListOrder(innerAggregate))) - ) + withNormalizedSubqueries match { + case Filter(condition: Expression, child: LogicalPlan) => + Filter( + splitConjunctivePredicates(condition) + .map(rewriteBinaryComparison) + .sortBy(_.hashCode()) + .reduce(And), + child + ) + case sample: Sample => + sample.copy(seed = sample.seed.map(_ => 0L)) + case Join(left, right, joinType, condition, hint) if condition.isDefined => + val newJoinType = joinType match { + case ExistenceJoin(a: Attribute) => + val newAttr = + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) + ExistenceJoin(newAttr) + case other => other + } - case c: KeepAnalyzedQuery => c.storeAnalyzedQuery() - case localRelation: LocalRelation if !localRelation.data.isEmpty => - /** - * A substitute for the [[LocalRelation.data]]. [[GenericInternalRow]] is incomparable for - * maps, because [[ArrayBasedMapData]] doesn't define [[equals]]. - */ - val unsafeProjection = UnsafeProjection.create(localRelation.schema) - localRelation.copy(data = localRelation.data.map { row => - unsafeProjection(row) - }) - case cteRelationDef: CTERelationDef => - cteIdNormalizer.normalizeDef(cteRelationDef) - case unionLoop: UnionLoop => - cteIdNormalizer.normalizeUnionLoop( - unionLoop.copy(outputAttrIds = Seq.fill(unionLoop.outputAttrIds.size)(ExprId(0))) - ) - case cteRelationRef: CTERelationRef => - cteIdNormalizer.normalizeRef(cteRelationRef) - case unionLoopRef: UnionLoopRef => - cteIdNormalizer.normalizeUnionLoopRef(unionLoopRef) - case normalizeableRelation: NormalizeableRelation => - normalizeableRelation.normalize() - } + val newCondition = + splitConjunctivePredicates(condition.get) + .map(rewriteBinaryComparison) + .sortBy(_.hashCode()) + .reduce(And) + Join(left, right, newJoinType, Some(newCondition), hint) + case project: Project if normalizeProjectList => + normalizeProjectListOrder(project) + case aggregate: Aggregate if normalizeProjectList => + normalizeAggregateListOrder(aggregate) + case c: KeepAnalyzedQuery => c.storeAnalyzedQuery() + case localRelation: LocalRelation if !localRelation.data.isEmpty => + /** + * A substitute for the [[LocalRelation.data]]. [[GenericInternalRow]] is incomparable + * for maps, because [[ArrayBasedMapData]] doesn't define [[equals]]. + */ + val unsafeProjection = UnsafeProjection.create(localRelation.schema) + localRelation.copy(data = localRelation.data.map { row => + unsafeProjection(row) + }) + case cteRelationDef: CTERelationDef => + cteIdNormalizer.normalizeDef(cteRelationDef) + case unionLoop: UnionLoop => + cteIdNormalizer.normalizeUnionLoop( + unionLoop.copy(outputAttrIds = Seq.fill(unionLoop.outputAttrIds.size)(ExprId(0))) + ) + case cteRelationRef: CTERelationRef => + cteIdNormalizer.normalizeRef(cteRelationRef) + case unionLoopRef: UnionLoopRef => + cteIdNormalizer.normalizeUnionLoopRef(unionLoopRef) + case normalizeableRelation: NormalizeableRelation => + normalizeableRelation.normalize() + case other => other + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala index cbf7ce7948800..3d2cf464b6c1a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala @@ -21,11 +21,13 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LocalRelation, LogicalPlan, UnionLoop, UnionLoopRef} -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LocalRelation, LogicalPlan, UnionLoop, UnionLoopRef, View} +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructField, StructType} class NormalizePlanSuite extends SparkFunSuite with SQLConfHelper { @@ -117,6 +119,176 @@ class NormalizePlanSuite extends SparkFunSuite with SQLConfHelper { assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan)) } + test("Normalize ordering in a project list of an inner Project under Project and Filter") { + val baselinePlan = + LocalRelation($"col1".int, $"col2".string) + .select($"col1", $"col2") + .where($"col1" === 1) + .select($"col1") + val testPlan = + LocalRelation($"col1".int, $"col2".string) + .select($"col2", $"col1") + .where($"col1" === 1) + .select($"col1") + + assert(baselinePlan != testPlan) + assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan)) + } + + test("SubqueryAlias resets normalizeProjectList flag for inner Project") { + // Project under SubqueryAlias should NOT be normalized even when above a Project. + // The SubqueryAlias boundary preserves the schema, so column order matters. + val baselinePlan = LocalRelation($"col1".int, $"col2".string) + .select($"col1", $"col2") + .subquery("t") + .select($"col1") + val testPlan = LocalRelation($"col1".int, $"col2".string) + .select($"col2", $"col1") + .subquery("t") + .select($"col1") + + // The inner Project has different column order, and SubqueryAlias resets the flag, + // so normalization should NOT make them equal. + assert(NormalizePlan(baselinePlan) != NormalizePlan(testPlan)) + } + + test("SubqueryAlias resets normalizeProjectList flag for inner Aggregate") { + // Aggregate under SubqueryAlias should NOT have its list normalized. + val baselinePlan = LocalRelation($"col1".int, $"col2".string) + .groupBy($"col1", $"col2")($"col1", $"col2") + .subquery("t") + .select($"col1") + val testPlan = LocalRelation($"col1".int, $"col2".string) + .groupBy($"col1", $"col2")($"col2", $"col1") + .subquery("t") + .select($"col1") + + assert(NormalizePlan(baselinePlan) != NormalizePlan(testPlan)) + } + + test("Nested SubqueryAlias resets flag even under multiple Projects") { + // Project -> Project -> SubqueryAlias -> Project: the innermost Project should NOT + // be normalized because SubqueryAlias resets the flag. + val baselinePlan = LocalRelation($"col1".int, $"col2".string) + .select($"col1", $"col2") + .subquery("t") + .select($"col1", $"col2") + .select($"col1") + val testPlan = LocalRelation($"col1".int, $"col2".string) + .select($"col2", $"col1") + .subquery("t") + .select($"col1", $"col2") + .select($"col1") + + // The inner Project (below SubqueryAlias) differs in order and should NOT be normalized. + assert(NormalizePlan(baselinePlan) != NormalizePlan(testPlan)) + } + + test("Project above SubqueryAlias IS normalized when under another Project") { + // Project -> Project -> SubqueryAlias -> relation + // The middle Project (between outer Project and SubqueryAlias) should still be normalized + // because it's under a Project, and SubqueryAlias only resets the flag for its children. + val baselinePlan = LocalRelation($"col1".int, $"col2".string) + .subquery("t") + .select($"col1", $"col2") + .select($"col1") + val testPlan = LocalRelation($"col1".int, $"col2".string) + .subquery("t") + .select($"col2", $"col1") + .select($"col1") + + assert(baselinePlan != testPlan) + assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan)) + } + + test("SubqueryAlias resets flag but Project above SubqueryAlias still normalizes with Filter") { + // Aggregate -> Filter -> SubqueryAlias -> Project: the Project under SubqueryAlias + // should NOT be normalized, but the structure above SubqueryAlias works normally. + val baselinePlan = LocalRelation($"col1".int, $"col2".string) + .select($"col1", $"col2") + .subquery("t") + .where($"col1" === 1) + .groupBy($"col1")($"col1") + val testPlan = LocalRelation($"col1".int, $"col2".string) + .select($"col2", $"col1") + .subquery("t") + .where($"col1" === 1) + .groupBy($"col1")($"col1") + + // Inner Project under SubqueryAlias is NOT normalized, so plans differ. + assert(NormalizePlan(baselinePlan) != NormalizePlan(testPlan)) + } + + test("Double SubqueryAlias both reset the flag independently") { + // Project -> SubqueryAlias -> Project -> SubqueryAlias -> Project + // Both inner Projects should NOT be normalized. + val baselinePlan = LocalRelation($"col1".int, $"col2".string) + .select($"col1", $"col2") + .subquery("inner") + .select($"col1", $"col2") + .subquery("outer") + .select($"col1") + val testPlan = LocalRelation($"col1".int, $"col2".string) + .select($"col2", $"col1") + .subquery("inner") + .select($"col2", $"col1") + .subquery("outer") + .select($"col1") + + assert(NormalizePlan(baselinePlan) != NormalizePlan(testPlan)) + } + + test("View resets normalizeProjectList flag for inner Project") { + val viewDesc = CatalogTable( + identifier = TableIdentifier("test_view"), + tableType = CatalogTableType.VIEW, + storage = CatalogStorageFormat.empty, + schema = StructType(Seq(StructField("col1", IntegerType), StructField("col2", StringType))) + ) + val baselinePlan = View( + desc = viewDesc, + isTempView = true, + child = LocalRelation($"col1".int, $"col2".string).select($"col1", $"col2") + ).select($"col1") + val testPlan = View( + desc = viewDesc, + isTempView = true, + child = LocalRelation($"col1".int, $"col2".string).select($"col2", $"col1") + ).select($"col1") + + // View is a schema boundary, so the inner Project should NOT be normalized. + assert(NormalizePlan(baselinePlan) != NormalizePlan(testPlan)) + } + + test("CTERelationDef resets normalizeProjectList flag for inner Project") { + val baselinePlan = CTERelationDef( + child = LocalRelation($"col1".int, $"col2".string).select($"col1", $"col2"), + id = 1L + ).select($"col1") + val testPlan = CTERelationDef( + child = LocalRelation($"col1".int, $"col2".string).select($"col2", $"col1"), + id = 1L + ).select($"col1") + + // CTERelationDef is a schema boundary, so the inner Project should NOT be normalized. + assert(NormalizePlan(baselinePlan) != NormalizePlan(testPlan)) + } + + test("Non-unary nodes reset normalizeProjectList flag") { + // Project -> Union -> [Project, Project]: the inner Projects under Union should NOT + // be normalized because Union matches columns by position. + val left1 = LocalRelation($"col1".int, $"col2".string).select($"col1", $"col2") + val right1 = LocalRelation($"col1".int, $"col2".string).select($"col1", $"col2") + val baselinePlan = left1.union(right1).select($"col1") + + val left2 = LocalRelation($"col1".int, $"col2".string).select($"col2", $"col1") + val right2 = LocalRelation($"col1".int, $"col2".string).select($"col2", $"col1") + val testPlan = left2.union(right2).select($"col1") + + // Inner Projects under Union have different column order and should NOT be normalized. + assert(NormalizePlan(baselinePlan) != NormalizePlan(testPlan)) + } + test("Normalize InheritAnalysisRules expressions") { val castWithoutTimezone = Cast(child = Literal(1), dataType = BooleanType, ansiEnabled = conf.ansiEnabled)