Skip to content

Fix SDPA vmap with GQA/MQA shapes (n_heads != n_kv_heads)#3385

Open
Brooooooklyn wants to merge 1 commit intoml-explore:mainfrom
mlx-node:fix/sdpa-vmap-gqa
Open

Fix SDPA vmap with GQA/MQA shapes (n_heads != n_kv_heads)#3385
Brooooooklyn wants to merge 1 commit intoml-explore:mainfrom
mlx-node:fix/sdpa-vmap-gqa

Conversation

@Brooooooklyn
Copy link
Copy Markdown
Contributor

@Brooooooklyn Brooooooklyn commented Apr 8, 2026

Proposed changes

ScaledDotProductAttention relied on Custom::vmap which re-vmapped the
fallback lambda. This always took the decomposed matmul-softmax-matmul
path, bypassing the fused Metal/CUDA kernel even when it was available.
On MLX 0.31.1 this also caused SIGSEGV/hang with GQA shapes due to a
since-fixed transforms infrastructure bug.

Add a dedicated vmap override that merges the vmap axis into the batch
dimension and re-invokes scaled_dot_product_attention directly, so the
fused kernel is dispatched under vmap just as it is without vmap.
Falls back to Custom::vmap when attention sinks are present.

Close #3383

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

The ScaledDotProductAttention primitive relied on Custom::vmap which
re-vmapped the fallback lambda. That lambda captured n_q_heads and
n_kv_heads at creation time, causing shape mismatches (SIGSEGV/hang)
when vmap changed the array dimensions.

Add a dedicated vmap override that merges the vmap axis into the batch
dimension and re-invokes scaled_dot_product_attention, which recomputes
head counts from actual shapes. Falls back to Custom::vmap for sinks.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SIGSEGV / hang in mx.fast.scaled_dot_product_attention under vmap with GQA/MQA shapes (n_heads != n_kv_heads)

1 participant