Skip to content

[BUG] SDPA NAX kernel: int16 overflow in mask col_pos for KV sequences > 32K #3360

@Clydingus

Description

@Clydingus

Describe the bug
mx.fast.scaled_dot_product_attention produces incorrect output (2-10x magnitude error) when the KV sequence length exceeds 32,768 and an additive mask deactivates most positions. The bug is in the NAX attention kernel only — the non-NAX variant is correct.

In mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h, lines 324-325:

const short row_pos = base_row + iq * kU;
const short col_pos = base_col + ik * kU;

base_col is int (up to kL, e.g. 66048), but col_pos is short (max 32767). When base_col > 32767, the value wraps negative, causing the mask to be loaded from wrong positions.

The non-NAX variant (steel_attention.h line 345) already uses int for col_pos.

This is a similar bug as #2894 / PR #2903 (int32 overflow in mask strides), but under the NAX code path.

To Reproduce

import mlx.core as mx
import numpy as np

D_HEAD = 64

def test(cap, active, n_heads=32, n_kv=16, t=512, dtype=mx.bfloat16):
    np.random.seed(0)
    q = mx.array(np.random.randn(1, n_heads, t, D_HEAD).astype(np.float32)).astype(dtype)
    k = mx.array(np.random.randn(1, n_kv, cap, D_HEAD).astype(np.float32)).astype(dtype)
    v = mx.array(np.random.randn(1, n_kv, cap, D_HEAD).astype(np.float32)).astype(dtype)
    mask = mx.full((1, 1, 1, cap), -1e4).astype(dtype)
    mask[:, :, :, cap - active:] = 0.0
    mx.eval(q, k, v, mask)

    y_fast = mx.fast.scaled_dot_product_attention(q, k, v, scale=D_HEAD**-0.5, mask=mask)
    mx.eval(y_fast)

    kk = mx.repeat(k, n_heads // n_kv, axis=1) if n_heads != n_kv else k
    vv = mx.repeat(v, n_heads // n_kv, axis=1) if n_heads != n_kv else v
    scores = (q @ kk.transpose(0, 1, 3, 2)) * (D_HEAD**-0.5) + mask
    y_ref = mx.softmax(scores.astype(mx.float32), axis=-1).astype(dtype) @ vv
    mx.eval(y_ref)

    yf = np.array(y_fast.astype(mx.float32))
    yr = np.array(y_ref.astype(mx.float32))
    ratio = np.nanmax(np.abs(yf)) / (np.max(np.abs(yr)) + 1e-10)
    ok = abs(ratio - 1.0) < 0.05
    status = "PASS" if ok else f"FAIL ({ratio:.3f})"
    print(f"  cap={cap:6d}  active={active:5d}  {status}")
    return ok

print(f"MLX {mx.__version__}\n")
n_fail = 0
for cap in [8192, 32768, 36864, 49152, 66048]:
    if not test(cap, 1024): n_fail += 1
print(f"\n{'ALL PASS' if n_fail == 0 else f'{n_fail} FAILED'}")

Expected output:

  cap=  8192  active= 1024  PASS
  cap= 32768  active= 1024  PASS
  cap= 36864  active= 1024  PASS
  cap= 49152  active= 1024  PASS
  cap= 66048  active= 1024  PASS

Actual output:

  cap=  8192  active= 1024  PASS
  cap= 32768  active= 1024  PASS
  cap= 36864  active= 1024  FAIL (0.719)
  cap= 49152  active= 1024  FAIL (0.219)
  cap= 66048  active= 1024  FAIL (0.123)

Expected behavior
A clear and concise description of what you expected to happen.

Desktop (please complete the following information):

  • MLX version: 0.31.2 (commit 6a9a121)
  • Apple M5 Max
  • macOS 26.3.1

Additional context

Proposed Fix (PR #3361)
Change short to int on lines 324-325:

-          const short row_pos = base_row + iq * kU;
-          const short col_pos = base_col + ik * kU;
+          const int row_pos = base_row + iq * kU;
+          const int col_pos = base_col + ik * kU;

There are two occurrences of this pattern in the file (mask loading in the aligned and non-aligned branches).

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions