From 8343a97de28f37d09789e06b8300da933e5ffbd0 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 19 Mar 2026 23:14:09 -0700 Subject: [PATCH 1/4] [ET-VK][embedding] Enable embedding weight dedup with tied linear weights When embedding and output linear weights are tied (same underlying tensor, as in Llama 3.2 1B), they are quantized independently with opposite nibble packing conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an `is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume weights packed in the linear convention, enabling dedup and saving ~125 MB. The implementation spans Python (detection + repacking), GLSL (packed block format reading), and C++ (shared prepacking with linear): - custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference impl, swapping nibble extraction order when True - patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that unpacks the embedding weight and compares against int8 linear weight placeholders. When matched, repack using linear convention and pass `is_linear_weight=True` to the op - embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()` returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from the block-interleaved format produced by `pack_q4_linear_weight` - embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x {buffer, texture2d} weight storage - EmbeddingQ4gsw.cpp: When `is_linear_weight`, call `prepack_quantized_linear_weight` (shared with linear op) instead of `prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight tensor has a different shape - test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update reference function, add 6 new test cases Authored with Claude. Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/) [ghstack-poisoned] --- .../vulkan/_passes/insert_prepack_nodes.py | 6 +- backends/vulkan/custom_ops_lib.py | 40 ++ backends/vulkan/op_registry.py | 32 ++ backends/vulkan/patterns/BUCK | 1 + backends/vulkan/patterns/__init__.py | 2 + .../vulkan/patterns/quantized_embedding.py | 175 ++++++ .../graph/ops/glsl/embedding_q4gsw.glsl | 165 ++++++ .../graph/ops/glsl/embedding_q4gsw.yaml | 38 ++ .../runtime/graph/ops/impl/EmbeddingQ4gsw.cpp | 166 ++++++ backends/vulkan/test/custom_ops/targets.bzl | 1 + .../test/custom_ops/test_embedding_q4gsw.cpp | 520 ++++++++++++++++++ backends/vulkan/test/custom_ops/utils.cpp | 67 ++- backends/vulkan/test/custom_ops/utils.h | 4 + 13 files changed, 1210 insertions(+), 7 deletions(-) create mode 100644 backends/vulkan/patterns/quantized_embedding.py create mode 100644 backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp create mode 100644 backends/vulkan/test/custom_ops/test_embedding_q4gsw.cpp diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index 373b2a4d135..cc4f5969e8e 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -16,7 +16,7 @@ from torch.export import ExportedProgram -def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: +def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: # noqa: C901 """ Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator is responsible for transferring the tensor data, which is serialized with the model, @@ -54,9 +54,13 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: # Most prepacking ops have the primary input at arg 0, but # embedding is embedding(weight, indices, ...) where the # primary input (indices) is at arg 1. + # For embedding_q4gsw, args are (weight, weight_scales, group_size, + # indices) so the primary input (indices) is at arg 3. primary_arg_idx = 0 if user.target == exir_ops.edge.aten.embedding.default: primary_arg_idx = 1 + elif user.target == exir_ops.edge.et_vk.embedding_q4gsw.default: + primary_arg_idx = 3 if node in user.args and user.args.index(node) == primary_arg_idx: nodes_to_replace_input.append(user) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 5769c3c132b..7b0a0544662 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -879,6 +879,46 @@ def q8ta_relu_impl( lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd") q8ta_relu_op = getattr(getattr(torch.ops, namespace), name) +######################## +## embedding_q4gsw ## +######################## + + +def embedding_q4gsw_impl( + weight: torch.Tensor, + weight_scales: torch.Tensor, + group_size: int, + indices: torch.Tensor, + is_linear_weight: bool = False, +) -> torch.Tensor: + # Unpack 4-bit values from packed uint8 tensor + # Packing convention: packed_byte = (even_val + 8) << 4 | (odd_val + 8) + high = (weight >> 4).to(torch.int8) - 8 + low = (weight & 0xF).to(torch.int8) - 8 + if is_linear_weight: + unpacked = torch.stack([low, high], dim=-1).reshape(weight.shape[0], -1) + else: + unpacked = torch.stack([high, low], dim=-1).reshape(weight.shape[0], -1) + # Dequantize using per-group scales + num_groups = weight_scales.shape[1] if weight_scales.dim() > 1 else 1 + unpacked_groups = unpacked.reshape(weight.shape[0], num_groups, group_size) + scales = ( + weight_scales.unsqueeze(-1) + if weight_scales.dim() > 1 + else weight_scales.reshape(1, 1, 1) + ) + dequantized = unpacked_groups.float() * scales.float() + dequantized = dequantized.reshape(weight.shape[0], -1) + return torch.nn.functional.embedding(indices, dequantized) + + +name = "embedding_q4gsw" +lib.define( + f"{name}(Tensor weight, Tensor weight_scales, int group_size, Tensor indices, bool is_linear_weight = False) -> Tensor" +) +lib.impl(name, embedding_q4gsw_impl, "CompositeExplicitAutograd") +embedding_q4gsw_op = getattr(getattr(torch.ops, namespace), name) + ############################# ## select_as_symint ## ############################# diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 308718ade7d..f9881f637d3 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1436,6 +1436,38 @@ def check_embedding_weight_size(node: torch.fx.Node) -> bool: ) +# ============================================================================= +# EmbeddingQ4gsw (Quantized Embedding) +# ============================================================================= + + +@update_features(exir_ops.edge.quantized_decomposed.embedding_4bit.dtype) +def register_quantized_decomposed_embedding_4bit(): + def check_embedding_4bit_weight_size(node: torch.fx.Node) -> bool: + weight = node.args[0] + if isinstance(weight, torch.fx.Node) and utils.is_tensor_node(weight): + numel = weight.meta["val"].numel() + if numel > utils.DEFAULT_BUFFER_LIMIT: + return False + return True + + return OpFeatures( + inputs_storage=utils.ANY_BUFFER, + supports_prepacking=True, + supports_resize=True, + are_node_inputs_supported_fn=check_embedding_4bit_weight_size, + ) + + +@update_features(exir_ops.edge.et_vk.embedding_q4gsw.default) +def register_embedding_q4gsw(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + supports_prepacking=True, + supports_resize=True, + ) + + # ============================================================================= # BatchNorm.cpp # ============================================================================= diff --git a/backends/vulkan/patterns/BUCK b/backends/vulkan/patterns/BUCK index 711000f74ca..73bdc7edd1e 100644 --- a/backends/vulkan/patterns/BUCK +++ b/backends/vulkan/patterns/BUCK @@ -10,6 +10,7 @@ fbcode_target(_kind = runtime.python_library, "__init__.py", "pattern_registry.py", "rope.py", + "quantized_embedding.py", "quantized_linear.py", "quantized_convolution.py", "quantized_binary.py", diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index 050680b024d..a9323a57b09 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -10,6 +10,8 @@ import executorch.backends.vulkan.patterns.quantized_convolution # noqa +import executorch.backends.vulkan.patterns.quantized_embedding # noqa + import executorch.backends.vulkan.patterns.quantized_linear # noqa import executorch.backends.vulkan.patterns.quantized_unary # noqa diff --git a/backends/vulkan/patterns/quantized_embedding.py b/backends/vulkan/patterns/quantized_embedding.py new file mode 100644 index 00000000000..511efa72a3c --- /dev/null +++ b/backends/vulkan/patterns/quantized_embedding.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import executorch.backends.vulkan.utils as utils +import torch +from executorch.backends.transforms.utils import get_param_tensor +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +class QuantizedEmbeddingMatch(PatternMatch): + def __init__(self, node: torch.fx.Node) -> None: + self.anchor_node = node + self.match_found = False + self.all_nodes = [node] + + # quantized_decomposed.embedding_4bit.dtype args: + # (weight, weight_scales, weight_zero_points, quant_min, quant_max, + # indices, *, dtype) + self.weight_node = node.args[0] + self.scales_node = node.args[1] + self.indices_node = node.args[5] + + # Validate quantization parameters match our shader's assumptions. + # The shader hardcodes the 4-bit signed offset (subtract 8), which + # corresponds to quant_min=-8, quant_max=7, zero_points=0. + quant_min = node.args[3] + quant_max = node.args[4] + if quant_min != -8 or quant_max != 7: + self.match_found = False + return + + # weight_zero_points (args[2]) should be None or all-zeros + weight_zp_node = node.args[2] + if weight_zp_node is not None: + # If it's a constant tensor, verify it's all zeros + if ( + isinstance(weight_zp_node, torch.fx.Node) + and "val" in weight_zp_node.meta + ): + zp_val = weight_zp_node.meta["val"] + if isinstance(zp_val, torch.Tensor) and not torch.all(zp_val == 0): + self.match_found = False + return + + # Trace weight to its placeholder + const_node, arg_chain = utils.trace_args_until_placeholder(self.weight_node) + if const_node is not None: + self.weight_node = const_node + self.all_nodes.extend(arg_chain) + + # Trace scales to their placeholder + scales_node, arg_chain = utils.trace_args_until_placeholder(self.scales_node) + if scales_node is not None: + self.scales_node = scales_node + self.all_nodes.extend(arg_chain) + + self.match_found = True + + +embedding_4bit_target = exir_ops.edge.quantized_decomposed.embedding_4bit.dtype + + +def _detect_tied_linear_weight( + ep: ExportedProgram, + weight_node: torch.fx.Node, + weight_tensor: torch.Tensor, +) -> bool: + """Check if this embedding weight is tied to a linear weight. + + The embedding weight is packed uint8 [vocab_size, embed_dim/2]. The linear + output weight may be stored as unpacked int8 [vocab_size, embed_dim]. If we + find a placeholder whose int8 values match our unpacked embedding values, + the weights are tied and we should use the linear packing to enable dedup. + """ + vocab_size = weight_tensor.shape[0] + embed_dim = weight_tensor.shape[1] * 2 + + # Unpack embedding weight using embedding convention (high nibble first) + emb_high = (weight_tensor >> 4).to(torch.int8) - 8 + emb_low = (weight_tensor & 0xF).to(torch.int8) - 8 + emb_unpacked = torch.stack([emb_high, emb_low], dim=-1).reshape( + vocab_size, embed_dim + ) + + for node in ep.graph_module.graph.nodes: + if node.op != "placeholder" or node == weight_node: + continue + + try: + candidate = get_param_tensor(ep, node) + except RuntimeError: + continue + if candidate is None: + continue + if candidate.shape != (vocab_size, embed_dim) or candidate.dtype != torch.int8: + continue + + if torch.equal(emb_unpacked, candidate): + return True + + return False + + +@register_pattern_detector("quantized_embedding") +def find_quantized_embedding_patterns( + node: torch.fx.Node, +) -> Optional[QuantizedEmbeddingMatch]: + if node.target != embedding_4bit_target: + return None + + matched_pattern = QuantizedEmbeddingMatch(node) + if matched_pattern.match_found: + return matched_pattern + return None + + +@register_pattern_replacement("quantized_embedding") +def replace_quantized_embedding_patterns( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedEmbeddingMatch, +): + weight_tensor = get_param_tensor(ep, match.weight_node) + assert weight_tensor is not None + + scales_tensor = get_param_tensor(ep, match.scales_node) + assert scales_tensor is not None + + is_linear_weight = _detect_tied_linear_weight(ep, match.weight_node, weight_tensor) + + if is_linear_weight: + # Repack using linear convention (low nibble = even, high nibble = odd) + vocab_size = weight_tensor.shape[0] + high = (weight_tensor >> 4).to(torch.int8) - 8 + low = (weight_tensor & 0xF).to(torch.int8) - 8 + unpacked = torch.stack([high, low], dim=-1).reshape(vocab_size, -1) + repacked = unpacked.to(torch.uint8) + 8 + weight_tensor = repacked[:, 1::2] << 4 | repacked[:, ::2] + # Update the state dict with repacked tensor + for key, value in ep.state_dict.items(): + if value.data_ptr() == get_param_tensor(ep, match.weight_node).data_ptr(): + ep.state_dict[key] = weight_tensor + break + + # Compute group_size from weight and scales shapes + embed_dim = weight_tensor.shape[1] * 2 # packed, 2 values per byte + groups_per_row = scales_tensor.shape[1] if scales_tensor.ndim > 1 else 1 + group_size = embed_dim // groups_per_row + + with graph_module.graph.inserting_before(match.anchor_node): + embedding_q4gsw_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.embedding_q4gsw.default, + args=( + match.weight_node, + match.scales_node, + group_size, + match.indices_node, + is_linear_weight, + ), + ) + + embedding_q4gsw_node.meta["val"] = match.anchor_node.meta["val"] + match.anchor_node.replace_all_uses_with(embedding_q4gsw_node) diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl new file mode 100644 index 00000000000..ce6779a0c9b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl @@ -0,0 +1,165 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +$if STORAGE == "buffer": + #define OUTPUT_BUFFER + +$if IS_LINEAR_WEIGHT: + #define LINEAR_WEIGHT + $if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +${define_required_extensions(STORAGE, DTYPE)} +$if IS_LINEAR_WEIGHT: + ${define_required_extensions(WEIGHT_STORAGE, "int")} +$else: + ${define_required_extensions("buffer", "int")} +${define_required_extensions("buffer", [SCALES_DTYPE, "uint8"])} + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +// Output uses the graph's storage type +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +// Indices use the graph's storage type +${layout_declare_tensor(B, "r", "t_indices", "int", STORAGE)} +// Weight: flat buffer for regular, packed block format for linear_weight +$if IS_LINEAR_WEIGHT: + ${layout_declare_tensor(B, "r", "t_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +$else: + ${layout_declare_tensor(B, "r", "t_weight", "int", "buffer", is_scalar_array=False)} +// Scales are ALWAYS buffer, loaded as scalar +${layout_declare_tensor(B, "r", "t_scales", SCALES_DTYPE, "buffer")} + +layout(push_constant) uniform PushConstants { + int group_size; + int embed_dim; + int num_indices; + int out_height; + int is_linear_weight; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Load 4 dequantized embedding values at the given dimension offset. +// The weight storage format differs between regular and linear_weight variants, +// so this function is defined separately for each. +#ifdef LINEAR_WEIGHT + +// Linear-packed block format: weight is stored as blocks indexed by +// t_weight[n8 * K4 + k4], where each ivec4 element contains 8 interleaved +// 4-bit values from 2 sub-rows of an 8-row block. +VEC4_T load_embedding_weights( + const int embedding_idx, + const int dim, + const float scale) { + const int n8 = embedding_idx >> 3; + const int n_local = embedding_idx & 7; + const int row_in_block = n_local < 4 ? n_local : n_local - 4; + const int shift_base = n_local < 4 ? 0 : 4; + const int K4 = embed_dim >> 2; + const int k4 = dim >> 2; + +#ifdef WEIGHT_BUFFER + const ivec4 block = t_weight[n8 * K4 + k4]; +#else + const ivec4 block = texelFetch(t_weight, ivec2(k4, n8), 0); +#endif + + const uint packed_uint = uint(block[row_in_block]); + + return VEC4_T( + T(float(int((packed_uint >> shift_base) & 0xFu) - 8) * scale), + T(float(int((packed_uint >> (shift_base + 8)) & 0xFu) - 8) * scale), + T(float(int((packed_uint >> (shift_base + 16)) & 0xFu) - 8) * scale), + T(float(int((packed_uint >> (shift_base + 24)) & 0xFu) - 8) * scale)); +} + +#else // !LINEAR_WEIGHT + +// Flat buffer format: weight is stored as a contiguous array of packed bytes. +// Each ivec4 covers 32 int4 values (16 bytes) for one embedding row. +VEC4_T load_embedding_weights( + const int embedding_idx, + const int dim, + const float scale) { + const int blocks_per_row = embed_dim >> 5; + const int block_in_row = dim >> 5; + const int t = (dim >> 2) & 7; + + const ivec4 packed = t_weight[embedding_idx * blocks_per_row + block_in_row]; + const int int_idx = t >> 1; + const int byte_pair = t & 1; + + const uint u = uint(packed[int_idx]); + const int shift = byte_pair << 4; + const uint b0 = (u >> shift) & 0xFFu; + const uint b1 = (u >> (shift + 8)) & 0xFFu; + + return VEC4_T( + T(float(int(b0 >> 4) - 8) * scale), + T(float(int(b0 & 0xFu) - 8) * scale), + T(float(int(b1 >> 4) - 8) * scale), + T(float(int(b1 & 0xFu) - 8) * scale)); +} + +#endif // LINEAR_WEIGHT + +void main() { + const int block_in_row = int(gl_GlobalInvocationID.x); + const int y_idx = int(gl_GlobalInvocationID.y); + const int z_idx = int(gl_GlobalInvocationID.z); + + const int blocks_per_row = embed_dim >> 5; + const int indices_idx = z_idx * out_height + y_idx; + if (block_in_row >= blocks_per_row || indices_idx >= num_indices) { + return; + } + +#ifdef OUTPUT_BUFFER + const int embedding_idx = t_indices[indices_idx]; +#else + const ivec4 in_texel = + texelFetch(t_indices, ivec3(indices_idx >> 2, 0, 0), 0); + const int embedding_idx = in_texel[indices_idx & 3]; +#endif + + const int base_dim = block_in_row << 5; + const int groups_per_row = embed_dim / group_size; + + [[unroll]] for (int t = 0; t < 8; t++) { + const int dim = base_dim + (t << 2); + const float scale = + float(t_scales[embedding_idx * groups_per_row + dim / group_size]); + + const VEC4_T vals = + load_embedding_weights(embedding_idx, dim, scale); + +#ifdef OUTPUT_BUFFER + const int out_base = indices_idx * embed_dim + dim; + t_out[out_base] = vals.x; + t_out[out_base + 1] = vals.y; + t_out[out_base + 2] = vals.z; + t_out[out_base + 3] = vals.w; +#else + imageStore( + t_out, + ivec3((base_dim >> 2) + t, y_idx, z_idx), + vals); +#endif + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.yaml b/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.yaml new file mode 100644 index 00000000000..9102bbb542e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.yaml @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +embedding_q4gsw: + parameter_names_with_default_values: + DTYPE: half + SCALES_DTYPE: half + STORAGE: buffer + WEIGHT_STORAGE: buffer + IS_LINEAR_WEIGHT: false + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture3d + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: embedding_q4gsw + - NAME: embedding_q4gsw_float_scales + SCALES_DTYPE: float + - NAME: embedding_q4gsw_linear_weight_buffer + IS_LINEAR_WEIGHT: true + WEIGHT_STORAGE: buffer + - NAME: embedding_q4gsw_linear_weight_texture2d + IS_LINEAR_WEIGHT: true + WEIGHT_STORAGE: texture2d + - NAME: embedding_q4gsw_float_scales_linear_weight_buffer + SCALES_DTYPE: float + IS_LINEAR_WEIGHT: true + WEIGHT_STORAGE: buffer + - NAME: embedding_q4gsw_float_scales_linear_weight_texture2d + SCALES_DTYPE: float + IS_LINEAR_WEIGHT: true + WEIGHT_STORAGE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp b/backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp new file mode 100644 index 00000000000..c5051392074 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp @@ -0,0 +1,166 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include + +#include + +#include + +namespace vkcompute { + +utils::uvec3 pick_embedding_q4gsw_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const std::vector& sizes = graph->sizes_of(out); + int ndim = sizes.size(); + + uint32_t blocks_per_row = static_cast(sizes[ndim - 1]) / 32; + uint32_t height = ndim >= 2 ? static_cast(sizes[ndim - 2]) : 1; + uint32_t depth = 1; + for (int i = 0; i < ndim - 2; i++) { + depth *= static_cast(sizes[i]); + } + + return {blocks_per_row, height, depth}; +} + +void resize_embedding_q4gsw_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef indices = args.at(1).refs.at(0); + + const int64_t embed_dim = graph->get_int(resize_args.at(0)); + const std::vector indices_sizes = graph->sizes_of(indices); + + // Output shape is indices.shape + [embed_dim] + std::vector out_sizes = indices_sizes; + out_sizes.push_back(embed_dim); + + graph->virtual_resize(out, out_sizes); +} + +void add_embedding_q4gsw_node( + ComputeGraph& graph, + const ValueRef indices, + const ValueRef weight, + const ValueRef weight_scales, + const int32_t group_size, + const int32_t embed_dim, + const int32_t num_indices, + const int32_t out_height, + const int32_t is_linear_weight, + const ValueRef out) { + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(indices) == WHCN::kWidthDim); + VK_CHECK_COND(embed_dim % 32 == 0, "embed_dim must be a multiple of 32"); + + std::string kernel_name = "embedding_q4gsw"; + kernel_name.reserve(kShaderNameReserve); + + vkapi::ScalarType scales_dtype = graph.dtype_of(weight_scales); + if (scales_dtype != vkapi::kHalf) { + kernel_name += "_float_scales"; + } + + if (is_linear_weight) { + kernel_name += "_linear_weight"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(weight)); + } + + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + std::vector push_constants = { + PushConstantDataInfo(&group_size, sizeof(group_size)), + PushConstantDataInfo(&embed_dim, sizeof(embed_dim)), + PushConstantDataInfo(&num_indices, sizeof(num_indices)), + PushConstantDataInfo(&out_height, sizeof(out_height)), + PushConstantDataInfo(&is_linear_weight, sizeof(is_linear_weight)), + }; + + ValueRef embed_dim_ref = graph.add_scalar(embed_dim); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_embedding_q4gsw_global_wg_size, + default_pick_local_wg_size, + {{out, vkapi::kWrite}, {{indices, weight, weight_scales}, vkapi::kRead}}, + {}, + push_constants, + {}, + {embed_dim_ref}, + resize_embedding_q4gsw_node)); +} + +void embedding_q4gsw(ComputeGraph& graph, const std::vector& args) { + ValueRef weight_data = args[0]; + ValueRef weight_scales_data = args[1]; + ValueRef group_size_ref = args[2]; + ValueRef indices = args[3]; + ValueRef is_linear_weight_ref = args[4]; + ValueRef out = args[5]; + + int32_t group_size = graph.extract_scalar(group_size_ref); + int32_t is_linear_weight = + graph.extract_scalar(is_linear_weight_ref) ? 1 : 0; + + const std::vector weight_sizes = graph.sizes_of(weight_data); + int32_t embed_dim = static_cast(weight_sizes.back() * 2); + + const std::vector indices_sizes = graph.sizes_of(indices); + int32_t num_indices = 1; + for (auto s : indices_sizes) { + num_indices *= static_cast(s); + } + int32_t out_height = static_cast(indices_sizes.back()); + + ValueRef weight; + if (is_linear_weight) { + QuantizationConfig weight_quant_config(4, kPerGroup, {group_size}); + weight = prepack_quantized_linear_weight( + graph, weight_quant_config, weight_data); + } else { + weight = prepack_standard( + graph, weight_data, utils::kBuffer, utils::kWidthPacked); + } + ValueRef weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + add_embedding_q4gsw_node( + graph, + indices, + weight, + weight_scales, + group_size, + embed_dim, + num_indices, + out_height, + is_linear_weight, + out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.embedding_q4gsw.default, embedding_q4gsw); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 84432bce30b..e8cdb3a4bf8 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -102,3 +102,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("test_mm") define_custom_op_test_binary("test_conv2d_pw") define_custom_op_test_binary("test_conv2d_dw") + define_custom_op_test_binary("test_embedding_q4gsw") diff --git a/backends/vulkan/test/custom_ops/test_embedding_q4gsw.cpp b/backends/vulkan/test/custom_ops/test_embedding_q4gsw.cpp new file mode 100644 index 00000000000..bb588fd3dbb --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_embedding_q4gsw.cpp @@ -0,0 +1,520 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +#include "utils.h" + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +struct EmbeddingConfig { + int64_t vocab_size; + int64_t embed_dim; + int64_t group_size; + std::vector indices_shape; + std::string test_case_name = "placeholder"; + vkapi::ScalarType dtype = vkapi::kHalf; + vkapi::ScalarType scales_dtype = vkapi::kHalf; + utils::StorageType storage_type = utils::kBuffer; + bool is_linear_weight = false; +}; + +// CPU reference: unpack 4-bit weights, dequantize, and perform embedding lookup +void embedding_4bit_reference(TestCase& tc) { + auto& weight_spec = tc.inputs()[0]; + auto& scales_spec = tc.inputs()[1]; + int32_t group_size = tc.inputs()[2].get_int_value(); + auto& indices_spec = tc.inputs()[3]; + bool is_linear_weight = tc.inputs()[4].get_bool_value(); + auto& output_spec = tc.outputs()[0]; + + weight_spec.ensure_data_generated(); + scales_spec.ensure_data_generated(); + indices_spec.ensure_data_generated(); + + const auto& weight_data = weight_spec.get_uint8_data(); + const auto& indices_data = indices_spec.get_int32_data(); + + bool scales_are_half = (scales_spec.dtype == vkapi::kHalf); + + int64_t packed_dim = weight_spec.sizes[1]; + int64_t embed_dim = packed_dim * 2; + int64_t groups_per_row = scales_spec.sizes[1]; + + int64_t num_indices = 1; + for (auto s : indices_spec.sizes) { + num_indices *= s; + } + + int64_t total_output = num_indices * embed_dim; + + // Always populate ref_float_data so the caching framework can distribute it + output_spec.get_ref_float_data().resize(total_output); + + bool output_is_half = (output_spec.dtype == vkapi::kHalf); + if (output_is_half) { + output_spec.get_ref_half_data().resize(total_output); + } + + for (int64_t i = 0; i < num_indices; ++i) { + int32_t idx = indices_data[i]; + for (int64_t d = 0; d < embed_dim; ++d) { + int64_t packed_idx = d / 2; + uint8_t packed_byte = weight_data[idx * packed_dim + packed_idx]; + + // Unpack: packed_byte = (even_val + 8) << 4 | (odd_val + 8) + // Even d -> high nibble, odd d -> low nibble + // For linear weight packing, nibble order is swapped + int int4_val; + if (d % 2 == 0) { + if (is_linear_weight) { + int4_val = static_cast(packed_byte & 0xF) - 8; + } else { + int4_val = static_cast(packed_byte >> 4) - 8; + } + } else { + if (is_linear_weight) { + int4_val = static_cast(packed_byte >> 4) - 8; + } else { + int4_val = static_cast(packed_byte & 0xF) - 8; + } + } + + int64_t group_idx = d / group_size; + int64_t scale_idx = idx * groups_per_row + group_idx; + + float scale; + if (scales_are_half) { + uint16_t scale_half = scales_spec.get_half_data()[scale_idx]; + scale = half_to_float(scale_half); + } else { + scale = scales_spec.get_float_data()[scale_idx]; + } + + float result = static_cast(int4_val) * scale; + + // Always store float reference + output_spec.get_ref_float_data()[i * embed_dim + d] = result; + + if (output_is_half) { + output_spec.get_ref_half_data()[i * embed_dim + d] = + float_to_half(result); + } + } + } +} + +TestCase create_test_case(const EmbeddingConfig& config) { + TestCase test_case; + test_case.set_name(config.test_case_name); + test_case.set_operator_name("et_vk.embedding_q4gsw.default"); + test_case.set_shader_filter({}); + + // Weight: [vocab_size, embed_dim / 2] packed uint8 + ValueSpec weight( + {config.vocab_size, config.embed_dim / 2}, + vkapi::kByte, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDINT4); + weight.set_constant(true); + test_case.add_input_spec(weight); + + // Weight scales: [vocab_size, groups_per_row] + int64_t groups_per_row = config.embed_dim / config.group_size; + ValueSpec weight_scales( + {config.vocab_size, groups_per_row}, + config.scales_dtype, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + test_case.add_input_spec(weight_scales); + + // Group size: int scalar + ValueSpec group_size_spec(static_cast(config.group_size)); + test_case.add_input_spec(group_size_spec); + + // Indices: [batch, seq_len] int32 + ValueSpec indices( + config.indices_shape, + vkapi::kInt, + config.storage_type, + utils::kWidthPacked, + DataGenType::RANDINT); + + // Clamp indices to valid vocab range + indices.ensure_data_generated(); + for (auto& idx : indices.get_int32_data()) { + idx = std::abs(idx) % config.vocab_size; + } + + test_case.add_input_spec(indices); + + // is_linear_weight: bool scalar + ValueSpec is_linear_weight_spec(config.is_linear_weight); + test_case.add_input_spec(is_linear_weight_spec); + + // Output: indices.shape + [embed_dim] + std::vector output_shape = config.indices_shape; + output_shape.push_back(config.embed_dim); + ValueSpec output( + output_shape, config.dtype, config.storage_type, utils::kWidthPacked); + test_case.add_output_spec(output); + + return test_case; +} + +std::vector generate_test_cases() { + std::vector test_cases; + + // --- is_linear_weight = true --- + + // Basic test with linear weight packing + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_linear_weight", + .is_linear_weight = true})); + + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d_linear_weight", + .is_linear_weight = true})); + + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup_linear_weight", + .is_linear_weight = true})); + + // fp32 output with linear weight + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_fp32_linear_weight", + .dtype = vkapi::kFloat, + .is_linear_weight = true})); + + // Texture3D with linear weight + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_texture_linear_weight", + .storage_type = utils::kTexture3D, + .is_linear_weight = true})); + + // Llama 3.2 1B with linear weight packing + test_cases.push_back(create_test_case( + {.vocab_size = 128256, + .embed_dim = 2048, + .group_size = 32, + .indices_shape = {1, 2047}, + .test_case_name = "llama_3_2_1b_prefill_fp32_linear_weight", + .dtype = vkapi::kFloat, + .storage_type = utils::kBuffer, + .is_linear_weight = true})); + + // --- Half scales (default) --- + + // Basic test: small vocab, small embed_dim + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d"})); + + // 2D indices + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d"})); + + // Larger vocab, multiple groups + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup"})); + + // group_size == embed_dim (single group) + test_cases.push_back(create_test_case( + {.vocab_size = 50, + .embed_dim = 64, + .group_size = 64, + .indices_shape = {2, 4}, + .test_case_name = "single_group"})); + + // fp32 output variants (half scales) + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_fp32", + .dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d_fp32", + .dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup_fp32", + .dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 50, + .embed_dim = 64, + .group_size = 64, + .indices_shape = {2, 4}, + .test_case_name = "single_group_fp32", + .dtype = vkapi::kFloat})); + + // Texture3D variants (fp16 output, half scales) + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_texture", + .dtype = vkapi::kHalf, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d_texture", + .dtype = vkapi::kHalf, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup_texture", + .dtype = vkapi::kHalf, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 50, + .embed_dim = 64, + .group_size = 64, + .indices_shape = {2, 4}, + .test_case_name = "single_group_texture", + .dtype = vkapi::kHalf, + .storage_type = utils::kTexture3D})); + + // Texture3D variants (fp32 output, half scales) + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_fp32_texture", + .dtype = vkapi::kFloat, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d_fp32_texture", + .dtype = vkapi::kFloat, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup_fp32_texture", + .dtype = vkapi::kFloat, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 50, + .embed_dim = 64, + .group_size = 64, + .indices_shape = {2, 4}, + .test_case_name = "single_group_fp32_texture", + .dtype = vkapi::kFloat, + .storage_type = utils::kTexture3D})); + + // --- Float scales --- + + // Buffer variants with float scales + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_float_scales", + .dtype = vkapi::kHalf, + .scales_dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d_float_scales", + .dtype = vkapi::kHalf, + .scales_dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup_float_scales", + .dtype = vkapi::kHalf, + .scales_dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_fp32_float_scales", + .dtype = vkapi::kFloat, + .scales_dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d_fp32_float_scales", + .dtype = vkapi::kFloat, + .scales_dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup_fp32_float_scales", + .dtype = vkapi::kFloat, + .scales_dtype = vkapi::kFloat})); + + // Texture3D with float scales + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_float_scales_texture", + .dtype = vkapi::kHalf, + .scales_dtype = vkapi::kFloat, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_fp32_float_scales_texture", + .dtype = vkapi::kFloat, + .scales_dtype = vkapi::kFloat, + .storage_type = utils::kTexture3D})); + + // Llama 3.2 1B prefill configuration (fp32 output, half scales) + test_cases.push_back(create_test_case( + {.vocab_size = 128256, + .embed_dim = 2048, + .group_size = 32, + .indices_shape = {1, 2047}, + .test_case_name = "llama_3_2_1b_prefill_fp32", + .dtype = vkapi::kFloat, + .storage_type = utils::kBuffer})); + + // Llama 3.2 1B prefill configuration (fp32 output, float scales) + test_cases.push_back(create_test_case( + {.vocab_size = 128256, + .embed_dim = 2048, + .group_size = 32, + .indices_shape = {1, 2047}, + .test_case_name = "llama_3_2_1b_prefill_fp32_float_scales", + .dtype = vkapi::kFloat, + .scales_dtype = vkapi::kFloat, + .storage_type = utils::kBuffer})); + + // Llama 3.2 1B prefill configuration (fp16 output, half scales) + test_cases.push_back(create_test_case( + {.vocab_size = 128256, + .embed_dim = 2048, + .group_size = 32, + .indices_shape = {1, 2047}, + .test_case_name = "llama_3_2_1b_prefill_fp16", + .dtype = vkapi::kHalf, + .storage_type = utils::kBuffer})); + + // Llama 3.2 1B prefill configuration (fp16 output, float scales) + test_cases.push_back(create_test_case( + {.vocab_size = 128256, + .embed_dim = 2048, + .group_size = 32, + .indices_shape = {1, 2047}, + .test_case_name = "llama_3_2_1b_prefill_fp16_float_scales", + .dtype = vkapi::kHalf, + .scales_dtype = vkapi::kFloat, + .storage_type = utils::kBuffer})); + + return test_cases; +} + +int main(int argc, char** argv) { + auto results = execute_test_cases( + generate_test_cases, + "embedding_q4gsw", + /* warmup_runs */ 3, + /* benchmark_runs */ 10, + embedding_4bit_reference); + + results.print_summary(); + + if (results.get_failed_count() > 0) { + std::cerr << "FAILED: " << results.get_failed_count() << " test(s) failed." + << std::endl; + return 1; + } + + std::cout << "PASSED: All tests passed." << std::endl; + return 0; +} diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index 2a50e7b5ec1..fc406c47403 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -74,6 +74,49 @@ void generate_random_int4_data( void generate_ones_data(std::vector& data); void generate_zeros_data(std::vector& data); +// Convert a float32 value to IEEE 754 half-precision (uint16_t) +uint16_t float_to_half(float value) { + uint32_t float_bits; + std::memcpy(&float_bits, &value, sizeof(float)); + uint32_t sign = (float_bits >> 31) & 0x1; + int32_t exponent = static_cast((float_bits >> 23) & 0xFF) - 127; + uint32_t mantissa = float_bits & 0x7FFFFF; + + uint16_t half_val; + if (exponent > 15) { + half_val = static_cast((sign << 15) | 0x7C00); // Inf + } else if (exponent < -14) { + half_val = static_cast(sign << 15); // Zero / subnormal + } else { + half_val = static_cast( + (sign << 15) | (static_cast(exponent + 15) << 10) | + (mantissa >> 13)); + } + return half_val; +} + +// Convert a IEEE 754 half-precision (uint16_t) value to float32 +float half_to_float(uint16_t half_val) { + uint32_t sign = (half_val >> 15) & 0x1; + uint32_t exponent = (half_val >> 10) & 0x1F; + uint32_t mantissa = half_val & 0x3FF; + + float result; + if (exponent == 0) { + result = std::ldexp(static_cast(mantissa), -24); + } else if (exponent == 31) { + result = mantissa ? std::numeric_limits::quiet_NaN() + : std::numeric_limits::infinity(); + } else { + result = std::ldexp( + 1.0f + static_cast(mantissa) / 1024.0f, exponent - 15); + } + if (sign) { + result = -result; + } + return result; +} + // Output and latency printing utilities namespace { static int print_output_enabled = 0; @@ -154,6 +197,13 @@ void ValueSpec::generate_tensor_data(int seed) { // Simple conversion to uint16_t representation of half half_data[i] = static_cast(temp_data[i] * 32767.0f); } + } else if (data_gen_type == DataGenType::RANDOM_SCALES) { + // Generate random scales in float, then convert to proper fp16 + std::vector temp_data(num_elements); + generate_random_float_data(temp_data, 0.005f, 0.015f, seed); + for (size_t i = 0; i < temp_data.size(); ++i) { + half_data[i] = float_to_half(temp_data[i]); + } } else if (data_gen_type == DataGenType::RANDINT) { generate_randint_half_data(half_data, -10, 10, seed); } else if (data_gen_type == DataGenType::RANDINT8) { @@ -1395,15 +1445,19 @@ execute_test_case(TestCase& test_case, int warmup_runs, int benchmark_runs) { graph.prepack(); // Copy input data into the graph's staging buffers + size_t graph_input_idx = 0; for (size_t i = 0; i < test_case.num_inputs(); ++i) { const ValueSpec& input_spec = test_case.inputs()[i]; - if (input_spec.is_tensor() && i < graph.inputs().size()) { - // Skip copying data for constant tensors - if (input_spec.is_constant()) { - continue; - } - const auto& input_ref = graph.inputs()[i]; + // Only non-constant tensor inputs correspond to graph.inputs() entries + bool is_graph_input = input_spec.is_tensor() && !input_spec.is_constant() && + !input_spec.is_none(); + if (!is_graph_input) { + continue; + } + + if (graph_input_idx < graph.inputs().size()) { + const auto& input_ref = graph.inputs()[graph_input_idx]; // Get the appropriate data based on dtype const void* data_ptr = nullptr; @@ -1433,6 +1487,7 @@ execute_test_case(TestCase& test_case, int warmup_runs, int benchmark_runs) { graph.maybe_cast_and_copy_into_staging( input_ref.staging, data_ptr, data_numel, input_spec.dtype); } + ++graph_input_idx; } // Warmup runs diff --git a/backends/vulkan/test/custom_ops/utils.h b/backends/vulkan/test/custom_ops/utils.h index 9b5b6a46782..4f3e4ec629f 100644 --- a/backends/vulkan/test/custom_ops/utils.h +++ b/backends/vulkan/test/custom_ops/utils.h @@ -864,6 +864,10 @@ void compute_weight_sums_4bit_grouped( int64_t out_features, int64_t group_size); +// Half-precision conversion utilities +uint16_t float_to_half(float value); +float half_to_float(uint16_t half_val); + // Setup compute graph based on TestCase and operation name ComputeGraph setup_compute_graph(TestCase& test_case, std::string op_name); From 5cc29283e3395e52f356f11b7dbbcdf9f96742ae Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 20 Mar 2026 09:03:54 -0700 Subject: [PATCH 2/4] Update base for Update on "[ET-VK][embedding] Enable embedding weight dedup with tied linear weights" When embedding and output linear weights are tied (same underlying tensor, as in Llama 3.2 1B), they are quantized independently with opposite nibble packing conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an `is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume weights packed in the linear convention, enabling dedup and saving ~125 MB. The implementation spans Python (detection + repacking), GLSL (packed block format reading), and C++ (shared prepacking with linear): - custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference impl, swapping nibble extraction order when True - patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that unpacks the embedding weight and compares against int8 linear weight placeholders. When matched, repack using linear convention and pass `is_linear_weight=True` to the op - embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()` returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from the block-interleaved format produced by `pack_q4_linear_weight` - embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x {buffer, texture2d} weight storage - EmbeddingQ4gsw.cpp: When `is_linear_weight`, call `prepack_quantized_linear_weight` (shared with linear op) instead of `prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight tensor has a different shape - test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update reference function, add 6 new test cases Authored with Claude. Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/) [ghstack-poisoned] From 39fe9851c6a00d1b75cf317042c8b1854b0177dc Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 20 Mar 2026 11:22:50 -0700 Subject: [PATCH 3/4] Update base for Update on "[ET-VK][embedding] Enable embedding weight dedup with tied linear weights" When embedding and output linear weights are tied (same underlying tensor, as in Llama 3.2 1B), they are quantized independently with opposite nibble packing conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an `is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume weights packed in the linear convention, enabling dedup and saving ~125 MB. The implementation spans Python (detection + repacking), GLSL (packed block format reading), and C++ (shared prepacking with linear): - custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference impl, swapping nibble extraction order when True - patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that unpacks the embedding weight and compares against int8 linear weight placeholders. When matched, repack using linear convention and pass `is_linear_weight=True` to the op - embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()` returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from the block-interleaved format produced by `pack_q4_linear_weight` - embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x {buffer, texture2d} weight storage - EmbeddingQ4gsw.cpp: When `is_linear_weight`, call `prepack_quantized_linear_weight` (shared with linear op) instead of `prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight tensor has a different shape - test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update reference function, add 6 new test cases Authored with Claude. Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/) [ghstack-poisoned] From 078ef058beefee40bbbd82a929125ac0ac60d12e Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 20 Mar 2026 12:29:57 -0700 Subject: [PATCH 4/4] Update base for Update on "[ET-VK][embedding] Enable embedding weight dedup with tied linear weights" When embedding and output linear weights are tied (same underlying tensor, as in Llama 3.2 1B), they are quantized independently with opposite nibble packing conventions, preventing SHA256-based dedup in the NamedDataStore. This adds an `is_linear_weight` flag to `embedding_q4gsw` so the embedding op can consume weights packed in the linear convention, enabling dedup and saving ~125 MB. The implementation spans Python (detection + repacking), GLSL (packed block format reading), and C++ (shared prepacking with linear): - custom_ops_lib.py: Add `is_linear_weight` param to op signature and reference impl, swapping nibble extraction order when True - patterns/quantized_embedding.py: Add `_detect_tied_linear_weight()` that unpacks the embedding weight and compares against int8 linear weight placeholders. When matched, repack using linear convention and pass `is_linear_weight=True` to the op - embedding_q4gsw.glsl: Extract weight loading into `load_embedding_weights()` returning VEC4_T, with compile-time LINEAR_WEIGHT variant that reads from the block-interleaved format produced by `pack_q4_linear_weight` - embedding_q4gsw.yaml: Add 4 new shader variants for linear_weight x {buffer, texture2d} weight storage - EmbeddingQ4gsw.cpp: When `is_linear_weight`, call `prepack_quantized_linear_weight` (shared with linear op) instead of `prepack_standard`. Pass `embed_dim` as resize_arg since the prepacked weight tensor has a different shape - test_embedding_q4gsw.cpp: Add `is_linear_weight` to EmbeddingConfig, update reference function, add 6 new test cases Authored with Claude. Differential Revision: [D97430803](https://our.internmc.facebook.com/intern/diff/D97430803/) [ghstack-poisoned]