feat: primus-turbo attn add sbhd format support#650
feat: primus-turbo attn add sbhd format support#650RuibinCheung wants to merge 6 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds experimental support for additional QKV tensor layouts (notably sbhd) in the Primus Turbo attention wrapper, while introducing special-casing for sink attention to force a specific layout.
Changes:
- Removes the previous manual
sbhd -> bshdtranspose and instead forwardsqkv_formatinto the underlyingflash_attnop. - Introduces a
use_sink_attnflag and forces sink-attention execution to usebshd, including explicit tensor permutations for Q/K/V and the output.
| # NOTE: sink attention only support bshd format | ||
| query = query.permute(1, 0, 2, 3).contiguous() | ||
| key = key.permute(1, 0, 2, 3).contiguous() | ||
| value = value.permute(1, 0, 2, 3).contiguous() |
There was a problem hiding this comment.
When use_sink_attn is enabled, query/key/value are always permuted as if the incoming layout were sbhd (S,B,H,D) -> bshd (B,S,H,D). If qkv_format is already bshd (or any non-sbhd value coming from packed_seq_params), this permutation will corrupt the tensor layout while qkv_format is forced to "bshd", creating a format/tensor mismatch.
Consider either (a) explicitly asserting qkv_format == "sbhd" before permuting in the sink-attention path, or (b) permuting conditionally based on qkv_format and ensuring o is permuted back consistently to preserve the expected output layout.
| @@ -465,9 +465,6 @@ def forward( | |||
| ) | |||
|
|
|||
| qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format) | |||
There was a problem hiding this comment.
qkv_format is now passed through to self.attn(...) without any local validation. Since this value can come from packed_seq_params, an unexpected value will likely fail deeper in the kernel with a less actionable error (and could also interact badly with the explicit format conversions in the sink-attention path).
Recommend validating qkv_format against the set of supported formats in this backend and raising a clear ValueError (or keeping an assert) before using it.
| qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format) | |
| qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format) | |
| supported_qkv_formats = ("sbhd", "bshd", "thd") | |
| if qkv_format not in supported_qkv_formats: | |
| raise ValueError( | |
| f"Unsupported qkv_format: {qkv_format}. " | |
| f"Supported formats: {supported_qkv_formats}" | |
| ) |
| else: | ||
| window_size = (self.sink_sliding_window, 0) | ||
|
|
||
| # NOTE: sink attention only support bshd format |
There was a problem hiding this comment.
Typo/grammar in the new comment: "sink attention only support bshd format" → "sink attention only supports bshd format".
| # NOTE: sink attention only support bshd format | |
| # NOTE: sink attention only supports bshd format |
…grb/add_sbhd_format_support
| if qkv_format == "sbhd": | ||
| query = query.permute(1, 0, 2, 3) | ||
| key = key.permute(1, 0, 2, 3) | ||
| value = value.permute(1, 0, 2, 3) |
There was a problem hiding this comment.
qkv_format is now only handled for the "sbhd" case; any other value falls through and is treated as if inputs are already BSHD, which can silently produce wrong results if qkv_format is misspelled/unsupported. Please add explicit validation (e.g., accept only sbhd and bshd, and raise ValueError otherwise) rather than relying on implicit fallthrough or assert.
| value = value.permute(1, 0, 2, 3) | |
| value = value.permute(1, 0, 2, 3) | |
| elif qkv_format == "bshd": | |
| pass | |
| else: | |
| raise ValueError( | |
| f"Unsupported qkv_format: {qkv_format!r}. Expected one of ('sbhd', 'bshd')." | |
| ) |
| if qkv_format == "sbhd": | ||
| query = query.permute(1, 0, 2, 3) | ||
| key = key.permute(1, 0, 2, 3) | ||
| value = value.permute(1, 0, 2, 3) | ||
|
|
There was a problem hiding this comment.
This adds conditional layout handling based on qkv_format, but there are no tests covering the new sbhd/bshd behavior. Please add a focused unit test that sets qkv_format to both values and asserts the returned tensor layout/shape matches expectations.
Uh oh!
There was an error while loading. Please reload this page.