[ET-VK][embedding] Enable embedding weight dedup with tied linear weights#18360
Open
SS-JIA wants to merge 5 commits intogh/SS-JIA/498/basefrom
Open
[ET-VK][embedding] Enable embedding weight dedup with tied linear weights#18360SS-JIA wants to merge 5 commits intogh/SS-JIA/498/basefrom
SS-JIA wants to merge 5 commits intogh/SS-JIA/498/basefrom
Conversation
…ghts
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18360
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 17 PendingAs of commit 58eb8af with merge base 38b40bc ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
added 4 commits
March 20, 2026 07:47
… linear weights"
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
… linear weights"
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
… linear weights"
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
… linear weights"
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
manuelcandales
approved these changes
Mar 20, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
is_linear_weightflag toembedding_q4gswso the embedding op can consumeweights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
is_linear_weightparam to op signature and referenceimpl, swapping nibble extraction order when True
_detect_tied_linear_weight()thatunpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
is_linear_weight=Trueto the opload_embedding_weights()returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by
pack_q4_linear_weight{buffer, texture2d} weight storage
is_linear_weight, callprepack_quantized_linear_weight(shared with linear op) instead ofprepack_standard. Passembed_dimas resize_arg since the prepacked weighttensor has a different shape
is_linear_weightto EmbeddingConfig, updatereference function, add 6 new test cases
Authored with Claude.
Differential Revision: D97430803