Skip to content

fix: Handle zero-stride TMA basis in fill_tma_gmem_shape_stride#3121

Closed
RobTand wants to merge 1 commit intoNVIDIA:mainfrom
RobTand:fix/sm120-k64-blockscaled-tma-layout
Closed

fix: Handle zero-stride TMA basis in fill_tma_gmem_shape_stride#3121
RobTand wants to merge 1 commit intoNVIDIA:mainfrom
RobTand:fix/sm120-k64-blockscaled-tma-layout

Conversation

@RobTand
Copy link
Copy Markdown

@RobTand RobTand commented Mar 20, 2026

Summary

When a block-scaled GEMM layout folds scale factors over a broadcast dimension (stride == 0), fill_tma_gmem_shape_stride passes the zero-stride basis to basis_get(), which produces undefined gmem_prob_shape and gmem_prob_stride values in the TMA descriptor. This results in invalid TMA descriptors that cause silent data corruption or runtime errors.

Fix: detect zero-stride basis via is_constant<0> and emit the correct degenerate descriptor (shape=1, stride=0) for broadcast dimensions. This is the standard TMA encoding for a dimension that contributes no unique data.

The bug is latent in all block-scaled GEMM configurations but only manifests when K/SFVectorSize creates a broadcast fold — e.g. when the scale factor vector size equals or exceeds the tile K extent.

Changes

1 file, +8/-2 lines in include/cute/atom/copy_traits_sm90_tma.hpp

Test plan

  • Existing block-scaled GEMM unit tests pass (K≥128 tiles unaffected — codepath is only reached with zero-stride basis)
  • sm120_bs_gemm_nvf4_nvf4_f32_bf16.cu with TileShape = Shape<_128,_128,_64> no longer produces invalid TMA descriptors

Signed-off-by: Rob Tand robert.tand@icloud.com

@brandonmmusic-max
Copy link
Copy Markdown

brandonmmusic-max commented Mar 20, 2026 via email

@voipmonitor
Copy link
Copy Markdown

Testing on RTX PRO 6000 (SM120, TP=4)

We applied both this PR and FlashInfer #2786 together on 4x RTX PRO 6000 Blackwell Server Edition (SM120, PCIe, 96GB each) running SGLang with nvidia/Qwen3.5-397B-A17B-NVFP4.

