Skip to content

[ET-VK][embedding] Enable embedding weight dedup with tied linear weights#18360

Open
SS-JIA wants to merge 5 commits intogh/SS-JIA/498/basefrom
gh/SS-JIA/498/head
Open

[ET-VK][embedding] Enable embedding weight dedup with tied linear weights#18360
SS-JIA wants to merge 5 commits intogh/SS-JIA/498/basefrom
gh/SS-JIA/498/head

Conversation

@SS-JIA
Copy link
Contributor

@SS-JIA SS-JIA commented Mar 20, 2026

Stack from ghstack (oldest at bottom):

When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
is_linear_weight flag to embedding_q4gsw so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.

The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):

  • custom_ops_lib.py: Add is_linear_weight param to op signature and reference
    impl, swapping nibble extraction order when True
  • patterns/quantized_embedding.py: Add _detect_tied_linear_weight() that
    unpacks the embedding weight and compares against int8 linear weight
    placeholders. When matched, repack using linear convention and pass
    is_linear_weight=True to the op
  • embedding_q4gsw.glsl: Extract weight loading into load_embedding_weights()
    returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
    the block-interleaved format produced by pack_q4_linear_weight
  • embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
    {buffer, texture2d} weight storage
  • EmbeddingQ4gsw.cpp: When is_linear_weight, call
    prepack_quantized_linear_weight (shared with linear op) instead of
    prepack_standard. Pass embed_dim as resize_arg since the prepacked weight
    tensor has a different shape
  • test_embedding_q4gsw.cpp: Add is_linear_weight to EmbeddingConfig, update
    reference function, add 6 new test cases

Authored with Claude.

Differential Revision: D97430803

…ghts

When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.

The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):

- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
  impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
  unpacks the embedding weight and compares against int8 linear weight
  placeholders. When matched, repack using linear convention and pass
  `is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
  returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
  the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
  {buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
  `prepack_quantized_linear_weight` (shared with linear op) instead of
  `prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
  tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
  reference function, add 6 new test cases

Authored with Claude.

Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 20, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18360

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 17 Pending

As of commit 58eb8af with merge base 38b40bc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 20, 2026
@github-actions
Copy link

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

ssjia added 4 commits March 20, 2026 07:47
… linear weights"

When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.

The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):

- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
  impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
  unpacks the embedding weight and compares against int8 linear weight
  placeholders. When matched, repack using linear convention and pass
  `is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
  returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
  the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
  {buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
  `prepack_quantized_linear_weight` (shared with linear op) instead of
  `prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
  tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
  reference function, add 6 new test cases

Authored with Claude.

Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)

[ghstack-poisoned]
… linear weights"

When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.

The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):

- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
  impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
  unpacks the embedding weight and compares against int8 linear weight
  placeholders. When matched, repack using linear convention and pass
  `is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
  returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
  the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
  {buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
  `prepack_quantized_linear_weight` (shared with linear op) instead of
  `prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
  tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
  reference function, add 6 new test cases

Authored with Claude.

Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)

[ghstack-poisoned]
… linear weights"

When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.

The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):

- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
  impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
  unpacks the embedding weight and compares against int8 linear weight
  placeholders. When matched, repack using linear convention and pass
  `is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
  returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
  the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
  {buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
  `prepack_quantized_linear_weight` (shared with linear op) instead of
  `prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
  tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
  reference function, add 6 new test cases

Authored with Claude.

Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)

[ghstack-poisoned]
… linear weights"

When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.

The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):

- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
  impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
  unpacks the embedding weight and compares against int8 linear weight
  placeholders. When matched, repack using linear convention and pass
  `is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
  returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
  the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
  {buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
  `prepack_quantized_linear_weight` (shared with linear op) instead of
  `prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
  tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
  reference function, add 6 new test cases

Authored with Claude.

Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants