[ET-VK][embedding] Enable embedding weight dedup with tied linear weights#18360
[ET-VK][embedding] Enable embedding weight dedup with tied linear weights#18360
Conversation
…ghts
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18360
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 3 Unrelated FailuresAs of commit 58eb8af with merge base 38b40bc ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
… dedup with tied linear weights"
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
… linear weights"
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
… dedup with tied linear weights"
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
… linear weights"
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
… dedup with tied linear weights"
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
… linear weights"
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
… dedup with tied linear weights"
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
… linear weights"
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
`is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume
weights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
- custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference
impl, swapping nibble extraction order when True
- patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that
unpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
`is_linear_weight=True` to the op
- embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()`
returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by `pack_q4_linear_weight`
- embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x
{buffer, texture2d} weight storage
- EmbeddingQ4gsw.cpp: When `is_linear_weight`, call
`prepack_quantized_linear_weight` (shared with linear op) instead of
`prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight
tensor has a different shape
- test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update
reference function, add 6 new test cases
Authored with Claude.
Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/)
[ghstack-poisoned]
Stack from ghstack (oldest at bottom):
When embedding and output linear weights are tied (same underlying tensor, as in
Llama 3.2 1B), they are quantized independently with opposite nibble packing
conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an
is_linear_weightflag toembedding_q4gswso the embedding op can consumeweights packed in the linear convention, enabling dedup and saving ~125 MB.
The implementation spans Python (detection + repacking), GLSL (packed block
format reading), and C++ (shared prepacking with linear):
is_linear_weightparam to op signature and referenceimpl, swapping nibble extraction order when True
_detect_tied_linear_weight()thatunpacks the embedding weight and compares against int8 linear weight
placeholders. When matched, repack using linear convention and pass
is_linear_weight=Trueto the opload_embedding_weights()returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from
the block-interleaved format produced by
pack_q4_linear_weight{buffer, texture2d} weight storage
is_linear_weight, callprepack_quantized_linear_weight(shared with linear op) instead ofprepack_standard. Passembed_dimas resize_arg since the prepacked weighttensor has a different shape
is_linear_weightto EmbeddingConfig, updatereference function, add 6 new test cases
Authored with Claude.
Differential Revision: D97430803