Skip to content

[BUG] visit_IfExp fix in 4.4.1 costs more registers #3122

@henrylhtsang

Description

@henrylhtsang

Which component has the problem?

CuTe DSL

Bug Report

Describe the bug

Previous in cute dsl 4.3.5, there was an issue (#2880) where a Python ternary expression (e.g. a = b if c else d) lowers into select, which means both b and d are evaluated and it can cause IMA.

The fix introduced in 4.4.1 is to lower into function blocks dispatched through scf.IfOp. The problem is this can cost more registers.

Claude:

At the MLIR level, `ifExp_executor` calls `_ifexp_execute_dynamic`, which creates an `scf.IfOp` with two regions (then/else). This means:

1. **Both branches are fully materialized** as separate MLIR regions with `scf.yield`
2. **The compiler must emit branch instructions** (`BRA`, `BRX`) instead of a single predicated move
3. **Register liveness spans both branches**, potentially increasing register pressure at merge points
4. **The PTX optimizer may not be able to lower this back to a `selp`** (select with predicate) instruction

In contrast, `cutlass.select_(cond, true_val, false_val)` lowers directly to `arith.select`, which becomes a single `selp` instruction in PTX — branchless, no register pressure overhead.

Steps/Code to reproduce bug

We struggle to find a nice repro, so this is the best of what we have got: We tested the Flash Attention 4 forward kernel (SM100/Blackwell, hdim=128, bf16) with paged KV (page_size=128) and local/sliding window attention (window_size=(2047, 0)), which exercises many runtime ternaries in the masking, softmax, and paged KV code paths. We use 4.4.1 as baseline and compare trunk vs replacing every ternary with select_:

Register analysis (paged KV + local window)

Using #2658 (comment) and nvdisasm -plr -lrm count on the compiled CUBIN to measure per-warp-role register usage:

Baseline (default visit_IfExp with scf.IfOp):

Warp Role                Warps    Budget  Max Idx Max Live  Margin
----------------------------------------------------------------------
Load (TMA)               14           48      R19       63     -15
MMA                      12           48      R17       15     +33
Softmax (s0+s1)          0-7         192     R171      160     +32
Correction               8-11         80      R63      146     -66
Empty                    15           48       R2        2     +46

num_regs=128, local_size_bytes=0, perf=0.64ms

With manual cutlass.select_() rewrites:

Warp Role                Warps    Budget  Max Idx Max Live  Margin
----------------------------------------------------------------------
Load (TMA)               14           48      R19       63     -15
MMA                      12           48      R17       15     +33
Softmax (s0+s1)          0-7         192     R171      159     +33
Correction               8-11         80      R63      146     -66
Empty                    15           48       R2        2     +46

num_regs=128, local_size_bytes=0, perf=0.64ms

In this particular kernel configuration, the codegen difference is small (softmax Max Live: 160 → 159). The PTX optimizer is able to lower most of the scf.IfOp patterns back to predicated instructions. However:

  1. The scf.IfOp path relies on the optimizer to recover what should have been a simple selp from the start
  2. In more complex kernels or with more aggressive ternary usage, the optimizer may not recover
  3. The function-block approach adds compilation overhead (more IR to generate and optimize)

Workaround

Users can replace runtime ternaries with explicit cutlass.select_() calls:

# Before (generates scf.IfOp via visit_IfExp):
result = value_a if runtime_cond else value_b

# After (generates arith.select directly):
result = cutlass.select_(runtime_cond, value_a, value_b)

Note: cutlass.select_ evaluates both branches (no short-circuit), so it can result in IMA if not careful.

Environment details (please complete the following information):

  • cutlass-dsl 4.4.1 (cu 12.9 version)
  • GPU: NVIDIA GB200 (SM100)
  • Kernel: Flash Attention 4 forward (paged KV, local/sliding window)
  • torch cu128

thanks @v0i0, @drisspg, @Alkaid-Benetnash for pointing me to right direction

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions