From fd8d43d9d4b9977afd52a5debdd138bef308a815 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 24 Mar 2026 15:44:44 -0700 Subject: [PATCH 1/2] Map divisible outer splits across merge outputs in almost-exact graph Add mapDivisibleMergeSplits after mapDivisibleSplits: for a Merge (e.g. merging reshape), when the merge output and merge outer input each have a divisible outer Split with the same factor, map the two outer IterDomains. Factor out isDivisible(Split*) for shared use with mapDivisibleSplits. Add IdModelTest.MergingReshapeOuterSplit_Mapped. Made-with: Cursor --- csrc/id_model/id_model.cpp | 79 +++++++++++++++++++++++++++++++++---- tests/cpp/test_id_model.cpp | 18 +++++++++ 2 files changed, 90 insertions(+), 7 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index b1baad0d187..8fd9b3f0f36 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -482,6 +482,11 @@ std::vector> getTriviallyMappedIds(Expr* expr) { return mapped_ids; } +// True when expr simplification proves split->isDivisible(). +bool isDivisible(Split* split) { + return simplifyExpr(split->isDivisible())->isTrue(); +} + // The following is a subpattern of // https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md#2-properties-of-iterdomain-transformations // @@ -492,17 +497,13 @@ std::vector> getTriviallyMappedIds(Expr* expr) { // If outermost_grand and outer' have the same extent, map them. // The splits must be divisible for this mapping to be valid. void mapDivisibleSplits(ValGraph& graph) { - auto is_divisible = [](Split* s) { - return simplifyExpr(s->isDivisible())->isTrue(); - }; - std::vector> ids_to_map; for (const ValGroup& root : graph.disjointValSets().disjointSets()) { const ExprGroups& uses_of_root = graph.getUses(root); std::vector outermost_grands; for (const ExprGroup& use_of_root : uses_of_root) { auto* split0 = dynamic_cast(use_of_root->front()); - if (split0 == nullptr || !is_divisible(split0)) { + if (split0 == nullptr || !isDivisible(split0)) { continue; } // Only follow the outer output of the first split; outer and inner @@ -510,7 +511,7 @@ void mapDivisibleSplits(ValGraph& graph) { const ValGroup& outer = graph.toGroup(split0->outer()); for (const ExprGroup& use_of_outer : graph.getUses(outer)) { auto* split1 = dynamic_cast(use_of_outer->front()); - if (split1 == nullptr || !is_divisible(split1)) { + if (split1 == nullptr || !isDivisible(split1)) { continue; } const ValGroup& outermost_grand = graph.toGroup(split1->outer()); @@ -524,7 +525,7 @@ void mapDivisibleSplits(ValGraph& graph) { for (const ExprGroup& use_of_root : uses_of_root) { auto* split = dynamic_cast(use_of_root->front()); - if (split == nullptr || !is_divisible(split)) { + if (split == nullptr || !isDivisible(split)) { continue; } @@ -542,6 +543,69 @@ void mapDivisibleSplits(ValGraph& graph) { } } +// Outer-split only: for a Merge (e.g. merging reshape), match the outer output +// of the consumer split of merge->out() (use_m_out) with the outer output of +// the producer split of merge->outer() (use_merge_outer), when factors match +// and splits are divisible — e.g. DIDx shard axes across reshape sides. +void mapDivisibleMergeSplits(ValGraph& graph) { + std::vector> ids_to_map; + // Given + // + // merge_outer merge_inner + // \ / + // [merge] + // | + // merge_out + // | + // [split_merge] + // / \. + // split_merge->outer() split_merge->inner() + // + // and + // + // merge_outer + // | + // [split_outer] + // / \. + // split_outer->outer() split_outer->inner() + // + // map split_merge->outer() and split_outer->outer() under certain + // divisibility conditions. + for (const ExprGroup& merge_group : graph.disjointExprSets().disjointSets()) { + auto* merge = dynamic_cast(merge_group->front()); + if (merge == nullptr) { + continue; + } + + const ValGroup& merge_out_group = graph.toGroup(merge->out()); + for (const ExprGroup& split_merge_group : graph.getUses(merge_out_group)) { + auto* split_merge = dynamic_cast(split_merge_group->front()); + if (split_merge == nullptr || !isDivisible(split_merge) || + split_merge->innerSplit()) { + continue; + } + + const ValGroup& merge_outer_group = graph.toGroup(merge->outer()); + for (const ExprGroup& split_outer_group : + graph.getUses(merge_outer_group)) { + auto* split_outer = dynamic_cast(split_outer_group->front()); + if (split_outer == nullptr || !isDivisible(split_outer) || + split_outer->innerSplit()) { + continue; + } + if (!split_merge->factor()->sameAs(split_outer->factor())) { + continue; + } + ids_to_map.emplace_back(split_merge->outer(), split_outer->outer()); + } + } + } + + for (const auto& [id1, id2] : ids_to_map) { + graph.mapVals(id1, id2); + } +} + } // namespace ValGraph& IdModel::buildAlmostExactGraph() { @@ -602,6 +666,7 @@ ValGraph& IdModel::buildAlmostExactGraph() { } mapDivisibleSplits(almost_exact_graph); + mapDivisibleMergeSplits(almost_exact_graph); almost_exact_graph.validateConsistency(); diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 76ebf293638..56f02051ed0 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -3385,4 +3385,22 @@ TEST_F(IdModelTest, NonDivisibleSplits_NotMapped) { in->axis(0), out->axis(0))); } +TEST_F(IdModelTest, MergingReshapeOuterSplit_Mapped) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* in = makeContigConcreteTensor({2LL * 2, 2}); + fusion.addInput(in); + TensorView* out = reshape(in, {2LL * 2, 2}, {2LL * 2 * 2}); + fusion.addOutput(out); + + in->outer_split(0, 2); + out->outer_split(0, 2); + + IdModel id_model(&fusion); + const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph(); + EXPECT_TRUE(almost_exact_graph.disjointValSets().strictAreMapped( + in->axis(0), out->axis(0))); +} + } // namespace nvfuser From eb48d1df29e82f2010b66fc5bc75543af736c9ef Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 24 Mar 2026 16:27:12 -0700 Subject: [PATCH 2/2] Fix test --- csrc/id_model/id_model.cpp | 4 ---- tests/cpp/test_indexing.cpp | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 8fd9b3f0f36..6b666126eed 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -543,10 +543,6 @@ void mapDivisibleSplits(ValGraph& graph) { } } -// Outer-split only: for a Merge (e.g. merging reshape), match the outer output -// of the consumer split of merge->out() (use_m_out) with the outer output of -// the producer split of merge->outer() (use_merge_outer), when factors match -// and splits are divisible — e.g. DIDx shard axes across reshape sides. void mapDivisibleMergeSplits(ValGraph& graph) { std::vector> ids_to_map; // Given diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index 1c960ae5359..8108d5647f1 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -860,9 +860,9 @@ TEST_F(IndexingTest, Reshape) { // to provide the extent of the group. However, since everything // should be deterministic, string match should also work. return std::string( - "( ( ( ( ( i126 * 20 ) + ( ( i127 * 10 ) + i128 ) ) / 25 ) * 25 " + "( ( ( ( ( i130 * 20 ) + ( ( i131 * 10 ) + i132 ) ) / 25 ) * 25 " ") " - "+ ( ( ( i126 * 20 ) + ( ( i127 * 10 ) + i128 ) ) % 25 ) )"); + "+ ( ( ( i130 * 20 ) + ( ( i131 * 10 ) + i132 ) ) % 25 ) )"); } default: return std::string();