diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 9ea38b8bfe91..f40dae8dc768 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -418,25 +418,36 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" since = "3.0.0", group = "math_funcs") case class Acosh(child: Expression) - extends UnaryMathExpression((x: Double) => x match { - // in case of large values, the square would lead to Infinity; also, - 1 would be ignored due - // to numeric precision. So log(x + sqrt(x * x - 1)) becomes log(2x) = log(2) + log(x) for - // positive values. - case x if x >= Math.sqrt(Double.MaxValue) => - StrictMath.log(2) + StrictMath.log(x) - case x if x < 1 => + extends UnaryMathExpression((x: Double) => { + // fdlibm e_acosh.c algorithm + if (x < 1.0) { Double.NaN - case _ => StrictMath.log(x + math.sqrt(x * x - 1.0)) }, "ACOSH") { + } else if (x >= (1 << 28)) { + StrictMath.log(x) + StrictMath.log(2.0) + } else if (x == 1.0) { + 0.0 + } else if (x > 2.0) { + StrictMath.log(2.0 * x - 1.0 / (x + math.sqrt(x * x - 1.0))) + } else { + val t = x - 1.0 + StrictMath.log1p(t + math.sqrt(2.0 * t + t * t)) + } + }, "ACOSH") { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { val sm = "java.lang.StrictMath" s""" - |if ($c >= ${Math.sqrt(Double.MaxValue)}) { - | ${ev.value} = $sm.log($c) + $sm.log(2); - |} else if ($c < 1) { + |if ($c < 1.0) { | ${ev.value} = java.lang.Double.NaN; + |} else if ($c >= ${1 << 28}.0) { + | ${ev.value} = $sm.log($c) + $sm.log(2.0); + |} else if ($c == 1.0) { + | ${ev.value} = 0.0; + |} else if ($c > 2.0) { + | ${ev.value} = $sm.log(2.0 * $c - 1.0 / ($c + java.lang.Math.sqrt($c * $c - 1.0))); |} else { - | ${ev.value} = $sm.log($c + java.lang.Math.sqrt($c * $c - 1.0)); + | double t = $c - 1.0; + | ${ev.value} = $sm.log1p(t + java.lang.Math.sqrt(2.0 * t + t * t)); |} |""".stripMargin }) @@ -865,20 +876,43 @@ case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH" since = "3.0.0", group = "math_funcs") case class Asinh(child: Expression) - extends UnaryMathExpression((x: Double) => x match { - // in case of large values, the square would lead to Infinity; also, + 1 would be ignored due - // to numeric precision. So log(x + sqrt(x * x + 1)) becomes log(2x) = log(2) + log(x) for - // positive values. Since the function is symmetric, for large values we can use - // signum(x) + log(2|x|) - case x if Math.abs(x) >= Math.sqrt(Double.MaxValue) - 1 => - Math.signum(x) * (StrictMath.log(2) + StrictMath.log(Math.abs(x))) - case _ => StrictMath.log(x + math.sqrt(x * x + 1.0)) }, "ASINH") { + extends UnaryMathExpression((x: Double) => { + // fdlibm s_asinh.c algorithm + val ax = Math.abs(x) + val w = if (ax.isInfinite || ax.isNaN) { + ax + } else if (ax < 1.0 / (1 << 28)) { + ax + } else if (ax > (1 << 28)) { + StrictMath.log(ax) + StrictMath.log(2.0) + } else if (ax > 2.0) { + StrictMath.log(2.0 * ax + 1.0 / (math.sqrt(x * x + 1.0) + ax)) + } else { + val t = x * x + StrictMath.log1p(ax + t / (1.0 + math.sqrt(1.0 + t))) + } + Math.copySign(w, x) + }, "ASINH") { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => { + nullSafeCodeGen(ctx, ev, c => { val sm = "java.lang.StrictMath" - s"$sm.abs($c) >= ${Math.sqrt(Double.MaxValue) - 1} ? " + - s"$sm.signum($c) * ($sm.log($sm.abs($c)) + $sm.log(2)) :" + - s"$sm.log($c + java.lang.Math.sqrt($c * $c + 1.0))" + s""" + |double ax = java.lang.Math.abs($c); + |double w; + |if (java.lang.Double.isInfinite(ax) || java.lang.Double.isNaN(ax)) { + | w = ax; + |} else if (ax < ${1.0 / (1 << 28)}) { + | w = ax; + |} else if (ax > ${1 << 28}.0) { + | w = $sm.log(ax) + $sm.log(2.0); + |} else if (ax > 2.0) { + | w = $sm.log(2.0 * ax + 1.0 / (java.lang.Math.sqrt($c * $c + 1.0) + ax)); + |} else { + | double t = $c * $c; + | w = $sm.log1p(ax + t / (1.0 + java.lang.Math.sqrt(1.0 + t))); + |} + |${ev.value} = java.lang.Math.copySign(w, $c); + |""".stripMargin }) } override protected def withNewChildInternal(newChild: Expression): Asinh = copy(child = newChild) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index f6a40406e668..63bfec0090ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -238,7 +238,23 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("asinh") { - testUnary(Asinh, (x: Double) => StrictMath.log(x + math.sqrt(x * x + 1.0))) + // fdlibm-aligned reference function + def asinhRef(x: Double): Double = { + val ax = Math.abs(x) + val w = if (ax.isInfinite || ax.isNaN) { + ax + } else if (ax < 1.0 / (1 << 28)) { + ax + } else if (ax > (1 << 28)) { + StrictMath.log(ax) + StrictMath.log(2.0) + } else if (ax > 2.0) { + StrictMath.log(2.0 * ax + 1.0 / (math.sqrt(x * x + 1.0) + ax)) + } else { + StrictMath.log1p(ax + x * x / (1.0 + math.sqrt(1.0 + x * x))) + } + Math.copySign(w, x) + } + testUnary(Asinh, asinhRef) checkConsistencyBetweenInterpretedAndCodegen(Asinh, DoubleType) checkEvaluation(Asinh(Double.NegativeInfinity), Double.NegativeInfinity) @@ -280,10 +296,25 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("acosh") { - def f: (Double) => Double = (x: Double) => StrictMath.log(x + math.sqrt(x * x - 1.0)) - testUnary(Acosh, f, (10 to 20).map(_ * 0.1)) - testUnary(Acosh, f, (-20 to 9).map(_ * 0.1), expectNaN = true) - checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType) + // fdlibm-aligned reference function + def acoshRef(x: Double): Double = { + if (x < 1.0) { + Double.NaN + } else if (x >= (1 << 28)) { + StrictMath.log(x) + StrictMath.log(2.0) + } else if (x == 1.0) { + 0.0 + } else if (x > 2.0) { + val t = x * x + StrictMath.log(2.0 * x - 1.0 / (x + math.sqrt(t - 1.0))) + } else { + val t = x - 1.0 + StrictMath.log1p(t + math.sqrt(2.0 * t + t * t)) + } + } + testUnary(Acosh, acoshRef, (10 to 20).map(_ * 0.1)) + testUnary(Acosh, acoshRef, (-20 to 9).map(_ * 0.1), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Acosh, DoubleType) val nullLit = Literal.create(null, NullType) val doubleNullLit = Literal.create(null, DoubleType) @@ -1025,4 +1056,35 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkNaN(Acosh(-Math.sqrt(Double.MaxValue) + 2)) } + test("SPARK-56089: asinh/acosh fdlibm algorithm coverage") { + // asinh: hardcoded reference values cross-verified against C libm (glibc/musl fdlibm) + checkEvaluation(Asinh(Literal(0.5)), 0.48121182505960347, EmptyRow) + checkEvaluation(Asinh(Literal(1.0)), 0.881373587019543, EmptyRow) + checkEvaluation(Asinh(Literal(2.0)), 1.4436354751788103, EmptyRow) + checkEvaluation(Asinh(Literal(10.0)), 2.99822295029797, EmptyRow) + checkEvaluation(Asinh(Literal(1e8)), 19.11382792451231, EmptyRow) + // |x| < 2^-28 (identity branch) + checkEvaluation(Asinh(Literal(1.0e-10)), 1.0e-10, EmptyRow) + // |x| > 2^28 branch + val asinhExpected = Math.log(Double.MaxValue) + StrictMath.log(2.0) + checkEvaluation(Asinh(Literal(Double.MaxValue)), asinhExpected, EmptyRow) + checkEvaluation(Asinh(Literal(-Double.MaxValue)), -asinhExpected, EmptyRow) + // infinity + checkEvaluation(Asinh(Literal(Double.PositiveInfinity)), Double.PositiveInfinity, EmptyRow) + checkEvaluation(Asinh(Literal(Double.NegativeInfinity)), Double.NegativeInfinity, EmptyRow) + + // acosh: hardcoded reference values cross-verified against C libm (glibc/musl fdlibm) + checkEvaluation(Acosh(Literal(1.0)), 0.0, EmptyRow) + checkEvaluation(Acosh(Literal(1.5)), 0.9624236501192069, EmptyRow) + checkEvaluation(Acosh(Literal(2.0)), 1.3169578969248166, EmptyRow) + checkEvaluation(Acosh(Literal(10.0)), 2.993222846126381, EmptyRow) + checkEvaluation(Acosh(Literal(1e8)), 19.11382792451231, EmptyRow) + // x >= 2^28 branch + val acoshExpected = Math.log(Double.MaxValue) + StrictMath.log(2.0) + checkEvaluation(Acosh(Literal(Double.MaxValue)), acoshExpected, EmptyRow) + checkEvaluation(Acosh(Literal(Double.PositiveInfinity)), Double.PositiveInfinity) + // x < 1 => NaN + checkEvaluation(Acosh(Literal(0.5)), Double.NaN, EmptyRow) + } + } diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out index e1b880f34370..a8f48f7cb4a9 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out @@ -577,7 +577,7 @@ SELECT asinh(double('1')) -- !query schema struct -- !query output -0.8813735870195429 +0.881373587019543 -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index 42f15a04fd58..fbfa84a2631d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -145,7 +145,15 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { test("asinh") { testOneToOneMathFunction(asinh, - (x: Double) => math.log(x + math.sqrt(x * x + 1)) ) + (x: Double) => { + val ax = Math.abs(x) + val w = if (ax.isInfinite || ax.isNaN) ax + else if (ax < 1.0 / (1 << 28)) ax + else if (ax > (1 << 28)) StrictMath.log(ax) + StrictMath.log(2.0) + else if (ax > 2.0) StrictMath.log(2.0 * ax + 1.0 / (math.sqrt(x * x + 1.0) + ax)) + else StrictMath.log1p(ax + x * x / (1.0 + math.sqrt(1.0 + x * x))) + Math.copySign(w, x) + }) } test("cos") { @@ -167,7 +175,13 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { test("acosh") { testOneToOneMathFunction(acosh, - (x: Double) => math.log(x + math.sqrt(x * x - 1)) ) + (x: Double) => { + if (x < 1.0) Double.NaN + else if (x >= (1 << 28)) StrictMath.log(x) + StrictMath.log(2.0) + else if (x == 1.0) 0.0 + else if (x > 2.0) StrictMath.log(2.0 * x - 1.0 / (x + math.sqrt(x * x - 1.0))) + else { val t = x - 1.0; StrictMath.log1p(t + math.sqrt(2.0 * t + t * t)) } + }) } test("tan") {