diff --git a/core/src/main/java/org/apache/calcite/prepare/Prepare.java b/core/src/main/java/org/apache/calcite/prepare/Prepare.java
index 586cce83853..1c6d8875cb1 100644
--- a/core/src/main/java/org/apache/calcite/prepare/Prepare.java
+++ b/core/src/main/java/org/apache/calcite/prepare/Prepare.java
@@ -256,11 +256,14 @@ public PreparedResult prepareSql(
RelRoot root =
sqlToRelConverter.convertQuery(sqlQuery, needsValidation, true);
- if (this.context.config().conformance().checkedArithmetic()) {
- ConvertToChecked checkedConv = new ConvertToChecked(root.rel.getCluster().getRexBuilder());
- RelNode rel = checkedConv.visit(root.rel);
- root = root.withRel(rel);
- }
+ boolean convertToChecked = this.context.config().conformance().checkedArithmetic();
+ // Convert some operations to use checked arithmetic:
+ // - all arithmetic operations on exact types if the conformance requires checked arithmetic
+ // - all arithmetic that produces INTERVAL results, regardless of the conformance
+ ConvertToChecked checkedConv =
+ new ConvertToChecked(root.rel.getCluster().getRexBuilder(), convertToChecked);
+ RelNode rel = checkedConv.visit(root.rel);
+ root = root.withRel(rel);
Hook.CONVERTED.run(root.rel);
if (timingTracer != null) {
diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
index e2adce8ea1c..dc0d1295610 100644
--- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
@@ -2971,10 +2971,14 @@ public static long checkedDivide(long b0, long b1) {
if ((b0 & b1 & q) >= 0) {
return q;
} else {
- throw new ArithmeticException("integer overflow");
+ throw new ArithmeticException("long overflow");
}
}
+ public static double checkedDivide(int b0, double b1) {
+ return b0 / b1;
+ }
+
public static UByte checkedDivide(UByte b0, UByte b1) {
return UByte.valueOf(b0.intValue() / b1.intValue());
}
@@ -2991,6 +2995,16 @@ public static ULong checkedDivide(ULong b0, ULong b1) {
return ULong.valueOf(UnsignedType.toBigInteger(b0).divide(UnsignedType.toBigInteger(b1)));
}
+ // The definition of this function must match the divide function with the same signature
+ public static int checkedDivide(int b0, BigDecimal b1) {
+ return BigDecimal.valueOf(b0)
+ .divide(b1, RoundingMode.HALF_DOWN).intValueExact();
+ }
+
+ public static BigDecimal checkedDivide(BigDecimal b0, BigDecimal b1) {
+ return b0.divide(b1, RoundingMode.HALF_DOWN);
+ }
+
// *
/** SQL * operator applied to int values. */
@@ -3113,6 +3127,10 @@ public static ULong checkedMultiply(ULong b0, ULong b1) {
return ULong.valueOf(UnsignedType.toBigInteger(b0).multiply(UnsignedType.toBigInteger(b1)));
}
+ public static BigDecimal checkedMultiply(BigDecimal b0, long b1) {
+ return b0.multiply(BigDecimal.valueOf(b1));
+ }
+
/** SQL SAFE_ADD function applied to long values. */
public static @Nullable Long safeAdd(long b0, long b1) {
try {
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/ConvertToChecked.java b/core/src/main/java/org/apache/calcite/sql2rel/ConvertToChecked.java
index a5e658e54a8..f64791f9d19 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/ConvertToChecked.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/ConvertToChecked.java
@@ -38,8 +38,8 @@
public class ConvertToChecked extends RelHomogeneousShuttle {
final ConvertRexToChecked converter;
- public ConvertToChecked(RexBuilder builder) {
- this.converter = new ConvertRexToChecked(builder);
+ public ConvertToChecked(RexBuilder builder, boolean allArithmetic) {
+ this.converter = new ConvertRexToChecked(builder, allArithmetic);
}
@Override public RelNode visit(RelNode other) {
@@ -48,14 +48,25 @@ public ConvertToChecked(RexBuilder builder) {
}
/**
- * Visitor which rewrites an expression tree such that all
- * arithmetic operations that produce numeric values use checked arithmetic.
+ * Visitor which rewrites an expression tree such that arithmetic operations
+ * use checked arithmetic.
*/
class ConvertRexToChecked extends RexShuttle {
private final RexBuilder builder;
+ // If true all arithmetic operations are converted.
+ // Otherwise, only arithmetic operations on INTERVAL values is checked.
+ private final boolean allArithmetic;
- ConvertRexToChecked(RexBuilder builder) {
+ /** Create a converter that replaces arithmetic with checked arithmetic.
+ *
+ * @param builder RexBuilder to use.
+ * @param allArithmetic If true all exact arithmetic operations are converted to checked.
+ * If false, only operations that produce INTERVAL-typed results
+ * are converted to checked.
+ */
+ ConvertRexToChecked(RexBuilder builder, boolean allArithmetic) {
this.builder = builder;
+ this.allArithmetic = allArithmetic;
}
@Override public RexNode visitSubQuery(RexSubQuery subQuery) {
@@ -72,6 +83,22 @@ class ConvertRexToChecked extends RexShuttle {
List clonedOperands = visitList(call.operands, update);
SqlKind kind = call.getKind();
SqlOperator operator = call.getOperator();
+ SqlTypeName resultType = call.getType().getSqlTypeName();
+ boolean anyOperandIsInterval = false;
+ for (RexNode op : call.getOperands()) {
+ if (SqlTypeName.INTERVAL_TYPES.contains(op.getType().getSqlTypeName())) {
+ anyOperandIsInterval = true;
+ break;
+ }
+ }
+ boolean resultIsInterval = SqlTypeName.INTERVAL_TYPES.contains(resultType);
+ boolean rewrite =
+ // Do not rewrite operator if the type is e.g., DOUBLE or DATE
+ (this.allArithmetic && SqlTypeName.EXACT_TYPES.contains(resultType))
+ // But always rewrite if the type is an INTERVAL and any operand is INTERVAL
+ // This will not rewrite date subtraction, for example
+ || (resultIsInterval && anyOperandIsInterval);
+
switch (kind) {
case PLUS:
operator = SqlStdOperatorTable.CHECKED_PLUS;
@@ -91,8 +118,7 @@ class ConvertRexToChecked extends RexShuttle {
default:
break;
}
- SqlTypeName resultType = call.getType().getSqlTypeName();
- if (resultType == SqlTypeName.DECIMAL) {
+ if (resultType == SqlTypeName.DECIMAL && this.allArithmetic) {
// Checked decimal arithmetic is implemented using unchecked
// arithmetic followed by a CAST, which is always checked
RexCall result;
@@ -102,8 +128,9 @@ class ConvertRexToChecked extends RexShuttle {
result = call;
}
return builder.makeCast(call.getParserPosition(), call.getType(), result);
- } else if (!SqlTypeName.EXACT_TYPES.contains(resultType)) {
- // Do not rewrite operator if the type is e.g., DOUBLE or DATE
+ }
+
+ if (!rewrite) {
operator = call.getOperator();
}
update[0] = update[0] || operator != call.getOperator();
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
index e8b1bcbd7a0..739c3e30fae 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
@@ -552,7 +552,7 @@ private static RexNode convertInterval(SqlRexContext cx, SqlCall call) {
SqlLiteral.createInterval(1, "1", intervalQualifier,
call.getParserPosition());
final SqlCall multiply =
- SqlStdOperatorTable.MULTIPLY.createCall(call.getParserPosition(), n,
+ SqlStdOperatorTable.CHECKED_MULTIPLY.createCall(call.getParserPosition(), n,
literal);
return cx.convertExpression(multiply);
}
diff --git a/core/src/test/resources/sql/scalar.iq b/core/src/test/resources/sql/scalar.iq
index d82d69f5890..e4ef19c0d3f 100644
--- a/core/src/test/resources/sql/scalar.iq
+++ b/core/src/test/resources/sql/scalar.iq
@@ -18,6 +18,32 @@
!set outputformat mysql
!use scott
+# 5 test cases for [CALCITE-7443] Incorrect simplification for large interval
+SELECT -(INTERVAL -2147483648 months);
+java.lang.ArithmeticException: integer overflow
+
+!error
+
+SELECT INTERVAL 2147483647 years;
+java.lang.ArithmeticException: integer overflow
+
+!error
+
+SELECT -(INTERVAL -9223372036854775.808 SECONDS);
+java.lang.ArithmeticException: long overflow
+
+!error
+
+SELECT INTERVAL 3000000 months * 1000;
+java.lang.ArithmeticException: integer overflow
+
+!error
+
+SELECT INTERVAL 3000000 months / .0001;
+java.lang.ArithmeticException: Overflow
+
+!error
+
select deptno, (select min(empno) from "scott".emp where deptno = dept.deptno) as x from "scott".dept;
+--------+------+
| DEPTNO | X |