@@ -193,18 +193,6 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
193193 test(" Children of conditional expressions: CaseWhen" ) {
194194 val add1 = Add (Literal (1 ), Literal (2 ))
195195 val add2 = Add (Literal (2 ), Literal (3 ))
196- val conditions1 = (GreaterThan (add2, Literal (3 )), add1) ::
197- (GreaterThan (add2, Literal (4 )), add1) ::
198- (GreaterThan (add2, Literal (5 )), add1) :: Nil
199-
200- val caseWhenExpr1 = CaseWhen (conditions1, None )
201- val equivalence1 = new EquivalentExpressions
202- equivalence1.addExprTree(caseWhenExpr1)
203-
204- // `add2` is repeatedly in all conditions.
205- assert(equivalence1.getAllEquivalentExprs().count(_.size == 2 ) == 1 )
206- assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2 ).head == Seq (add2, add2))
207-
208196 val conditions2 = (GreaterThan (add1, Literal (3 )), add1) ::
209197 (GreaterThan (add2, Literal (4 )), add1) ::
210198 (GreaterThan (add2, Literal (5 )), add1) :: Nil
@@ -229,30 +217,31 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
229217 }
230218
231219 test(" Children of conditional expressions: Coalesce" ) {
232- val add1 = Add (Literal (1 ), Literal (2 ))
233- val add2 = Add (Literal (2 ), Literal (3 ))
234- val conditions1 = GreaterThan (add2, Literal (3 )) ::
235- GreaterThan (add2, Literal (4 )) ::
236- GreaterThan (add2, Literal (5 )) :: Nil
237-
238- val coalesceExpr1 = Coalesce (conditions1)
239- val equivalence1 = new EquivalentExpressions
240- equivalence1.addExprTree(coalesceExpr1)
241-
242- // `add2` is repeatedly in all conditions.
243- assert(equivalence1.getAllEquivalentExprs().count(_.size == 2 ) == 1 )
244- assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2 ).head == Seq (add2, add2))
245-
246- // Negative case. `add1` and `add2` both are not used in all branches.
247- val conditions2 = GreaterThan (add1, Literal (3 )) ::
248- GreaterThan (add2, Literal (4 )) ::
249- GreaterThan (add2, Literal (5 )) :: Nil
250-
251- val coalesceExpr2 = Coalesce (conditions2)
252- val equivalence2 = new EquivalentExpressions
253- equivalence2.addExprTree(coalesceExpr2)
254-
255- assert(equivalence2.getAllEquivalentExprs().count(_.size == 2 ) == 0 )
220+ withSQLConf(SQLConf .SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED .key -> " true" ) {
221+ val add1 = Add (Literal (1 ), Literal (2 ))
222+ val add2 = Add (Literal (2 ), Literal (3 ))
223+ val conditions1 = GreaterThan (add2, Literal (3 )) ::
224+ GreaterThan (add2, Literal (4 )) :: Nil
225+
226+ val coalesceExpr1 = Coalesce (conditions1)
227+ val equivalence1 = new EquivalentExpressions
228+ equivalence1.addExprTree(coalesceExpr1)
229+
230+ // `add2` is repeatedly in all conditions.
231+ assert(equivalence1.getAllEquivalentExprs(1 ).size == 1 )
232+ assert(equivalence1.getAllEquivalentExprs(1 ).head == Seq (add2, add2))
233+
234+ // Negative case. `add1` and `add2` both are not used in all branches.
235+ val conditions2 = GreaterThan (add1, Literal (3 )) ::
236+ GreaterThan (add2, Literal (4 )) ::
237+ GreaterThan (add2, Literal (5 )) :: Nil
238+
239+ val coalesceExpr2 = Coalesce (conditions2)
240+ val equivalence2 = new EquivalentExpressions
241+ equivalence2.addExprTree(coalesceExpr2)
242+
243+ assert(equivalence2.getAllEquivalentExprs(1 ).size == 0 )
244+ }
256245 }
257246
258247 test(" SPARK-34723: Correct parameter type for subexpression elimination under whole-stage" ) {
@@ -359,9 +348,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
359348 + " elseValue" ) {
360349 val add1 = Add (Literal (1 ), Literal (2 ))
361350 val add2 = Add (Literal (2 ), Literal (3 ))
362- val conditions = (GreaterThan (add1, Literal (3 )), add1) ::
363- (GreaterThan (add2, Literal (4 )), add1) ::
364- (GreaterThan (add2, Literal (5 )), add1) :: Nil
351+ val add3 = Add (Literal (3 ), Literal (4 ))
352+ val conditions = (GreaterThan (add2, Literal (3 )), add1) ::
353+ (GreaterThan (add3, Literal (4 )), add1) ::
354+ (GreaterThan (add3, Literal (5 )), add1) :: Nil
365355
366356 val caseWhenExpr = CaseWhen (conditions, None )
367357 val equivalence = new EquivalentExpressions
@@ -371,6 +361,44 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
371361 assert(equivalence.getAllEquivalentExprs().count(_.size == 2 ) == 0 )
372362 }
373363
364+ test(" SPARK-35564: Subexpressions should be extracted from conditional values if that value "
365+ + " will always be evaluated elsewhere" ) {
366+ withSQLConf(SQLConf .SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED .key -> " true" ) {
367+ val add1 = Add (Literal (1 ), Literal (2 ))
368+ val add2 = Add (Literal (2 ), Literal (3 ))
369+
370+ val conditions1 = (GreaterThan (add1, Literal (3 )), add1) :: Nil
371+ val caseWhenExpr1 = CaseWhen (conditions1, None )
372+ val equivalence1 = new EquivalentExpressions
373+ equivalence1.addExprTree(caseWhenExpr1)
374+
375+ // `add1` is evaluated once in the first condition, and optionally in the first value
376+ assert(equivalence1.getAllEquivalentExprs(1 ).size == 1 )
377+
378+ val ifExpr = If (GreaterThan (add1, Literal (3 )), add1, add2)
379+ val equivalence2 = new EquivalentExpressions
380+ equivalence2.addExprTree(ifExpr)
381+
382+ // `add1` is evaluated once in the condition, and optionally in the true value
383+ assert(equivalence2.getAllEquivalentExprs(1 ).size == 1 )
384+ }
385+ }
386+
387+ test(" SPARK-35564: Don't double count conditional expressions if present in all branches" ) {
388+ withSQLConf(SQLConf .SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED .key -> " true" ) {
389+ val add1 = Add (Literal (1 ), Literal (2 ))
390+ val add2 = Add (Literal (2 ), Literal (3 ))
391+ val add3 = Add (add2, Literal (4 ))
392+
393+ val caseWhenExpr1 = CaseWhen ((GreaterThan (add1, Literal (3 )), add3) :: Nil , add2)
394+ val equivalence1 = new EquivalentExpressions
395+ equivalence1.addExprTree(caseWhenExpr1)
396+
397+ // `add2` will only be evaluated once so don't create a subexpression
398+ assert(equivalence1.getAllEquivalentExprs(1 ).size == 0 )
399+ }
400+ }
401+
374402 test(" SPARK-35439: sort exprs with ExpressionContainmentOrdering" ) {
375403 val exprOrdering = new ExpressionContainmentOrdering
376404
0 commit comments