Setup:

  • Docker image: voipmonitor/sglang:test-cu132 (CUDA 13.2, PyTorch nightly cu132)
  • CUTLASS 4.4.2 + this PR patch applied to headers
  • FlashInfer 0.6.6 reinstalled from PR Fixed compilation error when using StreamK scheduler + PDL. (#2686) #2786 branch
  • Backend: --fp4-gemm-backend flashinfer_cutlass --moe-runner-backend flashinfer_cutlass
  • No speculative decoding, no torch.compile
  • Single-request decode benchmark: 512 tokens, temperature=0

Results (single-user decode throughput):

Config Run 1 Run 2 Run 3
Baseline (K=128 only) 68.5 tok/s 70.2 tok/s 70.3 tok/s
With K=64 patches (both PRs) 69.9 tok/s 70.2 tok/s 70.2 tok/s

No measurable difference on this configuration. The patches applied cleanly and FlashInfer JIT compiled the K=64 kernels without errors, but the dispatch doesn't seem to select K=64 tiles for this model/shape combination, or the benefit is offset by TP=4 communication overhead.

Questions:

  1. What model and configuration did you benchmark for the ~2x speedup? (The PR mentions Nemotron-3 on DGX Spark SM121 — was that single-GPU unified memory?)
  2. Does K=64 benefit primarily smaller batch sizes or specific expert dimensions?
  3. Is there a way to force K=64 tile selection to verify it's being dispatched?

Both patches are functionally correct on SM120 — no crashes, no compilation errors. The CUTLASS TMA zero-stride fix and EffBlk_SF clamping work as described.

@RobTand
Copy link
Copy Markdown
Author

RobTand commented Mar 22, 2026

Thanks for taking the time to review this — really appreciate it.

To address your question about forcing K=64 tiles and demonstrating necessity:

SMEM constraint on SM120/121

SM120/SM121 devices (RTX 5090, DGX Spark GB10) have 99 KB of opt-in shared memory per block, compared to 227 KB on SM100. This changes which tile shapes are viable for block-scaled GEMM.

Pipeline stage impact

The builder computes pipeline stages as available_smem / bytes_per_stage. For NVF4 with SFVectorSize=32:

Tile Shape Bytes/Stage SM100 Stages SM120/121 Stages
128×128×K=128 ~17 KB 12 5
128×256×K=128 ~26 KB 8 3
128×128×K=64 ~9 KB 25 10
128×256×K=64 ~13 KB 17 7

K=128 tiles do work on SM120/121 (3–5 stages), so this isn't strictly about making things functional at K=128. The issue is that K=64 tiles — the natural choice for the constrained SMEM budget — cannot be instantiated at all.

Why K=64 fails without this fix

Two bugs surface only when K/SFVectorSize < Blk_SF (i.e., K=64 with SFVectorSize=32 gives NumSFAlongK=2, but Blk_SF=4):

  • TMA descriptor corruption (copy_traits_sm90_tma.hpp): Scale factor folding creates a broadcast dimension with zero stride. basis_get() on a zero-stride basis produces undefined gmem_prob_shape/gmem_prob_stride values, resulting in an invalid TMA descriptor.

  • Scale factor layout overflow (sm120_blockscaled_mma_builder.inl): The division Blk_SF/MMA_NSF assumes Blk_SF ≤ NumSFAlongK. When that doesn't hold, the resulting Blk_Elems layout overflows, producing an incorrect TMA tensor map for scale factors.

Why these were never caught on SM100

With 227 KB of SMEM, K=128 tiles get 8–12 pipeline stages — more than sufficient. K=64 tiles are never needed, so the code path was never exercised. Both bugs are latent but only manifest with small K values.

Performance expectations

The performance improvement from deeper pipelining may be modest on SM120/121, since these devices tend to be memory-bandwidth-bound. The primary motivation is correctness: K=64 tile shapes should be valid instantiations of the block-scaled GEMM builder, and currently they are not.

How to test

The simplest way to verify is a one-line change to an existing unit test. In test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_bf16.cu, change:

using TileShape = Shape<_128,_128,_256>;

to:

using TileShape = Shape<_128,_128,_64>;

Without this PR, this will either fail to compile or produce corrupt results. With the fix, it should compile and pass TestSmall correctness checks.

Both FlashInfer (v0.6.6) and vLLM (v0.18.0) already define K=64 tile configs in their SM120 enums (CtaShape128x128x64B, etc.) but deliberately exclude them from dispatch tables because they can't be instantiated against upstream CUTLASS today.

@brandonmmusic-max
Copy link
Copy Markdown

I got no reason to

Testing on RTX PRO 6000 (SM120, TP=4)

We applied both this PR and FlashInfer #2786 together on 4x RTX PRO 6000 Blackwell Server Edition (SM120, PCIe, 96GB each) running SGLang with nvidia/Qwen3.5-397B-A17B-NVFP4.

Setup:

  • Docker image: voipmonitor/sglang:test-cu132 (CUDA 13.2, PyTorch nightly cu132)
  • CUTLASS 4.4.2 + this PR patch applied to headers
  • FlashInfer 0.6.6 reinstalled from PR Fixed compilation error when using StreamK scheduler + PDL. (#2686) #2786 branch
  • Backend: --fp4-gemm-backend flashinfer_cutlass --moe-runner-backend flashinfer_cutlass
  • No speculative decoding, no torch.compile
  • Single-request decode benchmark: 512 tokens, temperature=0

Results (single-user decode throughput):

Config Run 1 Run 2 Run 3
Baseline (K=128 only) 68.5 tok/s 70.2 tok/s 70.3 tok/s
With K=64 patches (both PRs) 69.9 tok/s 70.2 tok/s 70.2 tok/s
No measurable difference on this configuration. The patches applied cleanly and FlashInfer JIT compiled the K=64 kernels without errors, but the dispatch doesn't seem to select K=64 tiles for this model/shape combination, or the benefit is offset by TP=4 communication overhead.

Questions:

  1. What model and configuration did you benchmark for the ~2x speedup? (The PR mentions Nemotron-3 on DGX Spark SM121 — was that single-GPU unified memory?)
  2. Does K=64 benefit primarily smaller batch sizes or specific expert dimensions?
  3. Is there a way to force K=64 tile selection to verify it's being dispatched?

Both patches are functionally correct on SM120 — no crashes, no compilation errors. The CUTLASS TMA zero-stride fix and EffBlk_SF clamping work as described.

I think his benchmarks are fair. I did my own benchmarking her e https://github.com/brandonmmusic-max/sm120-moe-bench, but I was testing for effects on different things. The fix may help more on the prefill side (i was getting 17k or so) The prefill numbers are pretty good. But I haven’t ran benchmark is such a manner, that I can’t exclude anny increase being related to p2p or mtp; i was optimizing for my particular neuro-symbolic pipeline for my local workflow.. I’m glad to hear you were getting 2x speed up! Would love to hear more about your setup!

@voipmonitor
Copy link
Copy Markdown

Update: Prefill benchmarks + methodology fix

My earlier decode results were correct (no difference), but the initial prefill numbers were wrong — the "+19%" was a cold-start artifact (first request without warmup hitting radix cache miss + JIT warmup). Re-ran with proper methodology: 3 warmup + 5 measured runs per prompt size.

Hardware: 4x RTX PRO 6000 Blackwell Server Edition (SM120, 96GB GDDR7 each, PCIe Gen5)
Model: nvidia/Qwen3.5-397B-A17B-NVFP4, TP=4, flashinfer_cutlass backend
Patches: CUTLASS #3121 + FlashInfer #2786 both applied, JIT cache cleared, FlashInfer reinstalled from PR branch

Decode throughput (single user, 512 output tokens)

Config Run 1 Run 2 Run 3
Baseline (K=128 only) 68.5 tok/s 70.2 tok/s 70.3 tok/s
With K=64 patches 69.9 tok/s 70.2 tok/s 70.2 tok/s

Prefill throughput (3 warmup + 5 measured, max_tokens=1)

Prompt tokens Baseline (K=128) K=64 patch Diff
885 8,761 tok/s 8,767 tok/s +0.1%
1,769 10,507 tok/s 10,563 tok/s +0.5%
3,550 11,855 tok/s 11,967 tok/s +0.9%
7,099 12,445 tok/s 12,549 tok/s +0.8%

All differences are within noise (<1%). Standard deviation across runs was 0.1–3.1ms.

Analysis

The MoE GEMM dimensions for Qwen3.5-397B with TP=4 are:

  • GEMM1 (gate+up): M=tokens, N=512, K=4096
  • GEMM2 (down): M=tokens, N=4096, K=256

For decode (M=1): the operation is a GEMV, entirely memory-bandwidth bound. Pipeline stages (K=64's advantage) don't help because compute is not the bottleneck — reading weights from GDDR7 dominates.

For prefill (M=885–7099): with TP=4, per-GPU M is ~220–1775. Even at these sizes, GDDR7 bandwidth (1.79 TB/s per GPU) appears sufficient that the additional pipeline depth from K=64 tiles doesn't provide measurable benefit over K=128.

Note on hardware difference: Your PR reports 2x speedup on DGX Spark (SM121, unified LPDDR5X). LPDDR5X has significantly lower bandwidth (~273 GB/s) than GDDR7 (1.79 TB/s per GPU × 4 = 7.16 TB/s aggregate). This likely explains why K=64's compute efficiency gains are visible on DGX Spark (compute-bound) but not on discrete RTX PRO 6000 GPUs (memory-bandwidth bound even for prefill).

The patches are functionally correct — JIT compilation succeeds, kernels load, no errors. They just don't provide a throughput benefit on this specific hardware configuration.

@johnnynunez
Copy link
Copy Markdown

@depaulmillz

@depaulmillz
Copy link
Copy Markdown
Contributor

When instantiating the 128x128x64 NVFP4 tile I am seeing refcheck failures with this MR.

For the 128x128x64 with MXFP4, the reason you are hitting issues is the layout for SMEM that you have computed is (((32, 4),1),((32,2),1,1),10) : (((16, 4),256),((0,1),2,256),256).

image

Along the contiguous dimension you are copying 2 elements (shaded in green) then skipping over 2 elements (not shaded) and repeating this pattern. This is not possible to copy since we require 16B of contiguous elements.

To use this tiling pattern, it requires switching to a universal copy atom instead (copies are 2B contiguous) or loading 1 SFA/SFB tile per every 2 MMA tiles at least.

aleozlx pushed a commit to flashinfer-ai/flashinfer that referenced this pull request Apr 1, 2026
…es on SM12x (#2913)

### Summary

- Add missing `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` compile flag to all
CUTLASS fused MoE JIT modules (SM100/SM103/SM120) and
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to SM90 modules
- Sync nv_internal `grid_dependency_control.h` with upstream CUTLASS to
support SM100/SM103/SM110/SM120/SM121 GDC
- Add `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to FP8 blockscale GEMM SM90
module

### Problem

Random `cudaErrorIllegalInstruction` crashes on DGX Spark (SM121) and
RTX 50-series (SM120) when running NVFP4 MoE models (e.g., Nemotron,
Qwen3.5-122B) under load. The crashes are intermittent and worsen with
longer context lengths and higher concurrency.

**Root cause:** PR #2780 fixed the missing GDC compile flags for GEMM
modules (`flashinfer/jit/gemm/core.py`), but the **CUTLASS fused MoE
modules** in `flashinfer/jit/fused_moe.py` and the **FP8 blockscale GEMM
module** were not fixed. This is the exact same class of bug as #2708.

Without `-DCUTLASS_ENABLE_GDC_FOR_SM100=1`, CUTLASS's
`grid_dependency_control.h` compiles `wait_on_dependent_grids()` and
`launch_dependent_grids()` as **empty no-ops**:

```cpp
CUTLASS_DEVICE void wait_on_dependent_grids() {
#if (defined(CUTLASS_GDC_ENABLED))   // ← not defined without the flag
  asm volatile("griddepcontrol.wait;");
#endif
}
```

Meanwhile, the host-side code still sets
`programmaticStreamSerializationAllowed = true` (PDL enabled) via
`device_support_pdl()` which returns `True` for all `major >= 9`,
including SM12x. This means:

1. **Host enables PDL** → CUDA runtime overlaps consecutive kernels
2. **Device GDC barriers are no-ops** → No synchronization between
overlapping kernels
3. **Race condition** → Dependent kernel reads stale global memory →
corruption → `cudaErrorIllegalInstruction`

The crash is random because it depends on exact kernel scheduling
timing, which varies per request.

### Fix

**`flashinfer/jit/fused_moe.py`** — Added GDC flags to all CUTLASS fused
MoE modules:

| Module | Flag | Architectures Covered |
|---|---|---|
| `gen_cutlass_fused_moe_sm120_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM120, SM121 |
| `gen_cutlass_fused_moe_sm103_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM103, SM120, SM121 |
| `gen_cutlass_fused_moe_sm100_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM100, SM110, SM120, SM121 |
| `gen_cutlass_fused_moe_sm90_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` | SM90 |
| `gen_trtllm_gen_fused_moe_sm100_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM100+, SM120, SM121 |

**`flashinfer/jit/gemm/fp8_blockscale.py`** — Added
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to
`gen_fp8_blockscale_gemm_sm90_module()`.

**`csrc/nv_internal/.../grid_dependency_control.h`** — Synced with
upstream CUTLASS
(`3rdparty/cutlass/include/cutlass/arch/grid_dependency_control.h`) to
add SM100+ GDC support. Previously only handled SM90, so any nv_internal
TensorRT-LLM code compiled for SM12x would have GDC barriers silently
compiled as no-ops.

### Why `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` covers SM12x

CUTLASS uses a single flag for the entire Blackwell family. From
`grid_dependency_control.h`:

```cpp
#if(CUDA_BARRIER_ENABLED && defined(CUTLASS_ENABLE_GDC_FOR_SM100) && defined(__CUDA_ARCH__) && \
    ((__CUDA_ARCH__ == 1000 && ...) ||   // SM100
     (__CUDA_ARCH__ == 1030 && ...) ||   // SM103
     (__CUDA_ARCH__ == 1100 && ...) ||   // SM110
     (__CUDA_ARCH__ == 1200 && ...) ||   // SM120 (RTX 50-series)
     (__CUDA_ARCH__ == 1210 && ...)))    // SM121 (DGX Spark)
#define CUTLASS_GDC_ENABLED
```

### Why SM90 GDC flag was NOT added to SM100+ modules

PR #2716 attempted to add both `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` and
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` to all modules. It broke AOT builds
because `sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp`
checks `CUTLASS_ENABLE_GDC_FOR_SM90` and calls
`scheduler.is_last_tile()` — a method not present on the SM120
scheduler. PR #2780 corrected this by using only the SM100 flag for
SM100+ modules. This PR follows the same approach.

### Related

- #2708 — Original issue: missing GDC flags cause PDL race condition
- #2716 — First fix attempt (reverted — broke AOT)
- #2780 — Corrected fix for GEMM modules only
-
[vllm-project/vllm#38423](vllm-project/vllm#38423)
— NVFP4 bugfix on DGX Spark
- [NVIDIA/cutlass#3121](NVIDIA/cutlass#3121) —
K=64 block-scaled GEMM tiles (separate issue)

### Test plan

- [x] Clear JIT cache: `rm -rf ~/.cache/flashinfer/`
- [x] Run NVFP4 MoE model on SM121 (DGX Spark) with 128K context under
load — verify no `cudaErrorIllegalInstruction`
- [x] Run NVFP4 MoE model on SM120 (RTX 50-series) with concurrent
requests — verify no NaN/garbage output
- [x] Verify `CUDA_LAUNCH_BLOCKING=1` workaround is no longer needed
- [x] AOT build with `FLASHINFER_CUDA_ARCH_LIST="12.1a"` completes
without errors
- [x] SM90 (Hopper) fused MoE tests pass: `pytest tests/moe/`
- [x] SM100 GEMM tests still pass (no regression from existing GDC
flags)


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Expanded GPU kernel compilation support: enabled additional
optimizations for NVIDIA SM100 and SM90 GPUs, activating
dependency-control optimizations where available.
* Updated JIT/GEMM build configs to include these architecture-specific
compile options, improving performance and compatibility on supported
hardware.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
@RobTand RobTand force-pushed the fix/sm120-k64-blockscaled-tma-layout branch from 5197972 to ab4a7c2 Compare April 1, 2026 16:48
@RobTand
Copy link
Copy Markdown
Author

RobTand commented Apr 1, 2026

@johnnynunez i updated this and tried to address @depaulmillz 's issues with tiling

@RobTand RobTand force-pushed the fix/sm120-k64-blockscaled-tma-layout branch from ab4a7c2 to 760feb5 Compare April 1, 2026 19:49
@RobTand RobTand changed the title fix: Support K=64 block-scaled GEMM tiles on SM120/SM121 fix: Handle zero-stride TMA basis in fill_tma_gmem_shape_stride Apr 1, 2026
When a block-scaled GEMM layout folds scale factors over a broadcast
dimension (stride == 0), the existing code passes the zero-stride
basis to basis_get(), which produces undefined gmem_prob_shape and
gmem_prob_stride values in the TMA descriptor. This results in
invalid TMA descriptors that cause silent data corruption or runtime
errors.

Fix: detect zero-stride basis via is_constant<0> and emit the
correct degenerate descriptor (shape=1, stride=0) for broadcast
dimensions. This is the standard TMA encoding for a dimension that
contributes no unique data.

The bug is latent in all block-scaled GEMM configurations but only
manifests when K/SFVectorSize creates a broadcast fold — e.g. when
scale factor vector size equals or exceeds the tile K extent.

Signed-off-by: Rob Tand <robert.tand@icloud.com>
@RobTand RobTand force-pushed the fix/sm120-k64-blockscaled-tma-layout branch from 760feb5 to 216faae Compare April 1, 2026 19:51
@RobTand
Copy link
Copy Markdown
Author

RobTand commented Apr 1, 2026

dramatically reduced the scope here since k=64 doesn't seem to improve the performance and significantly increases the complexity. Just fixes the root bug that was causing things to barf

@ccecka
Copy link
Copy Markdown
Contributor

ccecka commented Apr 1, 2026

I believe this should result in a TMA admissibility failure -- the gmem layout and the smem layouts are incompatible for TMA to execute the copy.

This should be a static assertion instead, correct?

@RobTand
Copy link
Copy Markdown
Author

RobTand commented Apr 1, 2026

@ccecka Thanks for taking a look — appreciate the review.

Good question — I initially thought the same, but I think this is actually a valid case rather than an inadmissible one.

The zero-stride basis here comes from a broadcast dimension in block-scaled GEMM scale factor layouts. Concrete example:

With NVF4 (SFVectorSize=32) and a K=64 tile, the scale factor TMA basis decomposition produces a mode where 32 elements share the same scale factor — CuTe encodes this as stride=0. The basis for that mode is _0, meaning "this dimension doesn't index into unique gmem data."

Currently basis_get(_0, gmem_shape) gets called on that zero basis and returns garbage, which corrupts the TMA descriptor. But the intent is sound: the broadcast dimension genuinely contributes no unique data, so the correct TMA encoding is shape=1, stride=0 — "this mode is degenerate, skip it."

The downstream admissibility problem @depaulmillz found (SMEM contiguity < 16B with the 128x128x64 tile) is real but it's in the builder's layout math, not here. The TMA descriptor itself is well-formed with the broadcast handled correctly.

A static_assert would reject any layout that produces a zero-stride basis during TMA setup, which would block legitimate broadcast folds. I think the is_constant<0> check is the right fix at this layer — it makes fill_tma_gmem_shape_stride handle broadcast dimensions the way CuTe already represents them elsewhere.

Happy to be wrong though — if there's a reason broadcast modes should never reach TMA descriptor construction, a static assert makes more sense and the fix should move to the caller.

@RobTand
Copy link
Copy Markdown
Author

RobTand commented Apr 1, 2026

@ccecka That's a really good point and I want to make sure I get this right.

My reasoning was that the zero-stride basis here comes from SFVectorSize broadcast folding — 32 SMEM positions aliased to the same physical address via stride=0, so I thought TMA could serve it by writing once to the aliased address with a degenerate shape=1, stride=0 GMEM dimension. But I may be wrong about how TMA admissibility validates the box shape against the SMEM tile — if it enforces strict shape matching regardless of aliasing, then this would be invalid.

Could you help me understand: does a zero-stride basis in tma_gbasis_stride always indicate an inadmissible layout, or can it arise legitimately for broadcast dimensions? If it's always inadmissible, I'll change this to a static_assert so the incompatibility is caught at compile time rather than papered over. The builder would then need to handle the K < SFVectorSize × Blk_SF case before it reaches TMA descriptor construction.

Really appreciate you taking the time to look at this — your insight on the TMA internals is exactly what I need here.

@ccecka
Copy link
Copy Markdown
Contributor

ccecka commented Apr 1, 2026

To my knowledge, it should always be invalid -- TMA is a glorified memcpy into smem so the contiguous elements that it sees in SMEM must correspond to real elements in gmem.

If you print the smem layout and the gmem layout I suspect you will find the "broadcast" modes of each to be misaligned (either different sizes or corresponding to different coordinates) and, therefore, the source and target are TMA-incompatible.

@RobTand
Copy link
Copy Markdown
Author

RobTand commented Apr 1, 2026

@ccecka appreciate the feedback, will take a look!

@RobTand
Copy link
Copy Markdown
Author

RobTand commented Apr 3, 2026

Closing this out. The underlying zero-stride TMA basis bug is real but only manifests with K=64 tiles (where K/SFVectorSize creates a broadcast fold), and testing showed K=64 doesn't improve performance — we're already at 99% theoretical memory bandwidth with K=128.

ccecka's feedback that this should be a static_assert rather than a graceful workaround is fair, but a compile-time guard for a tile shape nobody is using doesn't add much practical value. If K=64 support is revisited in the future, the builder-level layout fix (avoiding the zero-stride basis entirely) is the right approach, and depaulmillz's SMEM contiguity finding would need to be addressed as well.

@RobTand RobTand closed this Apr 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants