Skip to content

feat: primus-turbo attn add sbhd format support#650

Open
RuibinCheung wants to merge 6 commits intomainfrom
dev/zhangrb/add_sbhd_format_support
Open

feat: primus-turbo attn add sbhd format support#650
RuibinCheung wants to merge 6 commits intomainfrom
dev/zhangrb/add_sbhd_format_support

Conversation

@RuibinCheung
Copy link
Copy Markdown
Contributor

@RuibinCheung RuibinCheung commented Apr 8, 2026

  • Primus Turbo attention add sbhd format support. Eliminate extra transpose kernel call in Attention.
image

Copilot AI review requested due to automatic review settings April 8, 2026 11:30
@RuibinCheung RuibinCheung changed the title [No Merge][WIP] feat: add sbhd format support [No Merge][WIP] feat: primus-turbo attn add sbhd format support Apr 8, 2026
@RuibinCheung RuibinCheung marked this pull request as draft April 8, 2026 11:31
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds experimental support for additional QKV tensor layouts (notably sbhd) in the Primus Turbo attention wrapper, while introducing special-casing for sink attention to force a specific layout.

Changes:

  • Removes the previous manual sbhd -> bshd transpose and instead forwards qkv_format into the underlying flash_attn op.
  • Introduces a use_sink_attn flag and forces sink-attention execution to use bshd, including explicit tensor permutations for Q/K/V and the output.

Comment on lines +504 to +507
# NOTE: sink attention only support bshd format
query = query.permute(1, 0, 2, 3).contiguous()
key = key.permute(1, 0, 2, 3).contiguous()
value = value.permute(1, 0, 2, 3).contiguous()
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When use_sink_attn is enabled, query/key/value are always permuted as if the incoming layout were sbhd (S,B,H,D) -> bshd (B,S,H,D). If qkv_format is already bshd (or any non-sbhd value coming from packed_seq_params), this permutation will corrupt the tensor layout while qkv_format is forced to "bshd", creating a format/tensor mismatch.

Consider either (a) explicitly asserting qkv_format == "sbhd" before permuting in the sink-attention path, or (b) permuting conditionally based on qkv_format and ensuring o is permuted back consistently to preserve the expected output layout.

Copilot uses AI. Check for mistakes.
@@ -465,9 +465,6 @@ def forward(
)

qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qkv_format is now passed through to self.attn(...) without any local validation. Since this value can come from packed_seq_params, an unexpected value will likely fail deeper in the kernel with a less actionable error (and could also interact badly with the explicit format conversions in the sink-attention path).

Recommend validating qkv_format against the set of supported formats in this backend and raising a clear ValueError (or keeping an assert) before using it.

Suggested change
qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
supported_qkv_formats = ("sbhd", "bshd", "thd")
if qkv_format not in supported_qkv_formats:
raise ValueError(
f"Unsupported qkv_format: {qkv_format}. "
f"Supported formats: {supported_qkv_formats}"
)

Copilot uses AI. Check for mistakes.
else:
window_size = (self.sink_sliding_window, 0)

# NOTE: sink attention only support bshd format
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo/grammar in the new comment: "sink attention only support bshd format" → "sink attention only supports bshd format".

Suggested change
# NOTE: sink attention only support bshd format
# NOTE: sink attention only supports bshd format

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings April 17, 2026 09:16
@RuibinCheung RuibinCheung marked this pull request as ready for review April 17, 2026 09:17
@RuibinCheung RuibinCheung changed the title [No Merge][WIP] feat: primus-turbo attn add sbhd format support feat: primus-turbo attn add sbhd format support Apr 17, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

if qkv_format == "sbhd":
query = query.permute(1, 0, 2, 3)
key = key.permute(1, 0, 2, 3)
value = value.permute(1, 0, 2, 3)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qkv_format is now only handled for the "sbhd" case; any other value falls through and is treated as if inputs are already BSHD, which can silently produce wrong results if qkv_format is misspelled/unsupported. Please add explicit validation (e.g., accept only sbhd and bshd, and raise ValueError otherwise) rather than relying on implicit fallthrough or assert.

Suggested change
value = value.permute(1, 0, 2, 3)
value = value.permute(1, 0, 2, 3)
elif qkv_format == "bshd":
pass
else:
raise ValueError(
f"Unsupported qkv_format: {qkv_format!r}. Expected one of ('sbhd', 'bshd')."
)

Copilot uses AI. Check for mistakes.
Comment on lines +537 to +541
if qkv_format == "sbhd":
query = query.permute(1, 0, 2, 3)
key = key.permute(1, 0, 2, 3)
value = value.permute(1, 0, 2, 3)

Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds conditional layout handling based on qkv_format, but there are no tests covering the new sbhd/bshd behavior. Please add a focused unit test that sets qkv_format to both values and asserts the returned tensor layout/shape matches expectations.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants