diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 080186431eb7d..31577e8028b47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -175,6 +175,35 @@ class CodegenContext extends Logging { */ var currentVars: Seq[ExprCode] = null + /** + * A mapping from [[ExprId]] to [[ExprCode]] for lambda variables that are currently in scope. + * This is used by [[NamedLambdaVariable]] to look up pre-computed variable bindings set by + * enclosing higher-order functions during code generation. + * + * The enclosing higher-order function registers entries before generating the lambda body code, + * and restores the previous state after. This follows the same save/restore pattern as + * `currentVars`/`INPUT_ROW`. + * + * Note: Like other mutable state in CodegenContext (e.g., `currentVars`, `INPUT_ROW`), + * this is not thread-safe. Callers must ensure single-threaded access during code generation. + */ + var lambdaVariableMap: Map[ExprId, ExprCode] = Map.empty + + /** + * Registers lambda variable bindings, executes the given block, + * then restores the previous bindings. This ensures lambda variable scoping is correct + * for nested higher-order functions. + * + * Note: bindings from inner HOFs take precedence over outer ones via `Map.++`. + * This is safe because [[ExprId]]s are globally unique; inner and outer lambda + * variables will never share the same ExprId. + */ + def withLambdaVariableBindings[T](bindings: Map[ExprId, ExprCode])(f: => T): T = { + val oldBindings = lambdaVariableMap + lambdaVariableMap = lambdaVariableMap ++ bindings + try f finally { lambdaVariableMap = oldBindings } + } + /** * Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a * 2-tuple: java type, variable name. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 2a5a38e93706c..9144e6bf8975d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -23,11 +23,14 @@ import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.mutable import scala.jdk.CollectionConverters.MapHasAsScala +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -81,8 +84,9 @@ case class NamedLambdaVariable( exprId: ExprId = NamedExpression.newExprId, value: AtomicReference[Any] = new AtomicReference()) extends LeafExpression - with NamedExpression - with CodegenFallback { + with NamedExpression { + + final override val nodePatterns: Seq[TreePattern] = Seq(LAMBDA_VARIABLE) override def qualifier: Seq[String] = Seq.empty @@ -98,6 +102,42 @@ case class NamedLambdaVariable( override def eval(input: InternalRow): Any = value.get + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.lambdaVariableMap.get(exprId) match { + case Some(binding) => + // Lambda variable has been bound by an enclosing higher-order function. + // Return the binding directly -- it already contains the correct code, + // isNull, and value referencing the mutable state fields. + binding + case None => + // No binding found -- fall back to interpreted eval via references array. + // This is unexpected in normal operation (the enclosing HOF should have registered + // bindings), but we degrade gracefully rather than failing the query. + NamedLambdaVariable.warnNoCodegenBinding(name, exprId) + val idx = ctx.references.length + ctx.references += this + val objectTerm = ctx.freshName("lambdaValue") + val javaType = CodeGenerator.javaType(dataType) + // Pass null as the input row because NamedLambdaVariable.eval() ignores + // the input row entirely -- it reads its value from the AtomicReference + // set by the enclosing HOF's eval loop. + if (nullable) { + ev.copy(code = code""" + Object $objectTerm = ((Expression) references[$idx]).eval(null); + boolean ${ev.isNull} = $objectTerm == null; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $objectTerm; + }""") + } else { + ev.copy(code = code""" + Object $objectTerm = ((Expression) references[$idx]).eval(null); + $javaType ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $objectTerm; + """, isNull = FalseLiteral) + } + } + } + override def toString: String = s"lambda $name#${exprId.id}$typeSuffix" override def simpleString(maxFields: Int): String = { @@ -105,6 +145,16 @@ case class NamedLambdaVariable( } } +object NamedLambdaVariable extends Logging { + private[expressions] def warnNoCodegenBinding(name: String, exprId: ExprId): Unit = { + logWarning( + s"NamedLambdaVariable '$name#${exprId.id}' has no codegen binding, " + + "falling back to interpreted eval. This warning is emitted during code generation " + + "(not per row at runtime). " + + "Possible cause: missing binding in an enclosing higher-order function's doGenCode.") + } +} + /** * A lambda function and its arguments. A lambda function can be hidden when a user wants to * process an completely independent expression in a [[HigherOrderFunction]], the lambda function @@ -114,7 +164,7 @@ case class LambdaFunction( function: Expression, arguments: Seq[NamedExpression], hidden: Boolean = false) - extends Expression with CodegenFallback { + extends Expression { override def children: Seq[Expression] = function +: arguments override def dataType: DataType = function.dataType @@ -132,6 +182,26 @@ case class LambdaFunction( override def eval(input: InternalRow): Any = function.eval(input) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // LambdaFunction is a thin wrapper. The enclosing HOF is responsible for + // registering lambda variable bindings before this is called. + arguments.foreach { + case nlv: NamedLambdaVariable => + if (!ctx.lambdaVariableMap.contains(nlv.exprId)) { + throw SparkException.internalError( + s"Lambda variable '${nlv.name}#${nlv.exprId.id}' has no codegen binding. " + + s"Bound ids: [${ctx.lambdaVariableMap.keys.map(_.id).mkString(", ")}]") + } + case other => + // arguments should always be NamedLambdaVariable instances (bound by + // HigherOrderFunction.bind). When hidden=true, arguments is empty and + // this branch is unreachable. + throw SparkException.internalError( + s"Expected NamedLambdaVariable but got ${other.getClass.getName}") + } + function.genCode(ctx) + } + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): LambdaFunction = copy( @@ -312,7 +382,7 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { case class ArrayTransform( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction { override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) @@ -354,6 +424,174 @@ case class ArrayTransform( result } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val argumentGen = argument.genCode(ctx) + val resultArray = ctx.freshName("resultArray") + val numElements = ctx.freshName("numElements") + val loopIndex = ctx.freshName("i") + val arrData = ctx.freshName("arrData") + + val elementType = elementVar.dataType + val javaElementType = CodeGenerator.javaType(elementType) + val elementDefault = CodeGenerator.defaultValue(elementType) + + // Use mutable state (class fields) for lambda variable bindings instead of local + // variables. This is critical because Expression.reduceCodeSize() may extract + // lambda body code into a separate private method, and local loop variables would + // not be accessible from such extracted methods. + // Concurrency note: each Spark task runs in its own thread with a separate generated + // class instance, so these fields are not shared across tasks. + val elementIsNull = ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, "elementIsNull") + val elementValue = ctx.addMutableState(javaElementType, "elementValue") + + val elementExtract = if (elementVar.nullable) { + // For primitives, elementDefault provides a valid zero value (e.g. 0 for int) to avoid + // uninitialized reads. For non-primitives, it returns "null" -- logically redundant + // but harmless, and keeps the generated code pattern uniform across all types. + // The isNull flag guards all downstream reads (setElemAtomicRef checks isNull before + // boxing, and lambdaBodyGen propagates isNull through the lambda variable binding). + s""" + |$elementIsNull = $arrData.isNullAt($loopIndex); + |$elementValue = $elementIsNull ? + | $elementDefault : (${CodeGenerator.getValue(arrData, elementType, loopIndex)}); + """.stripMargin + } else { + s""" + |$elementIsNull = false; + |$elementValue = + | ${CodeGenerator.getValue(arrData, elementType, loopIndex)}; + """.stripMargin + } + + // Recursively check whether any sub-expression in the lambda body is a CodegenFallback. + // LambdaFunction and NamedLambdaVariable themselves no longer extend CodegenFallback, + // so this targets genuinely un-codegen'd sub-expressions (e.g., ArrayFilter). + // If none found, we can skip AtomicReference writes entirely, avoiding per-element + // boxing overhead. + val lambdaBodyHasFallback = function.exists(_.isInstanceOf[CodegenFallback]) + + // Also set the AtomicReference on the lambda variable so that any CodegenFallback + // expressions nested inside the lambda body (e.g., ArrayExists, ArrayFilter that + // haven't been given codegen yet) can read the correct value via + // NamedLambdaVariable.eval(). This is NOT redundant with the mutable state bindings + // above -- the mutable state is for the codegen path, while AtomicReference is for + // CodegenFallback sub-expressions that call eval() at runtime. + val setElemAtomicRef = if (lambdaBodyHasFallback) { + val elemAtomicRefTerm = ctx.addReferenceObj( + "elementVarRef", elementVar.value, + "java.util.concurrent.atomic.AtomicReference") + // Explicitly box primitive values to ensure the AtomicReference contains the + // correct boxed type (e.g., Byte for ByteType, Short for ShortType), matching + // what ArrayData.get() returns in the interpreted path. + val boxedElementType = CodeGenerator.boxedType(elementType) + if (elementVar.nullable) { + s"$elemAtomicRefTerm.set($elementIsNull ? null : ($boxedElementType) $elementValue);" + } else { + s"$elemAtomicRefTerm.set(($boxedElementType) $elementValue);" + } + } else { + "" + } + + // Build lambda variable bindings using the mutable state variables. + val elementCode = ExprCode( + code = EmptyBlock, + isNull = if (elementVar.nullable) JavaCode.isNullVariable(elementIsNull) + else FalseLiteral, + value = JavaCode.variable(elementValue, elementType)) + + val (indexExtract, indexBinding) = indexVar match { + case Some(iv) => + val indexValue = ctx.addMutableState(CodeGenerator.JAVA_INT, "indexValue") + val indexCode = ExprCode( + code = EmptyBlock, + isNull = FalseLiteral, + value = JavaCode.variable(indexValue, IntegerType)) + val idxAtomicRefUpdate = if (lambdaBodyHasFallback) { + val idxAtomicRefTerm = ctx.addReferenceObj( + "indexVarRef", iv.value, + "java.util.concurrent.atomic.AtomicReference") + val boxedIndexType = CodeGenerator.boxedType(iv.dataType) + s"\n$idxAtomicRefTerm.set(($boxedIndexType) $loopIndex);" + } else { + "" + } + val extract = + s""" + |$indexValue = $loopIndex;$idxAtomicRefUpdate + """.stripMargin + (extract, Some(iv.exprId -> indexCode)) + case None => + ("", None) + } + + val bindings = Map(elementVar.exprId -> elementCode) ++ indexBinding + + // Generate code for the lambda body with bindings registered. + // Call function.genCode (not lf.function.genCode) so that LambdaFunction.doGenCode + // is exercised, including its binding validation. + val lambdaBodyGen = ctx.withLambdaVariableBindings(bindings) { + function.genCode(ctx) + } + + // Determine the output element type and write strategy. + val outputElementType = function.dataType + val isPrimitive = CodeGenerator.isPrimitiveType(outputElementType) + val isNullOpt = if (function.nullable) Some(lambdaBodyGen.isNull.toString) else None + + // For primitives, setArrayElement handles null check internally. + // For non-primitives, we must copy to avoid memory aliasing with mutable types + // (e.g., UnsafeRow, GenericArrayData). copyValue is a no-op for immutable types + // (e.g., UTF8String, Decimal), so the overhead is negligible. + val setResultElement = if (isPrimitive) { + CodeGenerator.setArrayElement( + resultArray, outputElementType, loopIndex, lambdaBodyGen.value.toString, + isNullOpt) + } else if (function.nullable) { + s""" + |if (${lambdaBodyGen.isNull}) { + | $resultArray.setNullAt($loopIndex); + |} else { + | $resultArray.update($loopIndex, + | InternalRow.copyValue(${lambdaBodyGen.value})); + |} + """.stripMargin + } else { + s"$resultArray.update($loopIndex, InternalRow.copyValue(${lambdaBodyGen.value}));" + } + + val allocation = CodeGenerator.createArrayData( + resultArray, outputElementType, numElements, + " ArrayTransform failed.") + + // argumentGen.value is guaranteed to be ArrayData for ArrayType expressions. + val loopCode = + s""" + |ArrayData $arrData = (ArrayData) ${argumentGen.value}; + |int $numElements = $arrData.numElements(); + |$allocation + |for (int $loopIndex = 0; $loopIndex < $numElements; $loopIndex++) { + | $elementExtract + | $setElemAtomicRef + | $indexExtract + | ${lambdaBodyGen.code} + | $setResultElement + |} + """.stripMargin + + // Null safety: if argument is null, output is null. + ev.copy(code = code""" + ${argumentGen.code} + boolean ${ev.isNull} = ${argumentGen.isNull}; + ArrayData ${ev.value} = null; + if (!${ev.isNull}) { + $loopCode + ${ev.value} = $resultArray; + } + """) + } + override def nodeName: String = "transform" override protected def withNewChildrenInternal( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index cc36cd73d6d77..31c19b5b67a6f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -885,4 +885,19 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper "actualType" -> toSQLType(StringType) ))) } + + test("LambdaFunction.doGenCode requires bindings for all lambda variables") { + import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext + + val lv = NamedLambdaVariable("x", IntegerType, nullable = false) + val lf = LambdaFunction(lv + Literal(1), Seq(lv)) + val ctx = new CodegenContext() + + // genCode without registering bindings should fail with SparkException + val e = intercept[SparkException] { + lf.genCode(ctx) + } + assert(e.getMessage.contains("has no codegen binding")) + assert(e.getMessage.contains("x#"), "Error message should include the variable name") + } } diff --git a/sql/core/benchmarks/HigherOrderFunctionBenchmark-jdk21-results.txt b/sql/core/benchmarks/HigherOrderFunctionBenchmark-jdk21-results.txt new file mode 100644 index 0000000000000..930409f209923 --- /dev/null +++ b/sql/core/benchmarks/HigherOrderFunctionBenchmark-jdk21-results.txt @@ -0,0 +1,103 @@ +================================================================================================ +transform on primitive (int) array +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform int array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform int array wholestage off 149 156 10 7.0 142.3 1.0X +transform int array wholestage on 102 118 9 10.2 97.6 1.5X + + +================================================================================================ +transform with index variable +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform with index: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform with index wholestage off 101 109 11 10.4 96.5 1.0X +transform with index wholestage on 87 95 8 12.1 82.5 1.2X + + +================================================================================================ +transform on string array +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform string array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform string array wholestage off 640 641 1 1.6 610.6 1.0X +transform string array wholestage on 695 730 29 1.5 662.6 0.9X + + +================================================================================================ +transform on struct array +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform struct array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform struct array wholestage off 787 791 5 1.3 751.0 1.0X +transform struct array wholestage on 779 783 3 1.3 743.4 1.0X + + +================================================================================================ +transform on nullable element array +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform nullable array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform nullable array wholestage off 127 131 6 8.3 121.0 1.0X +transform nullable array wholestage on 109 113 6 9.6 103.8 1.2X + + +================================================================================================ +nested transform +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +nested transform: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +nested transform wholestage off 109 112 5 9.6 103.8 1.0X +nested transform wholestage on 104 108 3 10.1 99.4 1.0X + + +================================================================================================ +transform with CodegenFallback body (filter) +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform + filter (mixed codegen/fallback): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------ +transform + filter (mixed codegen/fallback) wholestage off 9601 9644 61 0.1 9155.9 1.0X +transform + filter (mixed codegen/fallback) wholestage on 9692 9701 11 0.1 9242.7 1.0X + + +================================================================================================ +filter (CodegenFallback) vs transform (codegen) +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +filter array (CodegenFallback): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +filter array (CodegenFallback) wholestage off 315 316 2 3.3 299.9 1.0X +filter array (CodegenFallback) wholestage on 330 336 6 3.2 314.4 1.0X + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform array (codegen): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform array (codegen) wholestage off 76 81 6 13.7 72.9 1.0X +transform array (codegen) wholestage on 86 88 2 12.2 82.3 0.9X + + diff --git a/sql/core/benchmarks/HigherOrderFunctionBenchmark-jdk25-results.txt b/sql/core/benchmarks/HigherOrderFunctionBenchmark-jdk25-results.txt new file mode 100644 index 0000000000000..92722151d1c13 --- /dev/null +++ b/sql/core/benchmarks/HigherOrderFunctionBenchmark-jdk25-results.txt @@ -0,0 +1,103 @@ +================================================================================================ +transform on primitive (int) array +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.2+10-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform int array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform int array wholestage off 111 116 7 9.5 105.6 1.0X +transform int array wholestage on 86 90 4 12.3 81.6 1.3X + + +================================================================================================ +transform with index variable +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.2+10-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform with index: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform with index wholestage off 90 91 2 11.7 85.4 1.0X +transform with index wholestage on 82 86 3 12.7 78.6 1.1X + + +================================================================================================ +transform on string array +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.2+10-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform string array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform string array wholestage off 580 580 1 1.8 552.7 1.0X +transform string array wholestage on 564 572 5 1.9 538.2 1.0X + + +================================================================================================ +transform on struct array +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.2+10-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform struct array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform struct array wholestage off 737 740 3 1.4 703.1 1.0X +transform struct array wholestage on 708 716 5 1.5 675.1 1.0X + + +================================================================================================ +transform on nullable element array +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.2+10-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform nullable array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform nullable array wholestage off 150 151 1 7.0 143.3 1.0X +transform nullable array wholestage on 101 103 2 10.4 95.9 1.5X + + +================================================================================================ +nested transform +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.2+10-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +nested transform: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +nested transform wholestage off 102 106 6 10.3 97.0 1.0X +nested transform wholestage on 95 100 4 11.0 90.8 1.1X + + +================================================================================================ +transform with CodegenFallback body (filter) +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.2+10-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform + filter (mixed codegen/fallback): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------ +transform + filter (mixed codegen/fallback) wholestage off 9008 9025 23 0.1 8591.0 1.0X +transform + filter (mixed codegen/fallback) wholestage on 9026 9051 20 0.1 8607.5 1.0X + + +================================================================================================ +filter (CodegenFallback) vs transform (codegen) +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.2+10-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +filter array (CodegenFallback): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +filter array (CodegenFallback) wholestage off 296 313 24 3.5 282.2 1.0X +filter array (CodegenFallback) wholestage on 309 314 5 3.4 294.5 1.0X + +OpenJDK 64-Bit Server VM 25.0.2+10-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform array (codegen): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform array (codegen) wholestage off 81 83 2 12.9 77.5 1.0X +transform array (codegen) wholestage on 69 72 3 15.1 66.2 1.2X + + diff --git a/sql/core/benchmarks/HigherOrderFunctionBenchmark-results.txt b/sql/core/benchmarks/HigherOrderFunctionBenchmark-results.txt new file mode 100644 index 0000000000000..0e22b30361b3d --- /dev/null +++ b/sql/core/benchmarks/HigherOrderFunctionBenchmark-results.txt @@ -0,0 +1,103 @@ +================================================================================================ +transform on primitive (int) array +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform int array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform int array wholestage off 101 111 14 10.4 96.3 1.0X +transform int array wholestage on 102 108 5 10.3 97.1 1.0X + + +================================================================================================ +transform with index variable +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform with index: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform with index wholestage off 93 98 6 11.2 89.2 1.0X +transform with index wholestage on 87 91 4 12.0 83.0 1.1X + + +================================================================================================ +transform on string array +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform string array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform string array wholestage off 621 626 6 1.7 592.4 1.0X +transform string array wholestage on 682 689 5 1.5 650.3 0.9X + + +================================================================================================ +transform on struct array +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform struct array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform struct array wholestage off 743 744 2 1.4 708.1 1.0X +transform struct array wholestage on 741 748 8 1.4 707.1 1.0X + + +================================================================================================ +transform on nullable element array +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform nullable array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform nullable array wholestage off 140 142 3 7.5 133.2 1.0X +transform nullable array wholestage on 107 110 3 9.8 102.4 1.3X + + +================================================================================================ +nested transform +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +nested transform: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +nested transform wholestage off 112 112 0 9.3 107.0 1.0X +nested transform wholestage on 100 106 4 10.5 95.3 1.1X + + +================================================================================================ +transform with CodegenFallback body (filter) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform + filter (mixed codegen/fallback): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------ +transform + filter (mixed codegen/fallback) wholestage off 9425 9435 13 0.1 8988.5 1.0X +transform + filter (mixed codegen/fallback) wholestage on 9408 9427 17 0.1 8971.8 1.0X + + +================================================================================================ +filter (CodegenFallback) vs transform (codegen) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +filter array (CodegenFallback): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +filter array (CodegenFallback) wholestage off 237 237 0 4.4 226.1 1.0X +filter array (CodegenFallback) wholestage on 247 252 5 4.2 235.9 1.0X + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.14.0-1017-azure +AMD EPYC 7763 64-Core Processor +transform array (codegen): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +transform array (codegen) wholestage off 79 79 0 13.3 75.4 1.0X +transform array (codegen) wholestage on 81 83 3 12.9 77.4 1.0X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index a3cfdc5a240a1..e4fff3f25dc27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -944,4 +944,78 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } } + + test("ArrayTransform should be included in WholeStageCodegen") { + withClue("basic transform") { + val df = spark.range(1).selectExpr("transform(array(1, 2, 3), x -> x + 1) as arr") + val plan = df.queryExecution.executedPlan + assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec]), + s"Expected WholeStageCodegenExec in plan:\n$plan") + checkAnswer(df, Row(Seq(2, 3, 4))) + } + + withClue("nested transform") { + val df2 = spark.range(1).selectExpr( + "transform(transform(array(1, 2, 3), x -> x + 1), y -> y * 2) as arr") + val plan2 = df2.queryExecution.executedPlan + assert(plan2.exists(_.isInstanceOf[WholeStageCodegenExec]), + s"Expected WholeStageCodegenExec in plan:\n$plan2") + checkAnswer(df2, Row(Seq(4, 6, 8))) + } + + withClue("transform with index: x+i => 10+0=10, 20+1=21, 30+2=32") { + val df3 = spark.range(1).selectExpr( + "transform(array(10, 20, 30), (x, i) -> x + i) as arr") + checkAnswer(df3, Row(Seq(10, 21, 32))) + } + + withClue("transform with nullable elements") { + val df4 = spark.range(1).selectExpr( + "transform(array(1, cast(null as int), 3), x -> x + 1) as arr") + checkAnswer(df4, Row(Seq(2, null, 4))) + } + + withClue("empty array") { + val df5 = spark.range(1).selectExpr( + "transform(array(), x -> x + 1) as arr") + checkAnswer(df5, Row(Seq.empty)) + } + + withClue("nested CodegenFallback HOF (filter) in lambda body") { + // ArrayFilter still uses CodegenFallback, but ArrayTransform's codegen handles + // this via AtomicReference dual-write: the filter sub-expression calls eval() + // at runtime, while the outer transform runs in codegen. The whole stage still + // uses WholeStageCodegenExec because ArrayTransform itself supports codegen. + val df6 = spark.range(1).selectExpr( + "transform(array(array(1, 2, 3), array(4, 5, 6)), x -> filter(x, y -> y > 2)) as arr") + val plan6 = df6.queryExecution.executedPlan + assert(plan6.exists(_.isInstanceOf[WholeStageCodegenExec]), + s"Expected WholeStageCodegenExec in plan:\n$plan6") + // filter(array(1,2,3), y -> y > 2) => [3] + // filter(array(4,5,6), y -> y > 2) => [4, 5, 6] + checkAnswer(df6, Row(Seq(Seq(3), Seq(4, 5, 6)))) + } + + withClue("null argument") { + val df7 = spark.range(1).selectExpr( + "transform(cast(null as array), x -> x + 1) as arr") + val plan7 = df7.queryExecution.executedPlan + assert(plan7.exists(_.isInstanceOf[WholeStageCodegenExec]), + s"Expected WholeStageCodegenExec in plan:\n$plan7") + checkAnswer(df7, Row(null)) + } + + withClue("struct (non-primitive) element type") { + val df8 = spark.range(1).selectExpr( + "transform(array(named_struct('a', 1, 'b', 'x'), " + + "named_struct('a', 2, 'b', 'y')), s -> named_struct('a', s.a + 10, 'b', s.b)) as arr") + checkAnswer(df8, Row(Seq(Row(11, "x"), Row(12, "y")))) + } + + withClue("string (non-primitive) element type") { + val df9 = spark.range(1).selectExpr( + "transform(array('hello', 'world'), x -> upper(x)) as arr") + checkAnswer(df9, Row(Seq("HELLO", "WORLD"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionBenchmark.scala new file mode 100644 index 0000000000000..e314e20050c53 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionBenchmark.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +/** + * Benchmark to measure higher-order function performance with codegen on/off. + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class + * --jars , + * 2. build/sbt "sql/Test/runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain " + * Results will be written to + * "benchmarks/HigherOrderFunctionBenchmark-results.txt". + * }}} + */ +object HigherOrderFunctionBenchmark extends SqlBasedBenchmark { + + private val numRows = 1 << 20 // ~1M rows + private val arraySize = 10 + + private def prepareArrayData(): Unit = { + spark.range(numRows).selectExpr( + "id", + s"array_repeat(cast(id as int), $arraySize) as int_arr", + s"array_repeat(cast(id as string), $arraySize) as str_arr", + s"array_repeat(named_struct('a', cast(id as int), 'b', cast(id as string)), " + + s"$arraySize) as struct_arr", + s"transform(array_repeat(cast(id as int), $arraySize), " + + s"(x, i) -> if(i % 3 = 0, null, x)) as nullable_arr" + ).createOrReplaceTempView("t") + } + + def transformPrimitiveArray(): Unit = { + runBenchmark("transform on primitive (int) array") { + codegenBenchmark("transform int array", numRows) { + spark.sql( + "select transform(int_arr, x -> x + 1) from t" + ).noop() + } + } + } + + def transformWithIndex(): Unit = { + runBenchmark("transform with index variable") { + codegenBenchmark("transform with index", numRows) { + spark.sql( + "select transform(int_arr, (x, i) -> x + i) from t" + ).noop() + } + } + } + + def transformStringArray(): Unit = { + runBenchmark("transform on string array") { + codegenBenchmark("transform string array", numRows) { + spark.sql( + "select transform(str_arr, x -> upper(x)) from t" + ).noop() + } + } + } + + def transformStructArray(): Unit = { + runBenchmark("transform on struct array") { + codegenBenchmark("transform struct array", numRows) { + spark.sql( + "select transform(struct_arr, x -> named_struct('a', x.a + 1, 'b', x.b)) from t" + ).noop() + } + } + } + + def transformNullableArray(): Unit = { + runBenchmark("transform on nullable element array") { + codegenBenchmark("transform nullable array", numRows) { + spark.sql( + "select transform(nullable_arr, x -> x + 1) from t" + ).noop() + } + } + } + + def nestedTransform(): Unit = { + runBenchmark("nested transform") { + codegenBenchmark("nested transform", numRows) { + spark.sql( + "select transform(transform(int_arr, x -> x + 1), y -> y * 2) from t" + ).noop() + } + } + } + + def transformWithCodegenFallbackBody(): Unit = { + runBenchmark("transform with CodegenFallback body (filter)") { + val arrExpr = s"array_repeat(array_repeat(cast(id as int), 5), $arraySize)" + spark.range(numRows).selectExpr("id", s"$arrExpr as nested_arr") + .createOrReplaceTempView("t_nested") + + codegenBenchmark("transform + filter (mixed codegen/fallback)", numRows) { + spark.sql( + "select transform(nested_arr, x -> filter(x, y -> y > 0)) from t_nested" + ).noop() + } + } + } + + def filterVsTransform(): Unit = { + runBenchmark("filter (CodegenFallback) vs transform (codegen)") { + codegenBenchmark("filter array (CodegenFallback)", numRows) { + spark.sql( + "select filter(int_arr, x -> x > 0) from t" + ).noop() + } + + codegenBenchmark("transform array (codegen)", numRows) { + spark.sql( + "select transform(int_arr, x -> x + 1) from t" + ).noop() + } + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + prepareArrayData() + transformPrimitiveArray() + transformWithIndex() + transformStringArray() + transformStructArray() + transformNullableArray() + nestedTransform() + transformWithCodegenFallbackBody() + filterVsTransform() + } +}