Skip to content

Commit 959cd88

Browse files
committed
Fix ras and only fallback bhj
1 parent 52fdee6 commit 959cd88

4 files changed

Lines changed: 40 additions & 19 deletions

File tree

gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/RewriteSparkPlanRulesManager.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
2525
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2626
import org.apache.spark.sql.catalyst.rules.Rule
2727
import org.apache.spark.sql.execution.{LeafExecNode, ProjectExec, SparkPlan}
28-
import org.apache.spark.sql.execution.joins.BaseJoinExec
28+
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
2929

3030
case class RewrittenNodeWall(originalChild: SparkPlan) extends LeafExecNode {
3131
override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException()
@@ -109,7 +109,9 @@ class RewriteSparkPlanRulesManager private (
109109
FallbackTags.add(origin, FallbackTags.getOption(rewriteNode).get)
110110
origin
111111
} else if (
112-
rewriteNode.isInstanceOf[BaseJoinExec] && allFallbackTags.exists(_.isDefined)
112+
(rewriteNode.isInstanceOf[BroadcastHashJoinExec] ||
113+
rewriteNode.isInstanceOf[BroadcastNestedLoopJoinExec]) &&
114+
allFallbackTags.exists(_.isDefined)
113115
) {
114116
// If the inserted projects for join is not transformable, return the original plan.
115117
val reason = allFallbackTags.collect { case Some(s) => s.reason() }.mkString(", ")

gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package org.apache.gluten.extension.columnar.enumerated
1818

1919
import org.apache.gluten.execution.{GlutenPlan, ValidatablePlan}
2020
import org.apache.gluten.extension.columnar.FallbackTags
21-
import org.apache.gluten.extension.columnar.offload.OffloadSingleNode
21+
import org.apache.gluten.extension.columnar.offload.{OffloadOthers, OffloadSingleNode}
2222
import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
2323
import org.apache.gluten.extension.columnar.validator.Validator
2424
import org.apache.gluten.ras.path.Pattern
@@ -27,7 +27,8 @@ import org.apache.gluten.ras.rule.{RasRule, Shape}
2727
import org.apache.gluten.ras.rule.Shapes.pattern
2828

2929
import org.apache.spark.internal.Logging
30-
import org.apache.spark.sql.execution.SparkPlan
30+
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
31+
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
3132

3233
import scala.reflect.{classTag, ClassTag}
3334

@@ -129,7 +130,21 @@ object RasOffload {
129130
case t: ValidatablePlan => t
130131
}
131132
val outComes = offloadedNodes.map(_.doValidate()).filter(!_.ok())
132-
if (outComes.nonEmpty) {
133+
// 4.1 Validate pre project of broadcast join
134+
val notOffload = from match {
135+
case _: BroadcastHashJoinExec | _: BroadcastNestedLoopJoinExec =>
136+
val projectOffload = RasOffload.from[ProjectExec](OffloadOthers())
137+
from
138+
.collect {
139+
case preProject: ProjectExec => projectOffload.offload(preProject)
140+
}
141+
.exists {
142+
case t: ValidatablePlan => !t.doValidate().ok()
143+
case plan if !plan.isInstanceOf[GlutenPlan] => true
144+
}
145+
case _ => false
146+
}
147+
if (outComes.nonEmpty || notOffload) {
133148
// 5. If native validation fails on at least one of the offloaded nodes, return
134149
// the original one.
135150
//

gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package org.apache.gluten.extension.columnar.rewrite
1818

1919
import org.apache.gluten.backendsapi.BackendsApiManager
20-
import org.apache.gluten.config.GlutenConfig
2120
import org.apache.gluten.extension.columnar.heuristic.RewrittenNodeWall
2221
import org.apache.gluten.sql.shims.SparkShimLoader
2322
import org.apache.gluten.utils.PullOutProjectHelper
@@ -26,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
2625
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Partial}
2726
import org.apache.spark.sql.execution._
2827
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, TypedAggregateExpression}
29-
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin}
28+
import org.apache.spark.sql.execution.joins.{BaseJoinExec, HashJoin}
3029
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
3130
import org.apache.spark.sql.execution.window.WindowExec
3231

@@ -295,17 +294,6 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper {
295294
arrowEvalPythonExec)
296295

297296
case join: BaseJoinExec if needsPreProject(join) =>
298-
join match {
299-
case _: BroadcastHashJoinExec | _: BroadcastNestedLoopJoinExec
300-
if !GlutenConfig.get.enableColumnarProject =>
301-
// If columnar project is disabled, we cannot pull out project for join, since ProjectExec
302-
// not override doExecuteBroadcast methods, we cannot add project between broadcast join
303-
// and broadcast exchange.
304-
throw new UnsupportedOperationException("columnar project is disabled, " +
305-
"broadcast join operator does not support pull out pre-project, and it will fallback.")
306-
case _ =>
307-
}
308-
309297
// Spark has an improvement which would patch integer joins keys to a Long value.
310298
// But this improvement would cause adding extra project before hash join in velox,
311299
// disabling this improvement as below would help reduce the project.

gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete, Partial}
2424
import org.apache.spark.sql.execution.SparkPlan
2525
import org.apache.spark.sql.execution.aggregate._
26-
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
26+
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, ShuffledHashJoinExec, SortMergeJoinExec}
2727
import org.apache.spark.sql.execution.window.WindowExec
2828
import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType}
2929

@@ -175,6 +175,22 @@ trait PullOutProjectHelper {
175175
condition = newCondition)
176176
newSmj.copyTagsFrom(smj)
177177
newSmj
178+
case nestedLoopJoin: BroadcastNestedLoopJoinExec =>
179+
val newNestedLoopJoin = nestedLoopJoin.copy(
180+
left = newLeft,
181+
right = newRight,
182+
condition = newCondition
183+
)
184+
newNestedLoopJoin.copyTagsFrom(nestedLoopJoin)
185+
newNestedLoopJoin
186+
case cartesianProduct: CartesianProductExec =>
187+
val newCartesianProduct = cartesianProduct.copy(
188+
left = newLeft,
189+
right = newRight,
190+
condition = newCondition
191+
)
192+
newCartesianProduct.copyTagsFrom(cartesianProduct)
193+
newCartesianProduct
178194
case _ =>
179195
throw new UnsupportedOperationException(s"Unsupported join $join")
180196
}

0 commit comments

Comments
 (0)