diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java index 274e91f14de..c1ea964e033 100644 --- a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java +++ b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java @@ -2957,7 +2957,8 @@ public static boolean classifyFilters( joinFields, nTotalFields, leftFields, - filter); + filter, + joinRel.getInput(0)); leftFilters.add(shiftedFilter); } @@ -2975,7 +2976,8 @@ public static boolean classifyFilters( joinFields, nTotalFields, rightFields, - filter); + filter, + joinRel.getInput(1)); rightFilters.add(shiftedFilter); } filtersToRemove.add(filter); @@ -3079,7 +3081,8 @@ public static boolean classifyFilters( joinFields, nTotalFields, leftFields, - filter); + filter, + joinRel.getInput(0)); leftFilters.add(shiftedFilter); } @@ -3105,7 +3108,8 @@ public static boolean classifyFilters( joinFields, nTotalFields, rightFields, - filter); + filter, + joinRel.getInput(1)); rightFilters.add(shiftedFilter); } filtersToRemove.add(filter); @@ -3140,7 +3144,8 @@ private static RexNode shiftFilter( List joinFields, int nTotalFields, List rightFields, - RexNode filter) { + RexNode filter, + RelNode child) { int[] adjustments = new int[nTotalFields]; for (int i = start; i < end; i++) { adjustments[i] = offset; @@ -3150,7 +3155,9 @@ private static RexNode shiftFilter( rexBuilder, joinFields, rightFields, - adjustments)); + adjustments, + offset, + child)); } /** @@ -4752,6 +4759,17 @@ public ImmutableBitSet build() { } return super.visitCall(call); } + + @Override public Void visitSubQuery(RexSubQuery subQuery) { + final Set variablesSet = RelOptUtil.getVariablesUsed(subQuery.rel); + for (CorrelationId id : variablesSet) { + ImmutableBitSet requiredColumns = RelOptUtil.correlationColumns(id, subQuery.rel); + for (int index : requiredColumns) { + bitBuilder.set(index); + } + } + return super.visitSubQuery(subQuery); + } } /** @@ -4766,6 +4784,8 @@ public static class RexInputConverter extends RexShuttle { private final @Nullable List rightDestFields; private final int nLeftDestFields; private final int[] adjustments; + private final int offset; + private final @Nullable RelNode correlateVariableChild; /** * Creates a RexInputConverter. @@ -4784,6 +4804,13 @@ public static class RexInputConverter extends RexShuttle { * @param rightDestFields in the case where the destination is a join, * these are the fields from the right join input * @param adjustments the amount to adjust each field by + * @param offset the amount to shift field accesses by when + * rewriting correlated subqueries + * @param correlateVariableChild the child relation providing the + * correlated variable; if non-null, subqueries + * referencing a correlation variable will have + * their field accesses shifted by {@code offset} + * relative to this child */ private RexInputConverter( RexBuilder rexBuilder, @@ -4791,7 +4818,9 @@ private RexInputConverter( @Nullable List destFields, @Nullable List leftDestFields, @Nullable List rightDestFields, - int[] adjustments) { + int[] adjustments, + int offset, + @Nullable RelNode correlateVariableChild) { this.rexBuilder = rexBuilder; this.srcFields = srcFields; this.destFields = destFields; @@ -4804,6 +4833,8 @@ private RexInputConverter( assert destFields == null; nLeftDestFields = leftDestFields.size(); } + this.offset = offset; + this.correlateVariableChild = correlateVariableChild; } public RexInputConverter( @@ -4818,7 +4849,9 @@ public RexInputConverter( null, leftDestFields, rightDestFields, - adjustments); + adjustments, + 0, + null); } public RexInputConverter( @@ -4826,14 +4859,51 @@ public RexInputConverter( @Nullable List srcFields, @Nullable List destFields, int[] adjustments) { - this(rexBuilder, srcFields, destFields, null, null, adjustments); + this(rexBuilder, srcFields, destFields, null, null, adjustments, 0, null); } public RexInputConverter( RexBuilder rexBuilder, @Nullable List srcFields, int[] adjustments) { - this(rexBuilder, srcFields, null, null, null, adjustments); + this(rexBuilder, srcFields, null, null, null, adjustments, 0, null); + } + + public RexInputConverter( + RexBuilder rexBuilder, + @Nullable List srcFields, + @Nullable List destFields, + int[] adjustments, + int offset, + RelNode child) { + this(rexBuilder, srcFields, destFields, null, null, adjustments, offset, child); + } + + @Override public RexNode visitSubQuery(RexSubQuery subQuery) { + boolean[] update = {false}; + List clonedOperands = visitList(subQuery.operands, update); + if (update[0]) { + subQuery = subQuery.clone(subQuery.getType(), clonedOperands); + } + final Set variablesSet = + RelOptUtil.getVariablesUsed(subQuery.rel); + if (!variablesSet.isEmpty() && correlateVariableChild != null) { + for (CorrelationId id : variablesSet) { + RelNode newSubQueryRel = + subQuery.rel.accept(new RelHomogeneousShuttle() { + @Override public RelNode visit(RelNode other) { + RelNode node = + RexUtil.shiftFieldAccess(rexBuilder, other, id, + correlateVariableChild, offset); + return super.visit(node); + } + }); + if (newSubQueryRel != subQuery.rel) { + subQuery = subQuery.clone(newSubQueryRel); + } + } + } + return subQuery; } @Override public RexNode visitInputRef(RexInputRef var) { diff --git a/core/src/main/java/org/apache/calcite/rel/rules/FilterJoinRule.java b/core/src/main/java/org/apache/calcite/rel/rules/FilterJoinRule.java index d4e1473ccbf..c014f6f88c6 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/FilterJoinRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/FilterJoinRule.java @@ -20,6 +20,7 @@ import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.RelRule; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinRelType; @@ -29,7 +30,9 @@ import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexSubQuery; import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVisitorImpl; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.tools.RelBuilder; @@ -198,14 +201,35 @@ protected void perform(RelOptRuleCall call, @Nullable Filter filter, return; } + Set leftVariablesSet = new LinkedHashSet<>(); + Set rightVariablesSet = new LinkedHashSet<>(); + + for (RexNode condition : leftFilters) { + condition.accept(new RexVisitorImpl(true) { + @Override public Void visitSubQuery(RexSubQuery subQuery) { + leftVariablesSet.addAll(RelOptUtil.getVariablesUsed(subQuery.rel)); + return super.visitSubQuery(subQuery); + } + }); + } + + for (RexNode condition : rightFilters) { + condition.accept(new RexVisitorImpl(true) { + @Override public Void visitSubQuery(RexSubQuery subQuery) { + rightVariablesSet.addAll(RelOptUtil.getVariablesUsed(subQuery.rel)); + return super.visitSubQuery(subQuery); + } + }); + } + // create Filters on top of the children if any filters were // pushed to them final RexBuilder rexBuilder = join.getCluster().getRexBuilder(); final RelBuilder relBuilder = call.builder(); final RelNode leftRel = - relBuilder.push(join.getLeft()).filter(leftFilters).build(); + relBuilder.push(join.getLeft()).filter(leftVariablesSet, leftFilters).build(); final RelNode rightRel = - relBuilder.push(join.getRight()).filter(rightFilters).build(); + relBuilder.push(join.getRight()).filter(rightVariablesSet, rightFilters).build(); // create the new join node referencing the new children and // containing its new join filters (if there are any) diff --git a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java index 2b406c3c015..a1951569910 100644 --- a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java +++ b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java @@ -1830,4 +1830,160 @@ public static Frameworks.ConfigBuilder config() { + " LogicalTableScan(table=[[scott, EMP]])\n"; assertThat(after, hasTree(planAfter)); } + + /** Test case for [CALCITE-7442] + * Getting Wrong index of Correlated variable inside Subquery after FilterJoinRule. */ + @Test void testCorrelatedVariableIndexForInClause() { + final FrameworkConfig frameworkConfig = config().build(); + final RelBuilder builder = RelBuilder.create(frameworkConfig); + final RelOptCluster cluster = builder.getCluster(); + final Planner planner = Frameworks.getPlanner(frameworkConfig); + final String sql = "select e.empno, d.dname, b.ename\n" + + "from emp e\n" + + "inner join dept d\n" + + " on d.deptno = e.deptno\n" + + "inner join bonus b\n" + + " on e.ename = b.ename\n" + + " and b.job in (\n" + + " select b2.job\n" + + " from bonus b2\n" + + " where b2.ename = b.ename)\n" + + "where e.sal > 1000 and d.dname = 'SALES'"; + + final RelNode originalRel; + try { + final SqlNode parse = planner.parse(sql); + final SqlNode validate = planner.validate(parse); + originalRel = planner.rel(validate).rel; + } catch (Exception e) { + throw TestUtil.rethrow(e); + } + + final HepProgram hepProgram = HepProgram.builder() + .addRuleCollection( + ImmutableList.of( + CoreRules.FILTER_INTO_JOIN, + CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)) + .build(); + final Program program = + Programs.of(hepProgram, true, + requireNonNull(cluster.getMetadataProvider())); + final RelNode before = + program.run(cluster.getPlanner(), originalRel, cluster.traitSet(), + Collections.emptyList(), Collections.emptyList()); + + final String planBefore = "LogicalProject(EMPNO=[$0], DNAME=[$9], ENAME=[$11])\n" + + " LogicalJoin(condition=[=($1, $11)], joinType=[inner], variablesSet=[[$cor0]])\n" + + " LogicalJoin(condition=[=($8, $7)], joinType=[inner])\n" + + " LogicalFilter(condition=[>(CAST($5):DECIMAL(12, 2), 1000.00)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($1, 'SALES')])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalProject(ENAME=[$0], JOB=[$1], SAL=[$2], COMM=[$3])\n" + + " LogicalFilter(condition=[=($1, $4)])\n" + + " LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalProject(JOB=[$1])\n" + + " LogicalFilter(condition=[=($0, $cor0.ENAME)])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n"; + assertThat(before, hasTree(planBefore)); + + final RelNode after = + RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()), + RuleSets.ofList(Collections.emptyList())); + final String planAfter = "LogicalProject(EMPNO=[$0], DNAME=[$9], ENAME=[$11])\n" + + " LogicalJoin(condition=[=($1, $11)], joinType=[inner])\n" + + " LogicalJoin(condition=[=($8, $7)], joinType=[inner])\n" + + " LogicalFilter(condition=[>(CAST($5):DECIMAL(12, 2), 1000.00)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($1, 'SALES')])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalProject(ENAME=[$0], JOB=[$1], SAL=[$2], COMM=[$3])\n" + + " LogicalJoin(condition=[AND(=($0, $5), =($1, $4))], joinType=[inner])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalProject(JOB=[$1], ENAME=[$0])\n" + + " LogicalFilter(condition=[IS NOT NULL($0)])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n"; + assertThat(after, hasTree(planAfter)); + } + + /** Test case for [CALCITE-7442] + * Getting Wrong index of Correlated variable inside Subquery after FilterJoinRule. + * Same as {@link #testCorrelatedVariableIndexForInClause()} but uses EXISTS + * instead of IN. */ + @Test void testCorrelatedVariableIndexForExistsClause() { + final FrameworkConfig frameworkConfig = config().build(); + final RelBuilder builder = RelBuilder.create(frameworkConfig); + final RelOptCluster cluster = builder.getCluster(); + final Planner planner = Frameworks.getPlanner(frameworkConfig); + final String sql = "select e.empno, d.dname, b.ename\n" + + "from emp e\n" + + "inner join dept d\n" + + " on d.deptno = e.deptno\n" + + "inner join bonus b\n" + + " on e.ename = b.ename\n" + + " and exists (\n" + + " select b2.job\n" + + " from bonus b2\n" + + " where b2.ename = b.ename\n" + + " and b2.job = b.job)\n" + + "where e.sal > 1000 and d.dname = 'SALES'"; + + final RelNode originalRel; + try { + final SqlNode parse = planner.parse(sql); + final SqlNode validate = planner.validate(parse); + originalRel = planner.rel(validate).rel; + } catch (Exception e) { + throw TestUtil.rethrow(e); + } + + final HepProgram hepProgram = HepProgram.builder() + .addRuleCollection( + ImmutableList.of( + CoreRules.FILTER_INTO_JOIN, + CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)) + .build(); + final Program program = + Programs.of(hepProgram, true, + requireNonNull(cluster.getMetadataProvider())); + final RelNode before = + program.run(cluster.getPlanner(), originalRel, cluster.traitSet(), + Collections.emptyList(), Collections.emptyList()); + + final String planBefore = "LogicalProject(EMPNO=[$0], DNAME=[$9], ENAME=[$11])\n" + + " LogicalJoin(condition=[=($1, $11)], joinType=[inner], variablesSet=[[$cor0]])\n" + + " LogicalJoin(condition=[=($8, $7)], joinType=[inner])\n" + + " LogicalFilter(condition=[>(CAST($5):DECIMAL(12, 2), 1000.00)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($1, 'SALES')])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalProject(ENAME=[$0], JOB=[$1], SAL=[$2], COMM=[$3])\n" + + " LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0, 1}])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalAggregate(group=[{0}])\n" + + " LogicalProject(i=[true])\n" + + " LogicalFilter(condition=[AND(=($0, $cor0.ENAME), =($1, $cor0.JOB))])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n"; + + assertThat(before, hasTree(planBefore)); + + final RelNode after = + RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()), + RuleSets.ofList(Collections.emptyList())); + final String planAfter = "LogicalProject(EMPNO=[$0], DNAME=[$9], ENAME=[$11])\n" + + " LogicalJoin(condition=[=($1, $11)], joinType=[inner])\n" + + " LogicalJoin(condition=[=($8, $7)], joinType=[inner])\n" + + " LogicalFilter(condition=[>(CAST($5):DECIMAL(12, 2), 1000.00)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($1, 'SALES')])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalProject(ENAME=[$0], JOB=[$1], SAL=[$2], COMM=[$3])\n" + + " LogicalJoin(condition=[AND(=($0, $4), =($1, $5))], joinType=[inner])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalProject(ENAME=[$0], JOB=[$1], $f2=[true])\n" + + " LogicalFilter(condition=[AND(IS NOT NULL($0), IS NOT NULL($1))])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n"; + assertThat(after, hasTree(planAfter)); + } }