Skip to content

Fix int16 overflow in SDPA NAX mask indexing for KV sequences > 32K#3361

Merged
angeloskath merged 5 commits intoml-explore:mainfrom
Clydingus:fix/sdpa-nax-int16-overflow
Apr 10, 2026
Merged

Fix int16 overflow in SDPA NAX mask indexing for KV sequences > 32K#3361
angeloskath merged 5 commits intoml-explore:mainfrom
Clydingus:fix/sdpa-nax-int16-overflow

Conversation

@Clydingus
Copy link
Copy Markdown
Contributor

Proposed changes

Fixes #3360.

Test

import mlx.core as mx
import numpy as np

D_HEAD = 64

def test(cap, active, n_heads=32, n_kv=16, t=512, dtype=mx.bfloat16):
    np.random.seed(0)
    q = mx.array(np.random.randn(1, n_heads, t, D_HEAD).astype(np.float32)).astype(dtype)
    k = mx.array(np.random.randn(1, n_kv, cap, D_HEAD).astype(np.float32)).astype(dtype)
    v = mx.array(np.random.randn(1, n_kv, cap, D_HEAD).astype(np.float32)).astype(dtype)
    mask = mx.full((1, 1, 1, cap), -1e4).astype(dtype)
    mask[:, :, :, cap - active:] = 0.0
    mx.eval(q, k, v, mask)

    y_fast = mx.fast.scaled_dot_product_attention(q, k, v, scale=D_HEAD**-0.5, mask=mask)
    mx.eval(y_fast)

    kk = mx.repeat(k, n_heads // n_kv, axis=1) if n_heads != n_kv else k
    vv = mx.repeat(v, n_heads // n_kv, axis=1) if n_heads != n_kv else v
    scores = (q @ kk.transpose(0, 1, 3, 2)) * (D_HEAD**-0.5) + mask
    y_ref = mx.softmax(scores.astype(mx.float32), axis=-1).astype(dtype) @ vv
    mx.eval(y_ref)

    yf = np.array(y_fast.astype(mx.float32))
    yr = np.array(y_ref.astype(mx.float32))
    ratio = np.nanmax(np.abs(yf)) / (np.max(np.abs(yr)) + 1e-10)
    ok = abs(ratio - 1.0) < 0.05
    status = "PASS" if ok else f"FAIL ({ratio:.3f})"
    print(f"  cap={cap:6d}  active={active:5d}  {status}")
    return ok

print(f"MLX {mx.__version__}\n")
n_fail = 0
for cap in [8192, 32768, 36864, 49152, 66048]:
    if not test(cap, 1024): n_fail += 1
print(f"\n{'ALL PASS' if n_fail == 0 else f'{n_fail} FAILED'}")

Correct output:

  cap=  8192  active= 1024  PASS
  cap= 32768  active= 1024  PASS
  cap= 36864  active= 1024  PASS
  cap= 49152  active= 1024  PASS
  cap= 66048  active= 1024  PASS

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test?

@Clydingus
Copy link
Copy Markdown
Contributor Author

Sorry, missed that😓. Added now. Tried adding it under the existing test_sdpa function, but the divergence mainly shows up on sparse masks at long sequence lengths. Let me know if any other changes required.

@Clydingus Clydingus requested a review from zcbenz April 8, 2026 09:19
Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for finding that.

There was a pretty big slow-down for masked attention due to constant bounds checking for the mask (which became worse by switching to integers). I changed that to only bounds check at the edges otherwise load as is.

Results (lower is better)

Shape (Q K D)      Before   After
4096  5000  64      2.028    1.366
2048 32121  64      9.175    5.664
4096  5000 128      4.538    3.313
2048 32121 128     18.012   11.511

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

@angeloskath angeloskath merged commit a33b791 into ml-explore:main Apr 10, 2026
16 checks passed
@Clydingus Clydingus deleted the fix/sdpa-nax-int16-overflow branch April 10, 2026 08:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] SDPA NAX kernel: int16 overflow in mask col_pos for KV sequences > 32K

3 participants