diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala index 352aa76397be0..08cc7827b4779 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala @@ -147,8 +147,10 @@ case class RegrR2(y: Expression, x: Expression) extends PearsonCorrelation(y, x, override def prettyName: String = "regr_r2" override val evaluateExpression: Expression = { val corr = ck / sqrt(xMk * yMk) - If(xMk === 0.0, Literal.create(null, DoubleType), - If(yMk === 0.0, Literal.create(1.0, DoubleType), corr * corr)) + // In PearsonCorrelation, x and y are swapped, so here xMk refers to the dependent variable + // and yMk to the independent variable + If(yMk === 0.0, Literal.create(null, DoubleType), + If(xMk === 0.0, Literal.create(1.0, DoubleType), corr * corr)) } override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): RegrR2 = diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out index 3a33dd7c84ed2..fa87a63e7f13d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out @@ -407,3 +407,29 @@ Aggregate [k#x], [k#x, regr_intercept(cast(y#x as double), cast(x#x as double)) +- Project [k#x, y#x, x#x] +- SubqueryAlias testRegression +- LocalRelation [k#x, y#x, x#x] + + +-- !query +SELECT regr_r2(k, x) FROM testRegression where k=2 +-- !query analysis +Aggregate [regr_r2(cast(k#x as double), cast(x#x as double)) AS regr_r2(k, x)#x] ++- Filter (k#x = 2) + +- SubqueryAlias testregression + +- View (`testRegression`, [k#x, y#x, x#x]) + +- Project [cast(k#x as int) AS k#x, cast(y#x as int) AS y#x, cast(x#x as int) AS x#x] + +- Project [k#x, y#x, x#x] + +- SubqueryAlias testRegression + +- LocalRelation [k#x, y#x, x#x] + + +-- !query +SELECT regr_r2(y, k) FROM testRegression where k=2 +-- !query analysis +Aggregate [regr_r2(cast(y#x as double), cast(k#x as double)) AS regr_r2(y, k)#x] ++- Filter (k#x = 2) + +- SubqueryAlias testregression + +- View (`testRegression`, [k#x, y#x, x#x]) + +- Project [cast(k#x as int) AS k#x, cast(y#x as int) AS y#x, cast(x#x as int) AS x#x] + +- Project [k#x, y#x, x#x] + +- SubqueryAlias testRegression + +- LocalRelation [k#x, y#x, x#x] diff --git a/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql b/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql index df286d2a9b0a9..a3fa6d4c4cd49 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql @@ -50,3 +50,7 @@ SELECT regr_intercept(y, x) FROM testRegression; SELECT regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL; SELECT k, regr_intercept(y, x) FROM testRegression GROUP BY k; SELECT k, regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k; + +-- SPARK-55969: regr_r2 should treat first param as dependent variable +SELECT regr_r2(k, x) FROM testRegression where k=2; +SELECT regr_r2(y, k) FROM testRegression where k=2; diff --git a/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out b/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out index e511ea75aae5a..96b2aa08884ef 100644 --- a/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out @@ -274,3 +274,19 @@ SELECT k, regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS struct -- !query output 2 1.1547344110854496 + + +-- !query +SELECT regr_r2(k, x) FROM testRegression where k=2 +-- !query schema +struct +-- !query output +1.0 + + +-- !query +SELECT regr_r2(y, k) FROM testRegression where k=2 +-- !query schema +struct +-- !query output +NULL