Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions core/src/main/java/org/apache/calcite/plan/hep/HepPlanner.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
import org.apache.calcite.util.graph.Graphs;
import org.apache.calcite.util.graph.TopologicalOrderIterator;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Multimap;

import org.checkerframework.checker.nullness.qual.Nullable;

Expand All @@ -66,6 +68,7 @@
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkArgument;

Expand Down Expand Up @@ -114,6 +117,30 @@ public class HepPlanner extends AbstractRelOptPlanner {
private final List<RelOptMaterialization> materializations =
new ArrayList<>();

/**
* Cache of rules that have already been fired for a specific operand match,
* to avoid firing the same rule repeatedly.
*
* <p>Key: the list of matched {@link RelNode} IDs (operand match).
*
* <p>Value: the set of {@link RelOptRule}s already fired for that exact ID list.
*/
private final Multimap<List<Integer>, RelOptRule> firedRulesCache = HashMultimap.create();

/**
* Reverse index for {@link #firedRulesCache}, used for cleanup/GC:
* maps a single {@link RelNode} ID to all match-key ID lists that include it,
* so related cache entries can be removed efficiently when a node is discarded.
*
* <p>Key: {@link RelNode} ID.
*
* <p>Value: match-key ID lists in {@link #firedRulesCache} that contain the key ID.
*/
private final Multimap<Integer, List<Integer>> firedRulesCacheIndex = HashMultimap.create();


private boolean enableFiredRulesCache = false;

//~ Constructors -----------------------------------------------------------

/**
Expand Down Expand Up @@ -173,6 +200,8 @@ public HepPlanner(
removeRule(rule);
}
this.materializations.clear();
this.firedRulesCache.clear();
this.firedRulesCacheIndex.clear();
}

@Override public RelNode changeTraits(RelNode rel, RelTraitSet toTraits) {
Expand All @@ -195,6 +224,17 @@ public HepPlanner(
return buildFinalPlan(requireNonNull(root, "'root' must not be null"));
}

/**
* Enables or disables the fire-rule cache.
*
* <p> If enabled, a rule will not fire twice on the same {@code RelNode::getId()}.
*
* @param enable true to enable; false is default value.
*/
public void setEnableFiredRulesCache(boolean enable) {
enableFiredRulesCache = enable;
}

/** Top-level entry point for a program. Initializes state and then invokes
* the program. */
private void executeProgram(HepProgram program) {
Expand Down Expand Up @@ -519,13 +559,28 @@ private Iterator<HepRelVertex> getGraphIterator(
nodeChildren,
parents);

List<Integer> relIds = null;
if (enableFiredRulesCache) {
relIds = call.getRelList().stream().map(RelNode::getId).collect(Collectors.toList());
if (firedRulesCache.get(relIds).contains(rule)) {
return null;
}
}

// Allow the rule to apply its own side-conditions.
if (!rule.matches(call)) {
return null;
}

fireRule(call);

if (relIds != null) {
firedRulesCache.put(relIds, rule);
for (Integer relId : relIds) {
firedRulesCacheIndex.put(relId, relIds);
}
}

if (!call.getResults().isEmpty()) {
return applyTransformationResults(
vertex,
Expand Down Expand Up @@ -982,6 +1037,15 @@ private void collectGarbage() {

// Clean up metadata cache too.
sweepSet.forEach(this::clearCache);

if (enableFiredRulesCache) {
sweepSet.forEach(rel -> {
for (List<Integer> relIds : firedRulesCacheIndex.get(rel.getCurrentRel().getId())) {
firedRulesCache.removeAll(relIds);
}
firedRulesCacheIndex.removeAll(rel.getCurrentRel().getId());
});
}
}

private void assertNoCycles() {
Expand Down
29 changes: 24 additions & 5 deletions core/src/test/java/org/apache/calcite/test/HepPlannerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,29 @@ private void assertIncludesExactlyOnce(String message, String digest,
}

@Test void testRuleApplyCount() {
final long applyTimes1 = checkRuleApplyCount(HepMatchOrder.ARBITRARY);
assertThat(applyTimes1, is(316L));
long applyTimes = checkRuleApplyCount(HepMatchOrder.ARBITRARY, false);
assertThat(applyTimes, is(316L));

final long applyTimes2 = checkRuleApplyCount(HepMatchOrder.DEPTH_FIRST);
assertThat(applyTimes2, is(87L));
applyTimes = checkRuleApplyCount(HepMatchOrder.DEPTH_FIRST, false);
assertThat(applyTimes, is(87L));

applyTimes = checkRuleApplyCount(HepMatchOrder.TOP_DOWN, false);
assertThat(applyTimes, is(295L));

applyTimes = checkRuleApplyCount(HepMatchOrder.BOTTOM_UP, false);
assertThat(applyTimes, is(296L));

applyTimes = checkRuleApplyCount(HepMatchOrder.ARBITRARY, true);
assertThat(applyTimes, is(65L));

applyTimes = checkRuleApplyCount(HepMatchOrder.DEPTH_FIRST, true);
assertThat(applyTimes, is(65L));

applyTimes = checkRuleApplyCount(HepMatchOrder.TOP_DOWN, true);
assertThat(applyTimes, is(65L));

applyTimes = checkRuleApplyCount(HepMatchOrder.BOTTOM_UP, true);
assertThat(applyTimes, is(65L));
}

@Test void testMaterialization() {
Expand All @@ -387,7 +405,7 @@ private void assertIncludesExactlyOnce(String message, String digest,
assertThat(planner.getMaterializations(), empty());
}

private long checkRuleApplyCount(HepMatchOrder matchOrder) {
private long checkRuleApplyCount(HepMatchOrder matchOrder, boolean enableFiredRulesCache) {
final HepProgramBuilder programBuilder = HepProgram.builder();
programBuilder.addMatchOrder(matchOrder);
programBuilder.addRuleInstance(CoreRules.FILTER_REDUCE_EXPRESSIONS);
Expand All @@ -397,6 +415,7 @@ private long checkRuleApplyCount(HepMatchOrder matchOrder) {
HepPlanner planner = new HepPlanner(programBuilder.build());
planner.addListener(listener);
planner.setRoot(sql(COMPLEX_UNION_TREE).toRel());
planner.setEnableFiredRulesCache(enableFiredRulesCache);
planner.findBestExp();
return listener.getApplyTimes();
}
Expand Down
Loading