Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 68 additions & 7 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,11 @@ std::vector<std::vector<Val*>> getTriviallyMappedIds(Expr* expr) {
return mapped_ids;
}

// True when expr simplification proves split->isDivisible().
bool isDivisible(Split* split) {
return simplifyExpr(split->isDivisible())->isTrue();
}

// The following is a subpattern of
// https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md#2-properties-of-iterdomain-transformations
//
Expand All @@ -492,25 +497,21 @@ std::vector<std::vector<Val*>> getTriviallyMappedIds(Expr* expr) {
// If outermost_grand and outer' have the same extent, map them.
// The splits must be divisible for this mapping to be valid.
void mapDivisibleSplits(ValGraph& graph) {
auto is_divisible = [](Split* s) {
return simplifyExpr(s->isDivisible())->isTrue();
};

std::vector<std::pair<Val*, Val*>> ids_to_map;
for (const ValGroup& root : graph.disjointValSets().disjointSets()) {
const ExprGroups& uses_of_root = graph.getUses(root);
std::vector<ValGroup> outermost_grands;
for (const ExprGroup& use_of_root : uses_of_root) {
auto* split0 = dynamic_cast<Split*>(use_of_root->front());
if (split0 == nullptr || !is_divisible(split0)) {
if (split0 == nullptr || !isDivisible(split0)) {
continue;
}
// Only follow the outer output of the first split; outer and inner
// must not be conflated.
const ValGroup& outer = graph.toGroup(split0->outer());
for (const ExprGroup& use_of_outer : graph.getUses(outer)) {
auto* split1 = dynamic_cast<Split*>(use_of_outer->front());
if (split1 == nullptr || !is_divisible(split1)) {
if (split1 == nullptr || !isDivisible(split1)) {
continue;
}
const ValGroup& outermost_grand = graph.toGroup(split1->outer());
Expand All @@ -524,7 +525,7 @@ void mapDivisibleSplits(ValGraph& graph) {

for (const ExprGroup& use_of_root : uses_of_root) {
auto* split = dynamic_cast<Split*>(use_of_root->front());
if (split == nullptr || !is_divisible(split)) {
if (split == nullptr || !isDivisible(split)) {
continue;
}

Expand All @@ -542,6 +543,65 @@ void mapDivisibleSplits(ValGraph& graph) {
}
}

void mapDivisibleMergeSplits(ValGraph& graph) {
std::vector<std::pair<Val*, Val*>> ids_to_map;
// Given
//
// merge_outer merge_inner
// \ /
// [merge]
// |
// merge_out
// |
// [split_merge]
// / \.
// split_merge->outer() split_merge->inner()
//
// and
//
// merge_outer
// |
// [split_outer]
// / \.
// split_outer->outer() split_outer->inner()
//
// map split_merge->outer() and split_outer->outer() under certain
// divisibility conditions.
for (const ExprGroup& merge_group : graph.disjointExprSets().disjointSets()) {
auto* merge = dynamic_cast<Merge*>(merge_group->front());
if (merge == nullptr) {
continue;
}

const ValGroup& merge_out_group = graph.toGroup(merge->out());
for (const ExprGroup& split_merge_group : graph.getUses(merge_out_group)) {
auto* split_merge = dynamic_cast<Split*>(split_merge_group->front());
if (split_merge == nullptr || !isDivisible(split_merge) ||
split_merge->innerSplit()) {
continue;
}

const ValGroup& merge_outer_group = graph.toGroup(merge->outer());
for (const ExprGroup& split_outer_group :
graph.getUses(merge_outer_group)) {
auto* split_outer = dynamic_cast<Split*>(split_outer_group->front());
if (split_outer == nullptr || !isDivisible(split_outer) ||
split_outer->innerSplit()) {
continue;
}
if (!split_merge->factor()->sameAs(split_outer->factor())) {
continue;
}
ids_to_map.emplace_back(split_merge->outer(), split_outer->outer());
}
}
}

for (const auto& [id1, id2] : ids_to_map) {
graph.mapVals(id1, id2);
}
}

} // namespace

ValGraph& IdModel::buildAlmostExactGraph() {
Expand Down Expand Up @@ -602,6 +662,7 @@ ValGraph& IdModel::buildAlmostExactGraph() {
}

mapDivisibleSplits(almost_exact_graph);
mapDivisibleMergeSplits(almost_exact_graph);

almost_exact_graph.validateConsistency();

Expand Down
18 changes: 18 additions & 0 deletions tests/cpp/test_id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3385,4 +3385,22 @@ TEST_F(IdModelTest, NonDivisibleSplits_NotMapped) {
in->axis(0), out->axis(0)));
}

TEST_F(IdModelTest, MergingReshapeOuterSplit_Mapped) {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* in = makeContigConcreteTensor({2LL * 2, 2});
fusion.addInput(in);
TensorView* out = reshape(in, {2LL * 2, 2}, {2LL * 2 * 2});
fusion.addOutput(out);

in->outer_split(0, 2);
out->outer_split(0, 2);

IdModel id_model(&fusion);
const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph();
EXPECT_TRUE(almost_exact_graph.disjointValSets().strictAreMapped(
in->axis(0), out->axis(0)));
}
Comment on lines +3388 to +3404
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing negative tests for mapDivisibleMergeSplits

The PR adds one positive test (MergingReshapeOuterSplit_Mapped) but no negative counterparts exercising the guard conditions inside mapDivisibleMergeSplits. The following cases are silently assumed to not map, but are not explicitly verified:

  1. Inner split instead of outer splitout->inner_split(0, 2) should NOT map in->axis(0) and out->axis(0) (the !split_merge->innerSplit() guard).
  2. Mismatched factorsin->outer_split(0, 2) plus out->outer_split(0, 4) (different factors) should NOT map.
  3. Non-divisible outer split — factors that don't evenly divide the dimension size should NOT map (the isDivisible guard).

The existing NonDivisibleSplits_NotMapped test covers mapDivisibleSplits, but there is no equivalent for mapDivisibleMergeSplits. Consider adding at least one negative test to document these invariants and prevent future regressions.


} // namespace nvfuser
4 changes: 2 additions & 2 deletions tests/cpp/test_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,9 +860,9 @@ TEST_F(IndexingTest, Reshape) {
// to provide the extent of the group. However, since everything
// should be deterministic, string match should also work.
return std::string(
"( ( ( ( ( i126 * 20 ) + ( ( i127 * 10 ) + i128 ) ) / 25 ) * 25 "
"( ( ( ( ( i130 * 20 ) + ( ( i131 * 10 ) + i132 ) ) / 25 ) * 25 "
") "
"+ ( ( ( i126 * 20 ) + ( ( i127 * 10 ) + i128 ) ) % 25 ) )");
"+ ( ( ( i130 * 20 ) + ( ( i131 * 10 ) + i132 ) ) % 25 ) )");
}
default:
return std::string();
Expand Down
Loading