Skip to content

TMA outer-reduction: Support multi-input fusions#6036

Open
tbqh wants to merge 3 commits intomainfrom
tbqh/multi_input_tma
Open

TMA outer-reduction: Support multi-input fusions#6036
tbqh wants to merge 3 commits intomainfrom
tbqh/multi_input_tma

Conversation

@tbqh
Copy link
Copy Markdown
Collaborator

@tbqh tbqh commented Mar 17, 2026

Fix a canonicalization issue with reduction_outer_tma, and support fusions with multiple inputs. Update TmaOuterReductionTest to cover these new cases. Also disabled Welford op for inner+outer TMA since supporting it is complicated.

@tbqh tbqh requested a review from liqiangxl March 17, 2026 10:52
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 17, 2026

Greptile Summary

This PR fixes two distinct bugs in the reduction_outer_tma scheduler and extends it to handle multi-input fusions. The core correctness fix propagates the [R, I] canonical form to all tensors (including TMA cache TVs) before TMA tiling begins — without this step, tma_tv->split(0, tma_tile_r) was operating on the wrong axis when TMA TVs were still in [I, R] form. The multi-input support removes the n_tensor_inputs == 1 hard gate from mayUseTmaOuter and replaces it with a proper minimum-tile-fits-in-smem check; the heuristic then proportionally shrinks tiles to fit the per-input budget. A follow-on fix correctly falls back to ParallelType::Serial (and non-grouped grid reduction) when iter_unroll_factor == 1 instead of creating an invalid size-1 Vectorize axis. Welford rejection is also added to both mayUseTma and mayUseTmaOuter.

Key changes:

  • Canonicalization propagation fix: TransformPropagator now propagates the [R, I] reorder to all tensors immediately after canonicalizeReduction, ensuring TMA tiling splits target the correct axes.
  • Multi-input smem budget: getReductionHeuristics divides available smem by n_inputs and shrinks tma_tile_r / tma_tile_i in two while-loops; mayUseTmaOuter rejects fusions where minimum-size tiles cannot fit.
  • iter_unroll_factor = 1 guard: Axis 6 parallelization and use_iter_grouped_reduction flag now both condition on iter_unroll_factor > 1, preventing an illegal size-1 Vectorize and a spurious call to iterGroupedGridReduce.
  • rFactorHelper: Replaces the bare rFactor call, which is necessary for correctness in multi-output/multi-producer situations.
  • Test suite: Extended to a 4-tuple (outer_size, iter_size, n_inputs, dtype) with 15 explicit cases covering 1–5 inputs and Float/Half dtypes.

Confidence Score: 4/5

  • Safe to merge with one minor follow-up: the iter_unroll_factor = 1 path (Serial fallback) is not yet exercised by any test case, but the fix itself is structurally correct and the gate in mayUseTmaOuter prevents the worst-case failure mode.
  • All three prior-review blocking concerns are now addressed: the size-1 Vectorize axis is guarded, the smem overhead is accounted for, and the Welford op is rejected. The canonicalization propagation fix is well-motivated and correct. Multi-TMA-TV propagation works because all input cache TVs are reachable through the fusion graph from tma_tvs[0]. No new logic bugs found. The remaining gap — no test hitting iter_unroll_factor = 1 — is a coverage issue, not a correctness regression, and mayUseTmaOuter correctly gates that path for extreme input counts.
  • No files require special attention; reduction_outer_tma.cpp is the most complex change but the logic has been thoroughly reviewed.

Important Files Changed

Filename Overview
csrc/scheduler/reduction_outer_tma.cpp Core scheduler changes: adds smem-budget-driven tile shrinking, a canonicalization propagation fix (propagating [R,I] form to all TVs before TMA tiling), iter_unroll_factor=1 guard, and rFactorHelper for multi-input rFactor. Logic is sound; multi-TMA-TV propagation relies correctly on MaxLogicalDomainInfoSpanningTree reaching sibling inputs through shared consumers.
csrc/scheduler/reduction.cpp Replaces blanket n_tensor_inputs==1 rejection with a proper min-tile-fits-in-smem check, and adds Welford rejection to both mayUseTma and mayUseTmaOuter. smem_overhead formula mirrors the heuristic file exactly.
tests/cpp/test_reduction.cpp Test suite extended with n_inputs and dtype dimensions; expectOuterTmaUsed updated to mirror the new smem check. Left-associative add chain correctly models multi-input fusion. All 15 explicit test cases are valid and exercise the new paths.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[mayUseTmaOuter] -->|n_inputs, dtype| B{Min tiles fit in smem?}
    B -->|No| REJECT[Reject: fall back to non-TMA]
    B -->|Yes| C{Welford op?}
    C -->|Yes| REJECT
    C -->|No| ACCEPT[Accept]

    ACCEPT --> D[getReductionHeuristics]
    D --> E[Compute smem_per_input = smem_bytes / n_inputs]
    E --> F{tma_tile_r * tma_tile_i * dtype > budget?}
    F -->|Yes, shrink r| F
    F -->|Yes, shrink i| F
    F -->|No| G[Compute iter_unroll_factor = tma_tile_i / bdimx]

    G --> H[scheduleReduction]
    H --> H1[cacheInputs → TMA TVs for each input]
    H1 --> H2[canonicalizeReduction → R,I form on reduction_tv]
    H2 --> H3[Propagate R,I to ALL TVs incl. all TMA TVs]
    H3 --> H4[Apply TMA tiling splits to tma_tvs 0]
    H4 --> H5[Propagate tiling from tma_tvs 0 → all TVs]
    H5 --> H6[Parallelize all TMA TVs with parallelizeAllLike]
    H6 --> H7[Sub-split redu_tv into thread dims]
    H7 --> H8{iter_unroll_factor > 1?}
    H8 -->|Yes| H9[axis 6 = Vectorize → Group for iter-grouped reduction]
    H8 -->|No| H10[axis 6 = Serial → regular grid reduction]
    H9 --> H11[rFactorHelper for grid reduction]
    H10 --> H11
    H11 --> H12[Propagate to non-TMA TVs]
Loading

Reviews (2): Last reviewed commit: "Disable vectorization when its factor is..." | Re-trigger Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants