Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ abstract class Optimizer(catalogManager: CatalogManager)
// this batch.
Batch("Early Filter and Projection Push-Down", Once, earlyScanPushDownRules: _*),
Batch("Update CTE Relation Stats", Once, UpdateCTERelationStats),
// Must run after "Early Filter and Projection Push-Down" because it relies on
// accurate stats (e.g., DSv2 relations only report stats after V2ScanRelationPushDown).
Batch("Push Down Join Through Union", Once,
PushDownJoinThroughUnion),
// Since join costs in AQP can change between multiple runs, there is no reason that we have an
// idempotence enforcement on this batch. We thus make it FixedPoint(1) instead of Once.
Batch("Join Reorder", FixedPoint(1),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis.DeduplicateRelations
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{JOIN, UNION}

/**
* Pushes down `Join` through `Union` when the right side of the join is small enough
* to broadcast.
*
* This rule transforms the pattern:
* {{{
* Join(Union(c1, c2, ..., cN), right, joinType, cond)
* }}}
* into:
* {{{
* Union(Join(c1, right, joinType, cond1), Join(c2, right, joinType, cond2), ...)
* }}}
*
* where each `condK` has the Union output attributes rewritten to the corresponding child's
* output attributes.
*
* This is beneficial when the right side is small enough to broadcast, because it avoids
* shuffling the (potentially very large) Union result before the Join. Instead, each Union
* branch joins independently with the broadcasted right side.
*
* Applicable join types: Inner, LeftOuter.
*/
object PushDownJoinThroughUnion
extends Rule[LogicalPlan]
with JoinSelectionHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAllPatterns(JOIN, UNION), ruleId) {

case join @ Join(u: Union, right, joinType, joinCond, hint)
if (joinType == Inner || joinType == LeftOuter) &&
canPlanAsBroadcastHashJoin(join, conf) &&
// Exclude right subtrees containing subqueries, as DeduplicateRelations
// may not correctly handle correlated references when cloning.
!right.exists(_.expressions.exists(SubqueryExpression.hasSubquery)) =>

val unionHeadOutput = u.children.head.output
val newChildren = u.children.zipWithIndex.map { case (child, idx) =>
val newRight = if (idx == 0) right else dedupRight(right)
val leftRewrites = AttributeMap(unionHeadOutput.zip(child.output))
val rightRewrites = if (idx == 0) {
AttributeMap.empty[Attribute]
} else {
AttributeMap(right.output.zip(newRight.output))
}
val newCond = joinCond.map(_.transform {
case a: Attribute if leftRewrites.contains(a) => leftRewrites(a)
case a: Attribute if rightRewrites.contains(a) => rightRewrites(a)
})
Join(child, newRight, joinType, newCond, hint)
}
u.withNewChildren(newChildren)
}

/**
* Creates a copy of `plan` with fresh ExprIds on all output attributes,
* using the same "fake self-join + DeduplicateRelations" pattern as InlineCTE.
*/
private def dedupRight(plan: LogicalPlan): LogicalPlan = {
DeduplicateRelations(
Join(plan, plan, Inner, None, JoinHint.NONE)
) match {
case Join(_, deduped, _, _, _) => deduped
case other =>
throw SparkException.internalError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any other optimization through bug-like errors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yaooqinn. Yes, SparkException.internalError is used in several optimizer rules as a defensive guard for "should-never-happen" plan shapes, for example:

  • NestedColumnAliasing: "Unreasonable plan after optimization: $other"
  • PushExtraPredicateThroughJoin / Optimizer: "Unexpected join type: $other"
  • DecorrelateInnerQuery: "Unexpected domain join type $o"
  • subquery.scala: "Unexpected plan when optimizing one row relation subquery: $o"

The dedupRight method here follows the same pattern — it guards against the (theoretically impossible) case where DeduplicateRelations changes the Join plan shape.

That said, InlineCTE uses the same "fake self-join + DeduplicateRelations" approach and simply calls .children(1) directly without any defensive check. I can align with InlineCTE and remove the explicit throw if you think that's cleaner. Alternatively, I could keep the pattern match but return the original plan unchanged in the fallback case (skipping the dedup rather than failing). Which approach would you prefer?

s"Unexpected plan shape after DeduplicateRelations: ${other.getClass.getName}")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
"org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelation" ::
"org.apache.spark.sql.catalyst.optimizer.PruneFilters" ::
"org.apache.spark.sql.catalyst.optimizer.PushDownJoinThroughUnion" ::
"org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" ::
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType

class PushDownJoinThroughUnionSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("PushDownJoinThroughUnion", FixedPoint(10),
PushDownJoinThroughUnion) :: Nil
}

