Fix SDPA vmap with GQA/MQA shapes (n_heads != n_kv_heads)#3385
Open
Brooooooklyn wants to merge 1 commit intoml-explore:mainfrom
Open
Fix SDPA vmap with GQA/MQA shapes (n_heads != n_kv_heads)#3385Brooooooklyn wants to merge 1 commit intoml-explore:mainfrom
Brooooooklyn wants to merge 1 commit intoml-explore:mainfrom
Conversation
4e42445 to
6fbbec5
Compare
The ScaledDotProductAttention primitive relied on Custom::vmap which re-vmapped the fallback lambda. That lambda captured n_q_heads and n_kv_heads at creation time, causing shape mismatches (SIGSEGV/hang) when vmap changed the array dimensions. Add a dedicated vmap override that merges the vmap axis into the batch dimension and re-invokes scaled_dot_product_attention, which recomputes head counts from actual shapes. Falls back to Custom::vmap for sinks.
6fbbec5 to
14700f3
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Proposed changes
ScaledDotProductAttentionrelied onCustom::vmapwhich re-vmapped thefallback lambda. This always took the decomposed matmul-softmax-matmul
path, bypassing the fused Metal/CUDA kernel even when it was available.
On MLX 0.31.1 this also caused SIGSEGV/hang with GQA shapes due to a
since-fixed transforms infrastructure bug.
Add a dedicated vmap override that merges the vmap axis into the batch
dimension and re-invokes
scaled_dot_product_attentiondirectly, so thefused kernel is dispatched under vmap just as it is without vmap.
Falls back to
Custom::vmapwhen attention sinks are present.Close #3383
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes