Skip to content

Commit d4d64c7

Browse files
committed
Track conditionally evaluated expressions to resolve as subexpressions for cases they are already being evaluated
1 parent 450b415 commit d4d64c7

3 files changed

Lines changed: 131 additions & 64 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.mutable
2222
import org.apache.spark.TaskContext
2323
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2424
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
25+
import org.apache.spark.sql.internal.SQLConf
2526

2627
/**
2728
* This class is used to compute equality of (sub)expression trees. Expressions can be added
@@ -43,6 +44,9 @@ class EquivalentExpressions {
4344

4445
// For each expression, the set of equivalent expressions.
4546
private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]]
47+
// Maintain optionally evaluated expressions that we can resolve as well
48+
private val conditionalEquivalenceMap =
49+
mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]]
4650

4751
/**
4852
* Adds each expression to this data structure, grouping them with existing equivalent
@@ -65,6 +69,24 @@ class EquivalentExpressions {
6569
}
6670
}
6771

72+
private def addConditionalExpr(commonExprs: mutable.Set[Expr])(expr: Expression): Boolean = {
73+
// If we've extract it as a common expression, we don't want to double count it as conditional
74+
val inCommonExprs = commonExprs.contains(Expr(expr))
75+
if (expr.deterministic && !inCommonExprs) {
76+
val e: Expr = Expr(expr)
77+
val f = conditionalEquivalenceMap.get(e)
78+
if (f.isDefined) {
79+
f.get += expr
80+
true
81+
} else {
82+
conditionalEquivalenceMap.put(e, mutable.ArrayBuffer(expr))
83+
equivalenceMap.contains(e)
84+
}
85+
} else {
86+
inCommonExprs
87+
}
88+
}
89+
6890
private def addExprToSet(expr: Expression, set: mutable.Set[Expr]): Boolean = {
6991
if (expr.deterministic) {
7092
val e = Expr(expr)
@@ -91,9 +113,9 @@ class EquivalentExpressions {
91113
* For example, if `((a + b) + c)` and `(a + b)` are common expressions, we only add
92114
* `((a + b) + c)`.
93115
*/
94-
private def addCommonExprs(
116+
private def findCommonExprs(
95117
exprs: Seq[Expression],
96-
addFunc: Expression => Boolean = addExpr): Unit = {
118+
addFunc: Expression => Boolean = addExpr): mutable.Set[Expr] = {
97119
val exprSetForAll = mutable.Set[Expr]()
98120
addExprTree(exprs.head, addExprToSet(_, exprSetForAll))
99121

@@ -105,13 +127,11 @@ class EquivalentExpressions {
105127

106128
// Not all expressions in the set should be added. We should filter out the related
107129
// children nodes.
108-
val commonExprSet = candidateExprs.filter { candidateExpr =>
130+
candidateExprs.filter { candidateExpr =>
109131
candidateExprs.forall { expr =>
110132
expr == candidateExpr || expr.e.find(_.semanticEquals(candidateExpr.e)).isEmpty
111133
}
112134
}
113-
114-
commonExprSet.foreach(expr => addExprTree(expr.e, addFunc))
115135
}
116136

117137
// There are some special expressions that we should not recurse into all of its children.
@@ -134,24 +154,18 @@ class EquivalentExpressions {
134154

135155
// For some special expressions we cannot just recurse into all of its children, but we can
136156
// recursively add the common expressions shared between all of its children.
137-
private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match {
138-
case i: If => Seq(Seq(i.trueValue, i.falseValue))
139-
case c: CaseWhen =>
140-
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
141-
// because a subexpression in conditions will be run no matter which condition is matched
142-
// if it is shared among conditions, but it doesn't need to be shared in values. Similarly,
143-
// a subexpression among values doesn't need to be in conditions because no matter which
144-
// condition is true, it will be evaluated.
145-
val conditions = c.branches.tail.map(_._1)
146-
// For an expression to be in all branch values of a CaseWhen statement, it must also be in
147-
// the elseValue.
148-
val values = if (c.elseValue.nonEmpty) {
149-
c.branches.map(_._2) ++ c.elseValue
150-
} else {
151-
Nil
152-
}
153-
Seq(conditions, values)
154-
case c: Coalesce => Seq(c.children.tail)
157+
private def commonChildrenToRecurse(expr: Expression): Seq[Expression] = expr match {
158+
case i: If => Seq(i.trueValue, i.falseValue)
159+
case c: CaseWhen if c.elseValue.nonEmpty => c.branches.map(_._2) ++ c.elseValue
160+
case _ => Nil
161+
}
162+
163+
// Finds expressions that are conditionally evaluated, so that if they are definitely evaluated
164+
// elsewhere, we can create a subexpression to optimize the conditional case.
165+
private def conditionallyEvaluatedChildren(expr: Expression): Seq[Expression] = expr match {
166+
case i: If => Seq(i.trueValue, i.falseValue)
167+
case c: CaseWhen => c.branches.tail.map(_._1) ++ c.branches.map(_._2) ++ c.elseValue
168+
case c: Coalesce => c.children.tail
155169
case _ => Nil
156170
}
157171

@@ -172,7 +186,19 @@ class EquivalentExpressions {
172186

173187
if (!skip && !addFunc(expr)) {
174188
childrenToRecurse(expr).foreach(addExprTree(_, addFunc))
175-
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc))
189+
190+
val commonChildrenCandidates = commonChildrenToRecurse(expr)
191+
val commonChildrenExprs = if (commonChildrenCandidates.nonEmpty) {
192+
findCommonExprs(commonChildrenCandidates)
193+
} else {
194+
mutable.Set.empty[Expr]
195+
}
196+
commonChildrenExprs.foreach(e => addExprTree(e.e, addFunc))
197+
198+
if (SQLConf.get.subexpressionEliminationConditionalsEnabled) {
199+
val conditionallyEvaluatedExprs = conditionallyEvaluatedChildren(expr)
200+
conditionallyEvaluatedExprs.foreach(addExprTree(_, addConditionalExpr(commonChildrenExprs)))
201+
}
176202
}
177203
}
178204

@@ -189,7 +215,9 @@ class EquivalentExpressions {
189215
* times.
190216
*/
191217
def getAllEquivalentExprs(repeatTimes: Int = 0): Seq[Seq[Expression]] = {
192-
equivalenceMap.values.map(_.toSeq).filter(_.size > repeatTimes).toSeq
218+
equivalenceMap.values.map(_.toSeq)
219+
.map(exprs => exprs ++ conditionalEquivalenceMap.getOrElse(Expr(exprs.head), Seq.empty))
220+
.filter(_.size > repeatTimes).toSeq
193221
.sortBy(_.head)(new ExpressionContainmentOrdering)
194222
}
195223

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,14 @@ object SQLConf {
626626
.checkValue(_ >= 0, "The maximum must not be negative")
627627
.createWithDefault(100)
628628

629+
val SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED =
630+
buildConf("spark.sql.subexpressionElimination.conditionals.enabled")
631+
.internal()
632+
.doc("When true, common conditional subexpressions will be eliminated.")
633+
.version("3.2.0")
634+
.booleanConf
635+
.createWithDefault(false)
636+
629637
val CASE_SENSITIVE = buildConf("spark.sql.caseSensitive")
630638
.internal()
631639
.doc("Whether the query analyzer should be case sensitive or not. " +
@@ -3599,6 +3607,9 @@ class SQLConf extends Serializable with Logging {
35993607
def subexpressionEliminationCacheMaxEntries: Int =
36003608
getConf(SUBEXPRESSION_ELIMINATION_CACHE_MAX_ENTRIES)
36013609

3610+
def subexpressionEliminationConditionalsEnabled: Boolean =
3611+
getConf(SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED)
3612+
36023613
def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
36033614

36043615
def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,6 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
193193
test("Children of conditional expressions: CaseWhen") {
194194
val add1 = Add(Literal(1), Literal(2))
195195
val add2 = Add(Literal(2), Literal(3))
196-
val conditions1 = (GreaterThan(add2, Literal(3)), add1) ::
197-
(GreaterThan(add2, Literal(4)), add1) ::
198-
(GreaterThan(add2, Literal(5)), add1) :: Nil
199-
200-
val caseWhenExpr1 = CaseWhen(conditions1, None)
201-
val equivalence1 = new EquivalentExpressions
202-
equivalence1.addExprTree(caseWhenExpr1)
203-
204-
// `add2` is repeatedly in all conditions.
205-
assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
206-
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2))
207-
208196
val conditions2 = (GreaterThan(add1, Literal(3)), add1) ::
209197
(GreaterThan(add2, Literal(4)), add1) ::
210198
(GreaterThan(add2, Literal(5)), add1) :: Nil
@@ -229,30 +217,31 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
229217
}
230218

231219
test("Children of conditional expressions: Coalesce") {
232-
val add1 = Add(Literal(1), Literal(2))
233-
val add2 = Add(Literal(2), Literal(3))
234-
val conditions1 = GreaterThan(add2, Literal(3)) ::
235-
GreaterThan(add2, Literal(4)) ::
236-
GreaterThan(add2, Literal(5)) :: Nil
237-
238-
val coalesceExpr1 = Coalesce(conditions1)
239-
val equivalence1 = new EquivalentExpressions
240-
equivalence1.addExprTree(coalesceExpr1)
241-
242-
// `add2` is repeatedly in all conditions.
243-
assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
244-
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2))
245-
246-
// Negative case. `add1` and `add2` both are not used in all branches.
247-
val conditions2 = GreaterThan(add1, Literal(3)) ::
248-
GreaterThan(add2, Literal(4)) ::
249-
GreaterThan(add2, Literal(5)) :: Nil
250-
251-
val coalesceExpr2 = Coalesce(conditions2)
252-
val equivalence2 = new EquivalentExpressions
253-
equivalence2.addExprTree(coalesceExpr2)
254-
255-
assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 0)
220+
withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED.key -> "true") {
221+
val add1 = Add(Literal(1), Literal(2))
222+
val add2 = Add(Literal(2), Literal(3))
223+
val conditions1 = GreaterThan(add2, Literal(3)) ::
224+
GreaterThan(add2, Literal(4)) :: Nil
225+
226+
val coalesceExpr1 = Coalesce(conditions1)
227+
val equivalence1 = new EquivalentExpressions
228+
equivalence1.addExprTree(coalesceExpr1)
229+
230+
// `add2` is repeatedly in all conditions.
231+
assert(equivalence1.getAllEquivalentExprs(1).size == 1)
232+
assert(equivalence1.getAllEquivalentExprs(1).head == Seq(add2, add2))
233+
234+
// Negative case. `add1` and `add2` both are not used in all branches.
235+
val conditions2 = GreaterThan(add1, Literal(3)) ::
236+
GreaterThan(add2, Literal(4)) ::
237+
GreaterThan(add2, Literal(5)) :: Nil
238+
239+
val coalesceExpr2 = Coalesce(conditions2)
240+
val equivalence2 = new EquivalentExpressions
241+
equivalence2.addExprTree(coalesceExpr2)
242+
243+
assert(equivalence2.getAllEquivalentExprs(1).size == 0)
244+
}
256245
}
257246

258247
test("SPARK-34723: Correct parameter type for subexpression elimination under whole-stage") {
@@ -359,9 +348,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
359348
+ "elseValue") {
360349
val add1 = Add(Literal(1), Literal(2))
361350
val add2 = Add(Literal(2), Literal(3))
362-
val conditions = (GreaterThan(add1, Literal(3)), add1) ::
363-
(GreaterThan(add2, Literal(4)), add1) ::
364-
(GreaterThan(add2, Literal(5)), add1) :: Nil
351+
val add3 = Add(Literal(3), Literal(4))
352+
val conditions = (GreaterThan(add2, Literal(3)), add1) ::
353+
(GreaterThan(add3, Literal(4)), add1) ::
354+
(GreaterThan(add3, Literal(5)), add1) :: Nil
365355

366356
val caseWhenExpr = CaseWhen(conditions, None)
367357
val equivalence = new EquivalentExpressions
@@ -371,6 +361,44 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
371361
assert(equivalence.getAllEquivalentExprs().count(_.size == 2) == 0)
372362
}
373363

364+
test("SPARK-35564: Subexpressions should be extracted from conditional values if that value "
365+
+ "will always be evaluated elsewhere") {
366+
withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED.key -> "true") {
367+
val add1 = Add(Literal(1), Literal(2))
368+
val add2 = Add(Literal(2), Literal(3))
369+
370+
val conditions1 = (GreaterThan(add1, Literal(3)), add1) :: Nil
371+
val caseWhenExpr1 = CaseWhen(conditions1, None)
372+
val equivalence1 = new EquivalentExpressions
373+
equivalence1.addExprTree(caseWhenExpr1)
374+
375+
// `add1` is evaluated once in the first condition, and optionally in the first value
376+
assert(equivalence1.getAllEquivalentExprs(1).size == 1)
377+
378+
val ifExpr = If(GreaterThan(add1, Literal(3)), add1, add2)
379+
val equivalence2 = new EquivalentExpressions
380+
equivalence2.addExprTree(ifExpr)
381+
382+
// `add1` is evaluated once in the condition, and optionally in the true value
383+
assert(equivalence2.getAllEquivalentExprs(1).size == 1)
384+
}
385+
}
386+
387+
test("SPARK-35564: Don't double count conditional expressions if present in all branches") {
388+
withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED.key -> "true") {
389+
val add1 = Add(Literal(1), Literal(2))
390+
val add2 = Add(Literal(2), Literal(3))
391+
val add3 = Add(add2, Literal(4))
392+
393+
val caseWhenExpr1 = CaseWhen((GreaterThan(add1, Literal(3)), add3) :: Nil, add2)
394+
val equivalence1 = new EquivalentExpressions
395+
equivalence1.addExprTree(caseWhenExpr1)
396+
397+
// `add2` will only be evaluated once so don't create a subexpression
398+
assert(equivalence1.getAllEquivalentExprs(1).size == 0)
399+
}
400+
}
401+
374402
test("SPARK-35439: sort exprs with ExpressionContainmentOrdering") {
375403
val exprOrdering = new ExpressionContainmentOrdering
376404

0 commit comments

Comments
 (0)