fix: Handle zero-stride TMA basis in fill_tma_gmem_shape_stride#3121
fix: Handle zero-stride TMA basis in fill_tma_gmem_shape_stride#3121RobTand wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
|
Thank you for the mentioned, I am not working on it as this time as I’ve
been busy trying to work on an sm120 optimized attention kernel, which is
proving to be harder than I thought.
I will be happy to help coordinate with you and anything you need.
|
Testing on RTX PRO 6000 (SM120, TP=4)We applied both this PR and FlashInfer #2786 together on 4x RTX PRO 6000 Blackwell Server Edition (SM120, PCIe, 96GB each) running SGLang with Setup:
Results (single-user decode throughput):
No measurable difference on this configuration. The patches applied cleanly and FlashInfer JIT compiled the K=64 kernels without errors, but the dispatch doesn't seem to select K=64 tiles for this model/shape combination, or the benefit is offset by TP=4 communication overhead. Questions:
Both patches are functionally correct on SM120 — no crashes, no compilation errors. The CUTLASS TMA zero-stride fix and EffBlk_SF clamping work as described. |
|
Thanks for taking the time to review this — really appreciate it. To address your question about forcing K=64 tiles and demonstrating necessity: SMEM constraint on SM120/121 SM120/SM121 devices (RTX 5090, DGX Spark GB10) have 99 KB of opt-in shared memory per block, compared to 227 KB on SM100. This changes which tile shapes are viable for block-scaled GEMM. Pipeline stage impact The builder computes pipeline stages as
K=128 tiles do work on SM120/121 (3–5 stages), so this isn't strictly about making things functional at K=128. The issue is that K=64 tiles — the natural choice for the constrained SMEM budget — cannot be instantiated at all. Why K=64 fails without this fix Two bugs surface only when
Why these were never caught on SM100 With 227 KB of SMEM, K=128 tiles get 8–12 pipeline stages — more than sufficient. K=64 tiles are never needed, so the code path was never exercised. Both bugs are latent but only manifest with small K values. Performance expectations The performance improvement from deeper pipelining may be modest on SM120/121, since these devices tend to be memory-bandwidth-bound. The primary motivation is correctness: K=64 tile shapes should be valid instantiations of the block-scaled GEMM builder, and currently they are not. How to test The simplest way to verify is a one-line change to an existing unit test. In using TileShape = Shape<_128,_128,_256>;to: using TileShape = Shape<_128,_128,_64>;Without this PR, this will either fail to compile or produce corrupt results. With the fix, it should compile and pass Both FlashInfer (v0.6.6) and vLLM (v0.18.0) already define K=64 tile configs in their SM120 enums ( |
|
I got no reason to
I think his benchmarks are fair. I did my own benchmarking her e https://github.com/brandonmmusic-max/sm120-moe-bench, but I was testing for effects on different things. The fix may help more on the prefill side (i was getting 17k or so) The prefill numbers are pretty good. But I haven’t ran benchmark is such a manner, that I can’t exclude anny increase being related to p2p or mtp; i was optimizing for my particular neuro-symbolic pipeline for my local workflow.. I’m glad to hear you were getting 2x speed up! Would love to hear more about your setup! |
Update: Prefill benchmarks + methodology fixMy earlier decode results were correct (no difference), but the initial prefill numbers were wrong — the "+19%" was a cold-start artifact (first request without warmup hitting radix cache miss + JIT warmup). Re-ran with proper methodology: 3 warmup + 5 measured runs per prompt size. Hardware: 4x RTX PRO 6000 Blackwell Server Edition (SM120, 96GB GDDR7 each, PCIe Gen5) Decode throughput (single user, 512 output tokens)
Prefill throughput (3 warmup + 5 measured, max_tokens=1)
All differences are within noise (<1%). Standard deviation across runs was 0.1–3.1ms. AnalysisThe MoE GEMM dimensions for Qwen3.5-397B with TP=4 are:
For decode (M=1): the operation is a GEMV, entirely memory-bandwidth bound. Pipeline stages (K=64's advantage) don't help because compute is not the bottleneck — reading weights from GDDR7 dominates. For prefill (M=885–7099): with TP=4, per-GPU M is ~220–1775. Even at these sizes, GDDR7 bandwidth (1.79 TB/s per GPU) appears sufficient that the additional pipeline depth from K=64 tiles doesn't provide measurable benefit over K=128. Note on hardware difference: Your PR reports 2x speedup on DGX Spark (SM121, unified LPDDR5X). LPDDR5X has significantly lower bandwidth (~273 GB/s) than GDDR7 (1.79 TB/s per GPU × 4 = 7.16 TB/s aggregate). This likely explains why K=64's compute efficiency gains are visible on DGX Spark (compute-bound) but not on discrete RTX PRO 6000 GPUs (memory-bandwidth bound even for prefill). The patches are functionally correct — JIT compilation succeeds, kernels load, no errors. They just don't provide a throughput benefit on this specific hardware configuration. |
…es on SM12x (#2913) ### Summary - Add missing `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` compile flag to all CUTLASS fused MoE JIT modules (SM100/SM103/SM120) and `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to SM90 modules - Sync nv_internal `grid_dependency_control.h` with upstream CUTLASS to support SM100/SM103/SM110/SM120/SM121 GDC - Add `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to FP8 blockscale GEMM SM90 module ### Problem Random `cudaErrorIllegalInstruction` crashes on DGX Spark (SM121) and RTX 50-series (SM120) when running NVFP4 MoE models (e.g., Nemotron, Qwen3.5-122B) under load. The crashes are intermittent and worsen with longer context lengths and higher concurrency. **Root cause:** PR #2780 fixed the missing GDC compile flags for GEMM modules (`flashinfer/jit/gemm/core.py`), but the **CUTLASS fused MoE modules** in `flashinfer/jit/fused_moe.py` and the **FP8 blockscale GEMM module** were not fixed. This is the exact same class of bug as #2708. Without `-DCUTLASS_ENABLE_GDC_FOR_SM100=1`, CUTLASS's `grid_dependency_control.h` compiles `wait_on_dependent_grids()` and `launch_dependent_grids()` as **empty no-ops**: ```cpp CUTLASS_DEVICE void wait_on_dependent_grids() { #if (defined(CUTLASS_GDC_ENABLED)) // ← not defined without the flag asm volatile("griddepcontrol.wait;"); #endif } ``` Meanwhile, the host-side code still sets `programmaticStreamSerializationAllowed = true` (PDL enabled) via `device_support_pdl()` which returns `True` for all `major >= 9`, including SM12x. This means: 1. **Host enables PDL** → CUDA runtime overlaps consecutive kernels 2. **Device GDC barriers are no-ops** → No synchronization between overlapping kernels 3. **Race condition** → Dependent kernel reads stale global memory → corruption → `cudaErrorIllegalInstruction` The crash is random because it depends on exact kernel scheduling timing, which varies per request. ### Fix **`flashinfer/jit/fused_moe.py`** — Added GDC flags to all CUTLASS fused MoE modules: | Module | Flag | Architectures Covered | |---|---|---| | `gen_cutlass_fused_moe_sm120_module()` | `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM120, SM121 | | `gen_cutlass_fused_moe_sm103_module()` | `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM103, SM120, SM121 | | `gen_cutlass_fused_moe_sm100_module()` | `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM100, SM110, SM120, SM121 | | `gen_cutlass_fused_moe_sm90_module()` | `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` | SM90 | | `gen_trtllm_gen_fused_moe_sm100_module()` | `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM100+, SM120, SM121 | **`flashinfer/jit/gemm/fp8_blockscale.py`** — Added `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to `gen_fp8_blockscale_gemm_sm90_module()`. **`csrc/nv_internal/.../grid_dependency_control.h`** — Synced with upstream CUTLASS (`3rdparty/cutlass/include/cutlass/arch/grid_dependency_control.h`) to add SM100+ GDC support. Previously only handled SM90, so any nv_internal TensorRT-LLM code compiled for SM12x would have GDC barriers silently compiled as no-ops. ### Why `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` covers SM12x CUTLASS uses a single flag for the entire Blackwell family. From `grid_dependency_control.h`: ```cpp #if(CUDA_BARRIER_ENABLED && defined(CUTLASS_ENABLE_GDC_FOR_SM100) && defined(__CUDA_ARCH__) && \ ((__CUDA_ARCH__ == 1000 && ...) || // SM100 (__CUDA_ARCH__ == 1030 && ...) || // SM103 (__CUDA_ARCH__ == 1100 && ...) || // SM110 (__CUDA_ARCH__ == 1200 && ...) || // SM120 (RTX 50-series) (__CUDA_ARCH__ == 1210 && ...))) // SM121 (DGX Spark) #define CUTLASS_GDC_ENABLED ``` ### Why SM90 GDC flag was NOT added to SM100+ modules PR #2716 attempted to add both `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` and `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` to all modules. It broke AOT builds because `sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp` checks `CUTLASS_ENABLE_GDC_FOR_SM90` and calls `scheduler.is_last_tile()` — a method not present on the SM120 scheduler. PR #2780 corrected this by using only the SM100 flag for SM100+ modules. This PR follows the same approach. ### Related - #2708 — Original issue: missing GDC flags cause PDL race condition - #2716 — First fix attempt (reverted — broke AOT) - #2780 — Corrected fix for GEMM modules only - [vllm-project/vllm#38423](vllm-project/vllm#38423) — NVFP4 bugfix on DGX Spark - [NVIDIA/cutlass#3121](NVIDIA/cutlass#3121) — K=64 block-scaled GEMM tiles (separate issue) ### Test plan - [x] Clear JIT cache: `rm -rf ~/.cache/flashinfer/` - [x] Run NVFP4 MoE model on SM121 (DGX Spark) with 128K context under load — verify no `cudaErrorIllegalInstruction` - [x] Run NVFP4 MoE model on SM120 (RTX 50-series) with concurrent requests — verify no NaN/garbage output - [x] Verify `CUDA_LAUNCH_BLOCKING=1` workaround is no longer needed - [x] AOT build with `FLASHINFER_CUDA_ARCH_LIST="12.1a"` completes without errors - [x] SM90 (Hopper) fused MoE tests pass: `pytest tests/moe/` - [x] SM100 GEMM tests still pass (no regression from existing GDC flags) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Expanded GPU kernel compilation support: enabled additional optimizations for NVIDIA SM100 and SM90 GPUs, activating dependency-control optimizations where available. * Updated JIT/GEMM build configs to include these architecture-specific compile options, improving performance and compatibility on supported hardware. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
5197972 to
ab4a7c2
Compare
|
@johnnynunez i updated this and tried to address @depaulmillz 's issues with tiling |
ab4a7c2 to
760feb5
Compare
When a block-scaled GEMM layout folds scale factors over a broadcast dimension (stride == 0), the existing code passes the zero-stride basis to basis_get(), which produces undefined gmem_prob_shape and gmem_prob_stride values in the TMA descriptor. This results in invalid TMA descriptors that cause silent data corruption or runtime errors. Fix: detect zero-stride basis via is_constant<0> and emit the correct degenerate descriptor (shape=1, stride=0) for broadcast dimensions. This is the standard TMA encoding for a dimension that contributes no unique data. The bug is latent in all block-scaled GEMM configurations but only manifests when K/SFVectorSize creates a broadcast fold — e.g. when scale factor vector size equals or exceeds the tile K extent. Signed-off-by: Rob Tand <robert.tand@icloud.com>
760feb5 to
216faae
Compare
|
dramatically reduced the scope here since k=64 doesn't seem to improve the performance and significantly increases the complexity. Just fixes the root bug that was causing things to barf |
|
I believe this should result in a TMA admissibility failure -- the gmem layout and the smem layouts are incompatible for TMA to execute the copy. This should be a static assertion instead, correct? |
|
@ccecka Thanks for taking a look — appreciate the review. Good question — I initially thought the same, but I think this is actually a valid case rather than an inadmissible one. The zero-stride basis here comes from a broadcast dimension in block-scaled GEMM scale factor layouts. Concrete example: With NVF4 (SFVectorSize=32) and a K=64 tile, the scale factor TMA basis decomposition produces a mode where 32 elements share the same scale factor — CuTe encodes this as stride=0. The basis for that mode is Currently The downstream admissibility problem @depaulmillz found (SMEM contiguity < 16B with the 128x128x64 tile) is real but it's in the builder's layout math, not here. The TMA descriptor itself is well-formed with the broadcast handled correctly. A Happy to be wrong though — if there's a reason broadcast modes should never reach TMA descriptor construction, a static assert makes more sense and the fix should move to the caller. |
|
@ccecka That's a really good point and I want to make sure I get this right. My reasoning was that the zero-stride basis here comes from SFVectorSize broadcast folding — 32 SMEM positions aliased to the same physical address via stride=0, so I thought TMA could serve it by writing once to the aliased address with a degenerate Could you help me understand: does a zero-stride basis in Really appreciate you taking the time to look at this — your insight on the TMA internals is exactly what I need here. |
|
To my knowledge, it should always be invalid -- TMA is a glorified memcpy into smem so the contiguous elements that it sees in SMEM must correspond to real elements in gmem. If you print the smem layout and the gmem layout I suspect you will find the "broadcast" modes of each to be misaligned (either different sizes or corresponding to different coordinates) and, therefore, the source and target are TMA-incompatible. |
|
@ccecka appreciate the feedback, will take a look! |
|
Closing this out. The underlying zero-stride TMA basis bug is real but only manifests with K=64 tiles (where K/SFVectorSize creates a broadcast fold), and testing showed K=64 doesn't improve performance — we're already at 99% theoretical memory bandwidth with K=128. ccecka's feedback that this should be a static_assert rather than a graceful workaround is fair, but a compile-time guard for a tile shape nobody is using doesn't add much practical value. If K=64 support is revisited in the future, the builder-level layout fix (avoiding the zero-stride basis entirely) is the right approach, and depaulmillz's SMEM contiguity finding would need to be addressed as well. |

Summary
When a block-scaled GEMM layout folds scale factors over a broadcast dimension (stride == 0),
fill_tma_gmem_shape_stridepasses the zero-stride basis tobasis_get(), which produces undefinedgmem_prob_shapeandgmem_prob_stridevalues in the TMA descriptor. This results in invalid TMA descriptors that cause silent data corruption or runtime errors.Fix: detect zero-stride basis via
is_constant<0>and emit the correct degenerate descriptor (shape=1, stride=0) for broadcast dimensions. This is the standard TMA encoding for a dimension that contributes no unique data.The bug is latent in all block-scaled GEMM configurations but only manifests when
K/SFVectorSizecreates a broadcast fold — e.g. when the scale factor vector size equals or exceeds the tile K extent.Changes
1 file, +8/-2 lines in
include/cute/atom/copy_traits_sm90_tma.hppTest plan
sm120_bs_gemm_nvf4_nvf4_f32_bf16.cuwithTileShape = Shape<_128,_128,_64>no longer produces invalid TMA descriptorsSigned-off-by: Rob Tand robert.tand@icloud.com