Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -418,25 +418,36 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH"
since = "3.0.0",
group = "math_funcs")
case class Acosh(child: Expression)
extends UnaryMathExpression((x: Double) => x match {
// in case of large values, the square would lead to Infinity; also, - 1 would be ignored due
// to numeric precision. So log(x + sqrt(x * x - 1)) becomes log(2x) = log(2) + log(x) for
// positive values.
case x if x >= Math.sqrt(Double.MaxValue) =>
StrictMath.log(2) + StrictMath.log(x)
case x if x < 1 =>
extends UnaryMathExpression((x: Double) => {
// fdlibm e_acosh.c algorithm
if (x < 1.0) {
Double.NaN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, as a question, do you know why fdmlib here has (x - x) / (x - x)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(x - x) / (x - x) is an fdlibm idiom to produce NaN while also raising the IEEE 754 "invalid operation" exception signal. In C, code can detect this via fetestexcept(FE_INVALID). The JVM has no IEEE 754 exception flag mechanism, (x - x) / (x - x) and Double.NaN are functionally identical in Java/Scala. So we use Double.NaN here for clarity.

case _ => StrictMath.log(x + math.sqrt(x * x - 1.0)) }, "ACOSH") {
} else if (x >= (1 << 28)) {
StrictMath.log(x) + StrictMath.log(2.0)
} else if (x == 1.0) {
0.0
} else if (x > 2.0) {
StrictMath.log(2.0 * x - 1.0 / (x + math.sqrt(x * x - 1.0)))
} else {
val t = x - 1.0
StrictMath.log1p(t + math.sqrt(2.0 * t + t * t))
}
}, "ACOSH") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val sm = "java.lang.StrictMath"
s"""
|if ($c >= ${Math.sqrt(Double.MaxValue)}) {
| ${ev.value} = $sm.log($c) + $sm.log(2);
|} else if ($c < 1) {
|if ($c < 1.0) {
| ${ev.value} = java.lang.Double.NaN;
|} else if ($c >= ${1 << 28}.0) {
| ${ev.value} = $sm.log($c) + $sm.log(2.0);
|} else if ($c == 1.0) {
| ${ev.value} = 0.0;
|} else if ($c > 2.0) {
| ${ev.value} = $sm.log(2.0 * $c - 1.0 / ($c + java.lang.Math.sqrt($c * $c - 1.0)));
|} else {
| ${ev.value} = $sm.log($c + java.lang.Math.sqrt($c * $c - 1.0));
| double t = $c - 1.0;
| ${ev.value} = $sm.log1p(t + java.lang.Math.sqrt(2.0 * t + t * t));
|}
|""".stripMargin
})
Expand Down Expand Up @@ -865,20 +876,43 @@ case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH"
since = "3.0.0",
group = "math_funcs")
case class Asinh(child: Expression)
extends UnaryMathExpression((x: Double) => x match {
// in case of large values, the square would lead to Infinity; also, + 1 would be ignored due
// to numeric precision. So log(x + sqrt(x * x + 1)) becomes log(2x) = log(2) + log(x) for
// positive values. Since the function is symmetric, for large values we can use
// signum(x) + log(2|x|)
case x if Math.abs(x) >= Math.sqrt(Double.MaxValue) - 1 =>
Math.signum(x) * (StrictMath.log(2) + StrictMath.log(Math.abs(x)))
case _ => StrictMath.log(x + math.sqrt(x * x + 1.0)) }, "ASINH") {
extends UnaryMathExpression((x: Double) => {
// fdlibm s_asinh.c algorithm
val ax = Math.abs(x)
val w = if (ax.isInfinite || ax.isNaN) {
ax
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fbmlib they have x + x here IIUC. I do not see the rationale though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fdlibm, x + x for the inf/NaN case serves two purposes: (1) it propagates signaling NaN signals in C, and (2) it preserves the sign of infinity (since asinh is an odd function).
In our implementation, we take Math.abs(x) which strips the sign, but Math.copySign(w, x) at the end restores it correctly for both +Inf and -Inf. For NaN, Math.copySign(NaN, x) still returns NaN. The signaling NaN distinction doesn't apply in the JVM, so the result is equivalent, just a different way to achieve the same thing.

} else if (ax < 1.0 / (1 << 28)) {
ax
} else if (ax > (1 << 28)) {
StrictMath.log(ax) + StrictMath.log(2.0)
} else if (ax > 2.0) {
StrictMath.log(2.0 * ax + 1.0 / (math.sqrt(x * x + 1.0) + ax))
} else {
val t = x * x
StrictMath.log1p(ax + t / (1.0 + math.sqrt(1.0 + t)))
}
Math.copySign(w, x)
}, "ASINH") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => {
nullSafeCodeGen(ctx, ev, c => {
val sm = "java.lang.StrictMath"
s"$sm.abs($c) >= ${Math.sqrt(Double.MaxValue) - 1} ? " +
s"$sm.signum($c) * ($sm.log($sm.abs($c)) + $sm.log(2)) :" +
s"$sm.log($c + java.lang.Math.sqrt($c * $c + 1.0))"
s"""
|double ax = java.lang.Math.abs($c);
|double w;
|if (java.lang.Double.isInfinite(ax) || java.lang.Double.isNaN(ax)) {
| w = ax;
|} else if (ax < ${1.0 / (1 << 28)}) {
| w = ax;
|} else if (ax > ${1 << 28}.0) {
| w = $sm.log(ax) + $sm.log(2.0);
|} else if (ax > 2.0) {
| w = $sm.log(2.0 * ax + 1.0 / (java.lang.Math.sqrt($c * $c + 1.0) + ax));
|} else {
| double t = $c * $c;
| w = $sm.log1p(ax + t / (1.0 + java.lang.Math.sqrt(1.0 + t)));
|}
|${ev.value} = java.lang.Math.copySign(w, $c);
|""".stripMargin
})
}
override protected def withNewChildInternal(newChild: Expression): Asinh = copy(child = newChild)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,23 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("asinh") {
testUnary(Asinh, (x: Double) => StrictMath.log(x + math.sqrt(x * x + 1.0)))
// fdlibm-aligned reference function
def asinhRef(x: Double): Double = {
val ax = Math.abs(x)
val w = if (ax.isInfinite || ax.isNaN) {
ax
} else if (ax < 1.0 / (1 << 28)) {
ax
} else if (ax > (1 << 28)) {
StrictMath.log(ax) + StrictMath.log(2.0)
} else if (ax > 2.0) {
StrictMath.log(2.0 * ax + 1.0 / (math.sqrt(x * x + 1.0) + ax))
} else {
StrictMath.log1p(ax + x * x / (1.0 + math.sqrt(1.0 + x * x)))
}
Math.copySign(w, x)
}
testUnary(Asinh, asinhRef)
checkConsistencyBetweenInterpretedAndCodegen(Asinh, DoubleType)

checkEvaluation(Asinh(Double.NegativeInfinity), Double.NegativeInfinity)
Expand Down Expand Up @@ -280,10 +296,25 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("acosh") {
def f: (Double) => Double = (x: Double) => StrictMath.log(x + math.sqrt(x * x - 1.0))
testUnary(Acosh, f, (10 to 20).map(_ * 0.1))
testUnary(Acosh, f, (-20 to 9).map(_ * 0.1), expectNaN = true)
checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType)
// fdlibm-aligned reference function
def acoshRef(x: Double): Double = {
if (x < 1.0) {
Double.NaN
} else if (x >= (1 << 28)) {
StrictMath.log(x) + StrictMath.log(2.0)
} else if (x == 1.0) {
0.0
} else if (x > 2.0) {
val t = x * x
StrictMath.log(2.0 * x - 1.0 / (x + math.sqrt(t - 1.0)))
} else {
val t = x - 1.0
StrictMath.log1p(t + math.sqrt(2.0 * t + t * t))
}
}
testUnary(Acosh, acoshRef, (10 to 20).map(_ * 0.1))
testUnary(Acosh, acoshRef, (-20 to 9).map(_ * 0.1), expectNaN = true)
checkConsistencyBetweenInterpretedAndCodegen(Acosh, DoubleType)

val nullLit = Literal.create(null, NullType)
val doubleNullLit = Literal.create(null, DoubleType)
Expand Down Expand Up @@ -1025,4 +1056,35 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkNaN(Acosh(-Math.sqrt(Double.MaxValue) + 2))
}

test("SPARK-56089: asinh/acosh fdlibm algorithm coverage") {
// asinh: hardcoded reference values cross-verified against C libm (glibc/musl fdlibm)
checkEvaluation(Asinh(Literal(0.5)), 0.48121182505960347, EmptyRow)
checkEvaluation(Asinh(Literal(1.0)), 0.881373587019543, EmptyRow)
checkEvaluation(Asinh(Literal(2.0)), 1.4436354751788103, EmptyRow)
checkEvaluation(Asinh(Literal(10.0)), 2.99822295029797, EmptyRow)
checkEvaluation(Asinh(Literal(1e8)), 19.11382792451231, EmptyRow)
// |x| < 2^-28 (identity branch)
checkEvaluation(Asinh(Literal(1.0e-10)), 1.0e-10, EmptyRow)
// |x| > 2^28 branch
val asinhExpected = Math.log(Double.MaxValue) + StrictMath.log(2.0)
checkEvaluation(Asinh(Literal(Double.MaxValue)), asinhExpected, EmptyRow)
checkEvaluation(Asinh(Literal(-Double.MaxValue)), -asinhExpected, EmptyRow)
// infinity
checkEvaluation(Asinh(Literal(Double.PositiveInfinity)), Double.PositiveInfinity, EmptyRow)
checkEvaluation(Asinh(Literal(Double.NegativeInfinity)), Double.NegativeInfinity, EmptyRow)

// acosh: hardcoded reference values cross-verified against C libm (glibc/musl fdlibm)
checkEvaluation(Acosh(Literal(1.0)), 0.0, EmptyRow)
checkEvaluation(Acosh(Literal(1.5)), 0.9624236501192069, EmptyRow)
checkEvaluation(Acosh(Literal(2.0)), 1.3169578969248166, EmptyRow)
checkEvaluation(Acosh(Literal(10.0)), 2.993222846126381, EmptyRow)
checkEvaluation(Acosh(Literal(1e8)), 19.11382792451231, EmptyRow)
// x >= 2^28 branch
val acoshExpected = Math.log(Double.MaxValue) + StrictMath.log(2.0)
checkEvaluation(Acosh(Literal(Double.MaxValue)), acoshExpected, EmptyRow)
checkEvaluation(Acosh(Literal(Double.PositiveInfinity)), Double.PositiveInfinity)
// x < 1 => NaN
checkEvaluation(Acosh(Literal(0.5)), Double.NaN, EmptyRow)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ SELECT asinh(double('1'))
-- !query schema
struct<ASINH(1):double>
-- !query output
0.8813735870195429
0.881373587019543
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked with PostgreSQL.



-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,15 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {

test("asinh") {
testOneToOneMathFunction(asinh,
(x: Double) => math.log(x + math.sqrt(x * x + 1)) )
(x: Double) => {
val ax = Math.abs(x)
val w = if (ax.isInfinite || ax.isNaN) ax
else if (ax < 1.0 / (1 << 28)) ax
else if (ax > (1 << 28)) StrictMath.log(ax) + StrictMath.log(2.0)
else if (ax > 2.0) StrictMath.log(2.0 * ax + 1.0 / (math.sqrt(x * x + 1.0) + ax))
else StrictMath.log1p(ax + x * x / (1.0 + math.sqrt(1.0 + x * x)))
Math.copySign(w, x)
})
}

test("cos") {
Expand All @@ -167,7 +175,13 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {

test("acosh") {
testOneToOneMathFunction(acosh,
(x: Double) => math.log(x + math.sqrt(x * x - 1)) )
(x: Double) => {
if (x < 1.0) Double.NaN
else if (x >= (1 << 28)) StrictMath.log(x) + StrictMath.log(2.0)
else if (x == 1.0) 0.0
else if (x > 2.0) StrictMath.log(2.0 * x - 1.0 / (x + math.sqrt(x * x - 1.0)))
else { val t = x - 1.0; StrictMath.log1p(t + math.sqrt(2.0 * t + t * t)) }
})
}

test("tan") {
Expand Down