Which component has the problem?
CuTe DSL
Bug Report
Describe the bug
Previous in cute dsl 4.3.5, there was an issue (#2880) where a Python ternary expression (e.g. a = b if c else d) lowers into select, which means both b and d are evaluated and it can cause IMA.
The fix introduced in 4.4.1 is to lower into function blocks dispatched through scf.IfOp. The problem is this can cost more registers.
Claude:
At the MLIR level, `ifExp_executor` calls `_ifexp_execute_dynamic`, which creates an `scf.IfOp` with two regions (then/else). This means:
1. **Both branches are fully materialized** as separate MLIR regions with `scf.yield`
2. **The compiler must emit branch instructions** (`BRA`, `BRX`) instead of a single predicated move
3. **Register liveness spans both branches**, potentially increasing register pressure at merge points
4. **The PTX optimizer may not be able to lower this back to a `selp`** (select with predicate) instruction
In contrast, `cutlass.select_(cond, true_val, false_val)` lowers directly to `arith.select`, which becomes a single `selp` instruction in PTX — branchless, no register pressure overhead.
Steps/Code to reproduce bug
We struggle to find a nice repro, so this is the best of what we have got: We tested the Flash Attention 4 forward kernel (SM100/Blackwell, hdim=128, bf16) with paged KV (page_size=128) and local/sliding window attention (window_size=(2047, 0)), which exercises many runtime ternaries in the masking, softmax, and paged KV code paths. We use 4.4.1 as baseline and compare trunk vs replacing every ternary with select_:
Register analysis (paged KV + local window)
Using #2658 (comment) and nvdisasm -plr -lrm count on the compiled CUBIN to measure per-warp-role register usage:
Baseline (default visit_IfExp with scf.IfOp):
Warp Role Warps Budget Max Idx Max Live Margin
----------------------------------------------------------------------
Load (TMA) 14 48 R19 63 -15
MMA 12 48 R17 15 +33
Softmax (s0+s1) 0-7 192 R171 160 +32
Correction 8-11 80 R63 146 -66
Empty 15 48 R2 2 +46
num_regs=128, local_size_bytes=0, perf=0.64ms
With manual cutlass.select_() rewrites:
Warp Role Warps Budget Max Idx Max Live Margin
----------------------------------------------------------------------
Load (TMA) 14 48 R19 63 -15
MMA 12 48 R17 15 +33
Softmax (s0+s1) 0-7 192 R171 159 +33
Correction 8-11 80 R63 146 -66
Empty 15 48 R2 2 +46
num_regs=128, local_size_bytes=0, perf=0.64ms
In this particular kernel configuration, the codegen difference is small (softmax Max Live: 160 → 159). The PTX optimizer is able to lower most of the scf.IfOp patterns back to predicated instructions. However:
- The
scf.IfOp path relies on the optimizer to recover what should have been a simple selp from the start
- In more complex kernels or with more aggressive ternary usage, the optimizer may not recover
- The function-block approach adds compilation overhead (more IR to generate and optimize)
Workaround
Users can replace runtime ternaries with explicit cutlass.select_() calls:
# Before (generates scf.IfOp via visit_IfExp):
result = value_a if runtime_cond else value_b
# After (generates arith.select directly):
result = cutlass.select_(runtime_cond, value_a, value_b)
Note: cutlass.select_ evaluates both branches (no short-circuit), so it can result in IMA if not careful.
Environment details (please complete the following information):
- cutlass-dsl 4.4.1 (cu 12.9 version)
- GPU: NVIDIA GB200 (SM100)
- Kernel: Flash Attention 4 forward (paged KV, local/sliding window)
- torch cu128
thanks @v0i0, @drisspg, @Alkaid-Benetnash for pointing me to right direction
Which component has the problem?
CuTe DSL
Bug Report
Describe the bug
Previous in cute dsl 4.3.5, there was an issue (#2880) where a Python ternary expression (e.g.
a = b if c else d) lowers into select, which means both b and d are evaluated and it can cause IMA.The fix introduced in 4.4.1 is to lower into function blocks dispatched through
scf.IfOp. The problem is this can cost more registers.Claude:
Steps/Code to reproduce bug
We struggle to find a nice repro, so this is the best of what we have got: We tested the Flash Attention 4 forward kernel (SM100/Blackwell, hdim=128, bf16) with paged KV (page_size=128) and local/sliding window attention (window_size=(2047, 0)), which exercises many runtime ternaries in the masking, softmax, and paged KV code paths. We use 4.4.1 as baseline and compare trunk vs replacing every ternary with select_:
Register analysis (paged KV + local window)
Using #2658 (comment) and
nvdisasm -plr -lrm counton the compiled CUBIN to measure per-warp-role register usage:Baseline (default
visit_IfExpwithscf.IfOp):With manual
cutlass.select_()rewrites:In this particular kernel configuration, the codegen difference is small (softmax Max Live: 160 → 159). The PTX optimizer is able to lower most of the
scf.IfOppatterns back to predicated instructions. However:scf.IfOppath relies on the optimizer to recover what should have been a simpleselpfrom the startWorkaround
Users can replace runtime ternaries with explicit
cutlass.select_()calls:Note:
cutlass.select_evaluates both branches (no short-circuit), so it can result in IMA if not careful.Environment details (please complete the following information):
thanks @v0i0, @drisspg, @Alkaid-Benetnash for pointing me to right direction