val testRelation1 = LocalRelation($"a".int, $"b".int)
val testRelation2 = LocalRelation($"c".int, $"d".int)
val testRelation3 = LocalRelation($"e".int, $"f".int)
val testRelation4 = LocalRelation($"g".int, $"h".int)

test("Push down Inner Join through Union when right side is small") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000") {
val union = Union(testRelation1, testRelation2)
val query = union.join(testRelation3, Inner, Some($"a" === $"e"))
val optimized = Optimize.execute(query.analyze)

val expected = Union(
testRelation1.join(testRelation3, Inner, Some($"a" === $"e")),
testRelation2.join(testRelation3, Inner, Some($"c" === $"e"))
).analyze

comparePlans(optimized, expected)
}
}

test("Push down Left Outer Join through Union when right side is small") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000") {
val union = Union(testRelation1, testRelation2)
val query = union.join(testRelation3, LeftOuter, Some($"a" === $"e"))
val optimized = Optimize.execute(query.analyze)

val expected = Union(
testRelation1.join(testRelation3, LeftOuter, Some($"a" === $"e")),
testRelation2.join(testRelation3, LeftOuter, Some($"c" === $"e"))
).analyze

comparePlans(optimized, expected)
}
}

test("Do not push down when right side is too large (broadcast disabled)") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val union = Union(testRelation1, testRelation2)
val query = union.join(testRelation3, Inner, Some($"a" === $"e"))
val optimized = Optimize.execute(query.analyze)

comparePlans(optimized, query.analyze)
}
}

test("Correctly rewrite attributes in join condition") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000") {
val union = Union(testRelation1, testRelation2)
val query = union.join(testRelation3, Inner, Some($"a" === $"e" && $"b" > 10))
val optimized = Optimize.execute(query.analyze)

val expected = Union(
testRelation1.join(testRelation3, Inner, Some($"a" === $"e" && $"b" > 10)),
testRelation2.join(testRelation3, Inner, Some($"c" === $"e" && $"d" > 10))
).analyze

comparePlans(optimized, expected)
}
}

test("Push down Inner Join through 3-way Union (TPC-DS pattern)") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000") {
val union = Union(Seq(testRelation1, testRelation2, testRelation4))
val query = union.join(testRelation3, Inner, Some($"a" === $"e"))
val optimized = Optimize.execute(query.analyze)

val expected = Union(Seq(
testRelation1.join(testRelation3, Inner, Some($"a" === $"e")),
testRelation2.join(testRelation3, Inner, Some($"c" === $"e")),
testRelation4.join(testRelation3, Inner, Some($"g" === $"e"))
)).analyze

comparePlans(optimized, expected)
}
}

test("Do not push down unsupported join types") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000") {
val union = Union(testRelation1, testRelation2)
Seq(RightOuter, FullOuter, LeftSemi, LeftAnti).foreach { joinType =>
val query = union.join(testRelation3, joinType, Some($"a" === $"e"))
val optimized = Optimize.execute(query.analyze)
comparePlans(optimized, query.analyze)
}
}
}

test("Do not push down Cross Join (no join condition)") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000") {
val union = Union(testRelation1, testRelation2)
val query = union.join(testRelation3, Inner, None)
val optimized = Optimize.execute(query.analyze)

comparePlans(optimized, query.analyze)
}
}

test("Do not push down when Union is on the right side") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000") {
val union = Union(testRelation1, testRelation2)
val query = testRelation3.join(union, Inner, Some($"e" === $"a"))
val optimized = Optimize.execute(query.analyze)

comparePlans(optimized, query.analyze)
}
}

test("Push down when right side is a complex subplan") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000") {
val complexRight = testRelation3
.where($"f" > 0)
.select($"e", ($"f" + 1).as("f_plus_1"))
val union = Union(testRelation1, testRelation2)
val query = union.join(complexRight, Inner, Some($"a" === $"e"))
val optimized = Optimize.execute(query.analyze)

// Verify the optimization was applied (Union should be the root)
assert(optimized.isInstanceOf[Union])
// Verify no duplicate ExprIds across Union children's top-level output.
// Each branch should have independent ExprIds for the right side.
val childOutputs = optimized.asInstanceOf[Union].children.map(_.output)
for (i <- childOutputs.indices; j <- (i + 1) until childOutputs.length) {
val ids_i = childOutputs(i).map(_.exprId).toSet
val ids_j = childOutputs(j).map(_.exprId).toSet
assert(ids_i.intersect(ids_j).isEmpty,
s"Union children $i and $j share ExprIds: ${ids_i.intersect(ids_j)}")
}
}
}

test("Push down when right side contains Generate (Explode)") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000") {
val arrayRelation = LocalRelation($"k".int, $"arr".array(IntegerType))
val rightWithGenerate = arrayRelation
.generate(Explode($"arr"), outputNames = Seq("exploded_val"))
.select($"k", $"exploded_val")
val union = Union(testRelation1, testRelation2)
val query = union.join(rightWithGenerate, Inner, Some($"a" === $"k"))
val optimized = Optimize.execute(query.analyze)

// Verify the optimization was applied
assert(optimized.isInstanceOf[Union])
// Verify no duplicate ExprIds across Union children's output
val childOutputs = optimized.asInstanceOf[Union].children.map(_.output)
for (i <- childOutputs.indices; j <- (i + 1) until childOutputs.length) {
val ids_i = childOutputs(i).map(_.exprId).toSet
val ids_j = childOutputs(j).map(_.exprId).toSet
assert(ids_i.intersect(ids_j).isEmpty,
s"Union children $i and $j share ExprIds: ${ids_i.intersect(ids_j)}")
}
}
}

test("Push down when right side contains SubqueryAlias") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000") {
val rightWithAlias = testRelation3.subquery("dim")
val union = Union(testRelation1, testRelation2)
val query = union.join(rightWithAlias, Inner, Some($"a" === $"e"))
val optimized = Optimize.execute(query.analyze)

// Verify the optimization was applied
assert(optimized.isInstanceOf[Union])
// Verify no duplicate ExprIds across Union children's output
val childOutputs = optimized.asInstanceOf[Union].children.map(_.output)
for (i <- childOutputs.indices; j <- (i + 1) until childOutputs.length) {
val ids_i = childOutputs(i).map(_.exprId).toSet
val ids_j = childOutputs(j).map(_.exprId).toSet
assert(ids_i.intersect(ids_j).isEmpty,
s"Union children $i and $j share ExprIds: ${ids_i.intersect(ids_j)}")
}
}
}

test("Push down when right side contains Project with Alias") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000") {
val rightWithAlias = testRelation3
.select($"e", ($"f" + 1).as("f_plus_1"))
val union = Union(testRelation1, testRelation2)
val query = union.join(rightWithAlias, Inner, Some($"a" === $"e"))
val optimized = Optimize.execute(query.analyze)

assert(optimized.isInstanceOf[Union])
val childOutputs = optimized.asInstanceOf[Union].children.map(_.output)
for (i <- childOutputs.indices; j <- (i + 1) until childOutputs.length) {
val ids_i = childOutputs(i).map(_.exprId).toSet
val ids_j = childOutputs(j).map(_.exprId).toSet
assert(ids_i.intersect(ids_j).isEmpty,
s"Union children $i and $j share ExprIds: ${ids_i.intersect(ids_j)}")
}
}
}

test("Push down when right side contains Aggregate") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000") {
val rightWithAgg = testRelation3
.groupBy($"e")(count($"f").as("cnt"), $"e")
val union = Union(testRelation1, testRelation2)
val query = union.join(rightWithAgg, Inner, Some($"a" === $"e"))
val optimized = Optimize.execute(query.analyze)

assert(optimized.isInstanceOf[Union])
val childOutputs = optimized.asInstanceOf[Union].children.map(_.output)
for (i <- childOutputs.indices; j <- (i + 1) until childOutputs.length) {
val ids_i = childOutputs(i).map(_.exprId).toSet
val ids_j = childOutputs(j).map(_.exprId).toSet
assert(ids_i.intersect(ids_j).isEmpty,
s"Union children $i and $j share ExprIds: ${ids_i.intersect(ids_j)}")
}
}
}
}
Loading