diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index 373b2a4d135..cc4f5969e8e 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -16,7 +16,7 @@ from torch.export import ExportedProgram -def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: +def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: # noqa: C901 """ Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator is responsible for transferring the tensor data, which is serialized with the model, @@ -54,9 +54,13 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: # Most prepacking ops have the primary input at arg 0, but # embedding is embedding(weight, indices, ...) where the # primary input (indices) is at arg 1. + # For embedding_q4gsw, args are (weight, weight_scales, group_size, + # indices) so the primary input (indices) is at arg 3. primary_arg_idx = 0 if user.target == exir_ops.edge.aten.embedding.default: primary_arg_idx = 1 + elif user.target == exir_ops.edge.et_vk.embedding_q4gsw.default: + primary_arg_idx = 3 if node in user.args and user.args.index(node) == primary_arg_idx: nodes_to_replace_input.append(user) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 5769c3c132b..7b0a0544662 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -879,6 +879,46 @@ def q8ta_relu_impl( lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd") q8ta_relu_op = getattr(getattr(torch.ops, namespace), name) +######################## +## embedding_q4gsw ## +######################## + + +def embedding_q4gsw_impl( + weight: torch.Tensor, + weight_scales: torch.Tensor, + group_size: int, + indices: torch.Tensor, + is_linear_weight: bool = False, +) -> torch.Tensor: + # Unpack 4-bit values from packed uint8 tensor + # Packing convention: packed_byte = (even_val + 8) << 4 | (odd_val + 8) + high = (weight >> 4).to(torch.int8) - 8 + low = (weight & 0xF).to(torch.int8) - 8 + if is_linear_weight: + unpacked = torch.stack([low, high], dim=-1).reshape(weight.shape[0], -1) + else: + unpacked = torch.stack([high, low], dim=-1).reshape(weight.shape[0], -1) + # Dequantize using per-group scales + num_groups = weight_scales.shape[1] if weight_scales.dim() > 1 else 1 + unpacked_groups = unpacked.reshape(weight.shape[0], num_groups, group_size) + scales = ( + weight_scales.unsqueeze(-1) + if weight_scales.dim() > 1 + else weight_scales.reshape(1, 1, 1) + ) + dequantized = unpacked_groups.float() * scales.float() + dequantized = dequantized.reshape(weight.shape[0], -1) + return torch.nn.functional.embedding(indices, dequantized) + + +name = "embedding_q4gsw" +lib.define( + f"{name}(Tensor weight, Tensor weight_scales, int group_size, Tensor indices, bool is_linear_weight = False) -> Tensor" +) +lib.impl(name, embedding_q4gsw_impl, "CompositeExplicitAutograd") +embedding_q4gsw_op = getattr(getattr(torch.ops, namespace), name) + ############################# ## select_as_symint ## ############################# diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 308718ade7d..f9881f637d3 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1436,6 +1436,38 @@ def check_embedding_weight_size(node: torch.fx.Node) -> bool: ) +# ============================================================================= +# EmbeddingQ4gsw (Quantized Embedding) +# ============================================================================= + + +@update_features(exir_ops.edge.quantized_decomposed.embedding_4bit.dtype) +def register_quantized_decomposed_embedding_4bit(): + def check_embedding_4bit_weight_size(node: torch.fx.Node) -> bool: + weight = node.args[0] + if isinstance(weight, torch.fx.Node) and utils.is_tensor_node(weight): + numel = weight.meta["val"].numel() + if numel > utils.DEFAULT_BUFFER_LIMIT: + return False + return True + + return OpFeatures( + inputs_storage=utils.ANY_BUFFER, + supports_prepacking=True, + supports_resize=True, + are_node_inputs_supported_fn=check_embedding_4bit_weight_size, + ) + + +@update_features(exir_ops.edge.et_vk.embedding_q4gsw.default) +def register_embedding_q4gsw(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + supports_prepacking=True, + supports_resize=True, + ) + + # ============================================================================= # BatchNorm.cpp # ============================================================================= diff --git a/backends/vulkan/patterns/BUCK b/backends/vulkan/patterns/BUCK index 711000f74ca..73bdc7edd1e 100644 --- a/backends/vulkan/patterns/BUCK +++ b/backends/vulkan/patterns/BUCK @@ -10,6 +10,7 @@ fbcode_target(_kind = runtime.python_library, "__init__.py", "pattern_registry.py", "rope.py", + "quantized_embedding.py", "quantized_linear.py", "quantized_convolution.py", "quantized_binary.py", diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index 050680b024d..a9323a57b09 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -10,6 +10,8 @@ import executorch.backends.vulkan.patterns.quantized_convolution # noqa +import executorch.backends.vulkan.patterns.quantized_embedding # noqa + import executorch.backends.vulkan.patterns.quantized_linear # noqa import executorch.backends.vulkan.patterns.quantized_unary # noqa diff --git a/backends/vulkan/patterns/quantized_embedding.py b/backends/vulkan/patterns/quantized_embedding.py new file mode 100644 index 00000000000..d260c8feb71 --- /dev/null +++ b/backends/vulkan/patterns/quantized_embedding.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import executorch.backends.vulkan.utils as utils +import torch +from executorch.backends.transforms.utils import get_param_tensor +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +class QuantizedEmbeddingMatch(PatternMatch): + def __init__(self, node: torch.fx.Node) -> None: + self.anchor_node = node + self.match_found = False + self.all_nodes = [node] + + # quantized_decomposed.embedding_4bit.dtype args: + # (weight, weight_scales, weight_zero_points, quant_min, quant_max, + # indices, *, dtype) + self.weight_node = node.args[0] + self.scales_node = node.args[1] + self.indices_node = node.args[5] + + # Validate quantization parameters match our shader's assumptions. + # The shader hardcodes the 4-bit signed offset (subtract 8), which + # corresponds to quant_min=-8, quant_max=7, zero_points=0. + quant_min = node.args[3] + quant_max = node.args[4] + if quant_min != -8 or quant_max != 7: + self.match_found = False + return + + # weight_zero_points (args[2]) should be None or all-zeros + weight_zp_node = node.args[2] + if weight_zp_node is not None: + # If it's a constant tensor, verify it's all zeros + if ( + isinstance(weight_zp_node, torch.fx.Node) + and "val" in weight_zp_node.meta + ): + zp_val = weight_zp_node.meta["val"] + if isinstance(zp_val, torch.Tensor) and not torch.all(zp_val == 0): + self.match_found = False + return + + # Trace weight to its placeholder + const_node, arg_chain = utils.trace_args_until_placeholder(self.weight_node) + if const_node is not None: + self.weight_node = const_node + self.all_nodes.extend(arg_chain) + + # Trace scales to their placeholder + scales_node, arg_chain = utils.trace_args_until_placeholder(self.scales_node) + if scales_node is not None: + self.scales_node = scales_node + self.all_nodes.extend(arg_chain) + + self.match_found = True + + +embedding_4bit_target = exir_ops.edge.quantized_decomposed.embedding_4bit.dtype + + +def _detect_tied_linear_weight( + ep: ExportedProgram, + weight_node: torch.fx.Node, + weight_tensor: torch.Tensor, +) -> bool: + """Check if this embedding weight is tied to a linear weight. + + The embedding weight is packed uint8 [vocab_size, embed_dim/2]. The linear + output weight may be stored as unpacked int8 [vocab_size, embed_dim]. If we + find a placeholder whose int8 values match our unpacked embedding values, + the weights are tied and we should use the linear packing to enable dedup. + """ + vocab_size = weight_tensor.shape[0] + embed_dim = weight_tensor.shape[1] * 2 + + # Unpack embedding weight using embedding convention (high nibble first) + emb_high = (weight_tensor >> 4).to(torch.int8) - 8 + emb_low = (weight_tensor & 0xF).to(torch.int8) - 8 + emb_unpacked = torch.stack([emb_high, emb_low], dim=-1).reshape( + vocab_size, embed_dim + ) + + for node in ep.graph_module.graph.nodes: + if node.op != "placeholder" or node == weight_node: + continue + + try: + candidate = get_param_tensor(ep, node) + except RuntimeError: + continue + if candidate is None: + continue + if candidate.shape != (vocab_size, embed_dim) or candidate.dtype != torch.int8: + continue + + if torch.equal(emb_unpacked, candidate): + return True + + return False + + +@register_pattern_detector("quantized_embedding") +def find_quantized_embedding_patterns( + node: torch.fx.Node, +) -> Optional[QuantizedEmbeddingMatch]: + if node.target != embedding_4bit_target: + return None + + matched_pattern = QuantizedEmbeddingMatch(node) + if matched_pattern.match_found: + return matched_pattern + return None + + +@register_pattern_replacement("quantized_embedding") +def replace_quantized_embedding_patterns( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedEmbeddingMatch, +): + weight_tensor = get_param_tensor(ep, match.weight_node) + assert weight_tensor is not None + + scales_tensor = get_param_tensor(ep, match.scales_node) + assert scales_tensor is not None + + is_linear_weight = _detect_tied_linear_weight(ep, match.weight_node, weight_tensor) + + if is_linear_weight: + # Repack using linear convention (low nibble = even, high nibble = odd) + vocab_size = weight_tensor.shape[0] + high = (weight_tensor >> 4).to(torch.int8) - 8 + low = (weight_tensor & 0xF).to(torch.int8) - 8 + unpacked = torch.stack([high, low], dim=-1).reshape(vocab_size, -1) + repacked = unpacked.to(torch.uint8) + 8 + weight_tensor = repacked[:, 1::2] << 4 | repacked[:, ::2] + # Update the state dict with repacked tensor + original_weight = get_param_tensor(ep, match.weight_node) + if original_weight is not None: + for key, value in ep.state_dict.items(): + if value.data_ptr() == original_weight.data_ptr(): + ep.state_dict[key] = weight_tensor + break + + # Compute group_size from weight and scales shapes + embed_dim = weight_tensor.shape[1] * 2 # packed, 2 values per byte + groups_per_row = scales_tensor.shape[1] if scales_tensor.ndim > 1 else 1 + group_size = embed_dim // groups_per_row + + with graph_module.graph.inserting_before(match.anchor_node): + embedding_q4gsw_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.embedding_q4gsw.default, + args=( + match.weight_node, + match.scales_node, + group_size, + match.indices_node, + is_linear_weight, + ), + ) + + embedding_q4gsw_node.meta["val"] = match.anchor_node.meta["val"] + match.anchor_node.replace_all_uses_with(embedding_q4gsw_node) diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl new file mode 100644 index 00000000000..ce6779a0c9b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl @@ -0,0 +1,165 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +$if STORAGE == "buffer": + #define OUTPUT_BUFFER + +$if IS_LINEAR_WEIGHT: + #define LINEAR_WEIGHT + $if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +${define_required_extensions(STORAGE, DTYPE)} +$if IS_LINEAR_WEIGHT: + ${define_required_extensions(WEIGHT_STORAGE, "int")} +$else: + ${define_required_extensions("buffer", "int")} +${define_required_extensions("buffer", [SCALES_DTYPE, "uint8"])} + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +// Output uses the graph's storage type +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +// Indices use the graph's storage type +${layout_declare_tensor(B, "r", "t_indices", "int", STORAGE)} +// Weight: flat buffer for regular, packed block format for linear_weight +$if IS_LINEAR_WEIGHT: + ${layout_declare_tensor(B, "r", "t_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +$else: + ${layout_declare_tensor(B, "r", "t_weight", "int", "buffer", is_scalar_array=False)} +// Scales are ALWAYS buffer, loaded as scalar +${layout_declare_tensor(B, "r", "t_scales", SCALES_DTYPE, "buffer")} + +layout(push_constant) uniform PushConstants { + int group_size; + int embed_dim; + int num_indices; + int out_height; + int is_linear_weight; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Load 4 dequantized embedding values at the given dimension offset. +// The weight storage format differs between regular and linear_weight variants, +// so this function is defined separately for each. +#ifdef LINEAR_WEIGHT + +// Linear-packed block format: weight is stored as blocks indexed by +// t_weight[n8 * K4 + k4], where each ivec4 element contains 8 interleaved +// 4-bit values from 2 sub-rows of an 8-row block. +VEC4_T load_embedding_weights( + const int embedding_idx, + const int dim, + const float scale) { + const int n8 = embedding_idx >> 3; + const int n_local = embedding_idx & 7; + const int row_in_block = n_local < 4 ? n_local : n_local - 4; + const int shift_base = n_local < 4 ? 0 : 4; + const int K4 = embed_dim >> 2; + const int k4 = dim >> 2; + +#ifdef WEIGHT_BUFFER + const ivec4 block = t_weight[n8 * K4 + k4]; +#else + const ivec4 block = texelFetch(t_weight, ivec2(k4, n8), 0); +#endif + + const uint packed_uint = uint(block[row_in_block]); + + return VEC4_T( + T(float(int((packed_uint >> shift_base) & 0xFu) - 8) * scale), + T(float(int((packed_uint >> (shift_base + 8)) & 0xFu) - 8) * scale), + T(float(int((packed_uint >> (shift_base + 16)) & 0xFu) - 8) * scale), + T(float(int((packed_uint >> (shift_base + 24)) & 0xFu) - 8) * scale)); +} + +#else // !LINEAR_WEIGHT + +// Flat buffer format: weight is stored as a contiguous array of packed bytes. +// Each ivec4 covers 32 int4 values (16 bytes) for one embedding row. +VEC4_T load_embedding_weights( + const int embedding_idx, + const int dim, + const float scale) { + const int blocks_per_row = embed_dim >> 5; + const int block_in_row = dim >> 5; + const int t = (dim >> 2) & 7; + + const ivec4 packed = t_weight[embedding_idx * blocks_per_row + block_in_row]; + const int int_idx = t >> 1; + const int byte_pair = t & 1; + + const uint u = uint(packed[int_idx]); + const int shift = byte_pair << 4; + const uint b0 = (u >> shift) & 0xFFu; + const uint b1 = (u >> (shift + 8)) & 0xFFu; + + return VEC4_T( + T(float(int(b0 >> 4) - 8) * scale), + T(float(int(b0 & 0xFu) - 8) * scale), + T(float(int(b1 >> 4) - 8) * scale), + T(float(int(b1 & 0xFu) - 8) * scale)); +} + +#endif // LINEAR_WEIGHT + +void main() { + const int block_in_row = int(gl_GlobalInvocationID.x); + const int y_idx = int(gl_GlobalInvocationID.y); + const int z_idx = int(gl_GlobalInvocationID.z); + + const int blocks_per_row = embed_dim >> 5; + const int indices_idx = z_idx * out_height + y_idx; + if (block_in_row >= blocks_per_row || indices_idx >= num_indices) { + return; + } + +#ifdef OUTPUT_BUFFER + const int embedding_idx = t_indices[indices_idx]; +#else + const ivec4 in_texel = + texelFetch(t_indices, ivec3(indices_idx >> 2, 0, 0), 0); + const int embedding_idx = in_texel[indices_idx & 3]; +#endif + + const int base_dim = block_in_row << 5; + const int groups_per_row = embed_dim / group_size; + + [[unroll]] for (int t = 0; t < 8; t++) { + const int dim = base_dim + (t << 2); + const float scale = + float(t_scales[embedding_idx * groups_per_row + dim / group_size]); + + const VEC4_T vals = + load_embedding_weights(embedding_idx, dim, scale); + +#ifdef OUTPUT_BUFFER + const int out_base = indices_idx * embed_dim + dim; + t_out[out_base] = vals.x; + t_out[out_base + 1] = vals.y; + t_out[out_base + 2] = vals.z; + t_out[out_base + 3] = vals.w; +#else + imageStore( + t_out, + ivec3((base_dim >> 2) + t, y_idx, z_idx), + vals); +#endif + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.yaml b/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.yaml new file mode 100644 index 00000000000..9102bbb542e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.yaml @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +embedding_q4gsw: + parameter_names_with_default_values: + DTYPE: half + SCALES_DTYPE: half + STORAGE: buffer + WEIGHT_STORAGE: buffer + IS_LINEAR_WEIGHT: false + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture3d + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: embedding_q4gsw + - NAME: embedding_q4gsw_float_scales + SCALES_DTYPE: float + - NAME: embedding_q4gsw_linear_weight_buffer + IS_LINEAR_WEIGHT: true + WEIGHT_STORAGE: buffer + - NAME: embedding_q4gsw_linear_weight_texture2d + IS_LINEAR_WEIGHT: true + WEIGHT_STORAGE: texture2d + - NAME: embedding_q4gsw_float_scales_linear_weight_buffer + SCALES_DTYPE: float + IS_LINEAR_WEIGHT: true + WEIGHT_STORAGE: buffer + - NAME: embedding_q4gsw_float_scales_linear_weight_texture2d + SCALES_DTYPE: float + IS_LINEAR_WEIGHT: true + WEIGHT_STORAGE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp b/backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp new file mode 100644 index 00000000000..c5051392074 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp @@ -0,0 +1,166 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include + +#include + +#include + +namespace vkcompute { + +utils::uvec3 pick_embedding_q4gsw_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const std::vector& sizes = graph->sizes_of(out); + int ndim = sizes.size(); + + uint32_t blocks_per_row = static_cast(sizes[ndim - 1]) / 32; + uint32_t height = ndim >= 2 ? static_cast(sizes[ndim - 2]) : 1; + uint32_t depth = 1; + for (int i = 0; i < ndim - 2; i++) { + depth *= static_cast(sizes[i]); + } + + return {blocks_per_row, height, depth}; +} + +void resize_embedding_q4gsw_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef indices = args.at(1).refs.at(0); + + const int64_t embed_dim = graph->get_int(resize_args.at(0)); + const std::vector indices_sizes = graph->sizes_of(indices); + + // Output shape is indices.shape + [embed_dim] + std::vector out_sizes = indices_sizes; + out_sizes.push_back(embed_dim); + + graph->virtual_resize(out, out_sizes); +} + +void add_embedding_q4gsw_node( + ComputeGraph& graph, + const ValueRef indices, + const ValueRef weight, + const ValueRef weight_scales, + const int32_t group_size, + const int32_t embed_dim, + const int32_t num_indices, + const int32_t out_height, + const int32_t is_linear_weight, + const ValueRef out) { + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(indices) == WHCN::kWidthDim); + VK_CHECK_COND(embed_dim % 32 == 0, "embed_dim must be a multiple of 32"); + + std::string kernel_name = "embedding_q4gsw"; + kernel_name.reserve(kShaderNameReserve); + + vkapi::ScalarType scales_dtype = graph.dtype_of(weight_scales); + if (scales_dtype != vkapi::kHalf) { + kernel_name += "_float_scales"; + } + + if (is_linear_weight) { + kernel_name += "_linear_weight"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(weight)); + } + + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + std::vector push_constants = { + PushConstantDataInfo(&group_size, sizeof(group_size)), + PushConstantDataInfo(&embed_dim, sizeof(embed_dim)), + PushConstantDataInfo(&num_indices, sizeof(num_indices)), + PushConstantDataInfo(&out_height, sizeof(out_height)), + PushConstantDataInfo(&is_linear_weight, sizeof(is_linear_weight)), + }; + + ValueRef embed_dim_ref = graph.add_scalar(embed_dim); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_embedding_q4gsw_global_wg_size, + default_pick_local_wg_size, + {{out, vkapi::kWrite}, {{indices, weight, weight_scales}, vkapi::kRead}}, + {}, + push_constants, + {}, + {embed_dim_ref}, + resize_embedding_q4gsw_node)); +} + +void embedding_q4gsw(ComputeGraph& graph, const std::vector& args) { + ValueRef weight_data = args[0]; + ValueRef weight_scales_data = args[1]; + ValueRef group_size_ref = args[2]; + ValueRef indices = args[3]; + ValueRef is_linear_weight_ref = args[4]; + ValueRef out = args[5]; + + int32_t group_size = graph.extract_scalar(group_size_ref); + int32_t is_linear_weight = + graph.extract_scalar(is_linear_weight_ref) ? 1 : 0; + + const std::vector weight_sizes = graph.sizes_of(weight_data); + int32_t embed_dim = static_cast(weight_sizes.back() * 2); + + const std::vector indices_sizes = graph.sizes_of(indices); + int32_t num_indices = 1; + for (auto s : indices_sizes) { + num_indices *= static_cast(s); + } + int32_t out_height = static_cast(indices_sizes.back()); + + ValueRef weight; + if (is_linear_weight) { + QuantizationConfig weight_quant_config(4, kPerGroup, {group_size}); + weight = prepack_quantized_linear_weight( + graph, weight_quant_config, weight_data); + } else { + weight = prepack_standard( + graph, weight_data, utils::kBuffer, utils::kWidthPacked); + } + ValueRef weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + add_embedding_q4gsw_node( + graph, + indices, + weight, + weight_scales, + group_size, + embed_dim, + num_indices, + out_height, + is_linear_weight, + out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.embedding_q4gsw.default, embedding_q4gsw); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 84432bce30b..e8cdb3a4bf8 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -102,3 +102,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("test_mm") define_custom_op_test_binary("test_conv2d_pw") define_custom_op_test_binary("test_conv2d_dw") + define_custom_op_test_binary("test_embedding_q4gsw") diff --git a/backends/vulkan/test/custom_ops/test_embedding_q4gsw.cpp b/backends/vulkan/test/custom_ops/test_embedding_q4gsw.cpp new file mode 100644 index 00000000000..bb588fd3dbb --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_embedding_q4gsw.cpp @@ -0,0 +1,520 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +#include "utils.h" + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +struct EmbeddingConfig { + int64_t vocab_size; + int64_t embed_dim; + int64_t group_size; + std::vector indices_shape; + std::string test_case_name = "placeholder"; + vkapi::ScalarType dtype = vkapi::kHalf; + vkapi::ScalarType scales_dtype = vkapi::kHalf; + utils::StorageType storage_type = utils::kBuffer; + bool is_linear_weight = false; +}; + +// CPU reference: unpack 4-bit weights, dequantize, and perform embedding lookup +void embedding_4bit_reference(TestCase& tc) { + auto& weight_spec = tc.inputs()[0]; + auto& scales_spec = tc.inputs()[1]; + int32_t group_size = tc.inputs()[2].get_int_value(); + auto& indices_spec = tc.inputs()[3]; + bool is_linear_weight = tc.inputs()[4].get_bool_value(); + auto& output_spec = tc.outputs()[0]; + + weight_spec.ensure_data_generated(); + scales_spec.ensure_data_generated(); + indices_spec.ensure_data_generated(); + + const auto& weight_data = weight_spec.get_uint8_data(); + const auto& indices_data = indices_spec.get_int32_data(); + + bool scales_are_half = (scales_spec.dtype == vkapi::kHalf); + + int64_t packed_dim = weight_spec.sizes[1]; + int64_t embed_dim = packed_dim * 2; + int64_t groups_per_row = scales_spec.sizes[1]; + + int64_t num_indices = 1; + for (auto s : indices_spec.sizes) { + num_indices *= s; + } + + int64_t total_output = num_indices * embed_dim; + + // Always populate ref_float_data so the caching framework can distribute it + output_spec.get_ref_float_data().resize(total_output); + + bool output_is_half = (output_spec.dtype == vkapi::kHalf); + if (output_is_half) { + output_spec.get_ref_half_data().resize(total_output); + } + + for (int64_t i = 0; i < num_indices; ++i) { + int32_t idx = indices_data[i]; + for (int64_t d = 0; d < embed_dim; ++d) { + int64_t packed_idx = d / 2; + uint8_t packed_byte = weight_data[idx * packed_dim + packed_idx]; + + // Unpack: packed_byte = (even_val + 8) << 4 | (odd_val + 8) + // Even d -> high nibble, odd d -> low nibble + // For linear weight packing, nibble order is swapped + int int4_val; + if (d % 2 == 0) { + if (is_linear_weight) { + int4_val = static_cast(packed_byte & 0xF) - 8; + } else { + int4_val = static_cast(packed_byte >> 4) - 8; + } + } else { + if (is_linear_weight) { + int4_val = static_cast(packed_byte >> 4) - 8; + } else { + int4_val = static_cast(packed_byte & 0xF) - 8; + } + } + + int64_t group_idx = d / group_size; + int64_t scale_idx = idx * groups_per_row + group_idx; + + float scale; + if (scales_are_half) { + uint16_t scale_half = scales_spec.get_half_data()[scale_idx]; + scale = half_to_float(scale_half); + } else { + scale = scales_spec.get_float_data()[scale_idx]; + } + + float result = static_cast(int4_val) * scale; + + // Always store float reference + output_spec.get_ref_float_data()[i * embed_dim + d] = result; + + if (output_is_half) { + output_spec.get_ref_half_data()[i * embed_dim + d] = + float_to_half(result); + } + } + } +} + +TestCase create_test_case(const EmbeddingConfig& config) { + TestCase test_case; + test_case.set_name(config.test_case_name); + test_case.set_operator_name("et_vk.embedding_q4gsw.default"); + test_case.set_shader_filter({}); + + // Weight: [vocab_size, embed_dim / 2] packed uint8 + ValueSpec weight( + {config.vocab_size, config.embed_dim / 2}, + vkapi::kByte, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDINT4); + weight.set_constant(true); + test_case.add_input_spec(weight); + + // Weight scales: [vocab_size, groups_per_row] + int64_t groups_per_row = config.embed_dim / config.group_size; + ValueSpec weight_scales( + {config.vocab_size, groups_per_row}, + config.scales_dtype, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + test_case.add_input_spec(weight_scales); + + // Group size: int scalar + ValueSpec group_size_spec(static_cast(config.group_size)); + test_case.add_input_spec(group_size_spec); + + // Indices: [batch, seq_len] int32 + ValueSpec indices( + config.indices_shape, + vkapi::kInt, + config.storage_type, + utils::kWidthPacked, + DataGenType::RANDINT); + + // Clamp indices to valid vocab range + indices.ensure_data_generated(); + for (auto& idx : indices.get_int32_data()) { + idx = std::abs(idx) % config.vocab_size; + } + + test_case.add_input_spec(indices); + + // is_linear_weight: bool scalar + ValueSpec is_linear_weight_spec(config.is_linear_weight); + test_case.add_input_spec(is_linear_weight_spec); + + // Output: indices.shape + [embed_dim] + std::vector output_shape = config.indices_shape; + output_shape.push_back(config.embed_dim); + ValueSpec output( + output_shape, config.dtype, config.storage_type, utils::kWidthPacked); + test_case.add_output_spec(output); + + return test_case; +} + +std::vector generate_test_cases() { + std::vector test_cases; + + // --- is_linear_weight = true --- + + // Basic test with linear weight packing + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_linear_weight", + .is_linear_weight = true})); + + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d_linear_weight", + .is_linear_weight = true})); + + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup_linear_weight", + .is_linear_weight = true})); + + // fp32 output with linear weight + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_fp32_linear_weight", + .dtype = vkapi::kFloat, + .is_linear_weight = true})); + + // Texture3D with linear weight + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_texture_linear_weight", + .storage_type = utils::kTexture3D, + .is_linear_weight = true})); + + // Llama 3.2 1B with linear weight packing + test_cases.push_back(create_test_case( + {.vocab_size = 128256, + .embed_dim = 2048, + .group_size = 32, + .indices_shape = {1, 2047}, + .test_case_name = "llama_3_2_1b_prefill_fp32_linear_weight", + .dtype = vkapi::kFloat, + .storage_type = utils::kBuffer, + .is_linear_weight = true})); + + // --- Half scales (default) --- + + // Basic test: small vocab, small embed_dim + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d"})); + + // 2D indices + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d"})); + + // Larger vocab, multiple groups + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup"})); + + // group_size == embed_dim (single group) + test_cases.push_back(create_test_case( + {.vocab_size = 50, + .embed_dim = 64, + .group_size = 64, + .indices_shape = {2, 4}, + .test_case_name = "single_group"})); + + // fp32 output variants (half scales) + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_fp32", + .dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d_fp32", + .dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup_fp32", + .dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 50, + .embed_dim = 64, + .group_size = 64, + .indices_shape = {2, 4}, + .test_case_name = "single_group_fp32", + .dtype = vkapi::kFloat})); + + // Texture3D variants (fp16 output, half scales) + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_texture", + .dtype = vkapi::kHalf, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d_texture", + .dtype = vkapi::kHalf, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup_texture", + .dtype = vkapi::kHalf, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 50, + .embed_dim = 64, + .group_size = 64, + .indices_shape = {2, 4}, + .test_case_name = "single_group_texture", + .dtype = vkapi::kHalf, + .storage_type = utils::kTexture3D})); + + // Texture3D variants (fp32 output, half scales) + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_fp32_texture", + .dtype = vkapi::kFloat, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d_fp32_texture", + .dtype = vkapi::kFloat, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup_fp32_texture", + .dtype = vkapi::kFloat, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 50, + .embed_dim = 64, + .group_size = 64, + .indices_shape = {2, 4}, + .test_case_name = "single_group_fp32_texture", + .dtype = vkapi::kFloat, + .storage_type = utils::kTexture3D})); + + // --- Float scales --- + + // Buffer variants with float scales + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_float_scales", + .dtype = vkapi::kHalf, + .scales_dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d_float_scales", + .dtype = vkapi::kHalf, + .scales_dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup_float_scales", + .dtype = vkapi::kHalf, + .scales_dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_fp32_float_scales", + .dtype = vkapi::kFloat, + .scales_dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 32, + .embed_dim = 64, + .group_size = 32, + .indices_shape = {2, 3}, + .test_case_name = "small_2d_fp32_float_scales", + .dtype = vkapi::kFloat, + .scales_dtype = vkapi::kFloat})); + + test_cases.push_back(create_test_case( + {.vocab_size = 100, + .embed_dim = 128, + .group_size = 32, + .indices_shape = {4, 8}, + .test_case_name = "medium_multigroup_fp32_float_scales", + .dtype = vkapi::kFloat, + .scales_dtype = vkapi::kFloat})); + + // Texture3D with float scales + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_float_scales_texture", + .dtype = vkapi::kHalf, + .scales_dtype = vkapi::kFloat, + .storage_type = utils::kTexture3D})); + + test_cases.push_back(create_test_case( + {.vocab_size = 16, + .embed_dim = 32, + .group_size = 32, + .indices_shape = {4}, + .test_case_name = "small_1d_fp32_float_scales_texture", + .dtype = vkapi::kFloat, + .scales_dtype = vkapi::kFloat, + .storage_type = utils::kTexture3D})); + + // Llama 3.2 1B prefill configuration (fp32 output, half scales) + test_cases.push_back(create_test_case( + {.vocab_size = 128256, + .embed_dim = 2048, + .group_size = 32, + .indices_shape = {1, 2047}, + .test_case_name = "llama_3_2_1b_prefill_fp32", + .dtype = vkapi::kFloat, + .storage_type = utils::kBuffer})); + + // Llama 3.2 1B prefill configuration (fp32 output, float scales) + test_cases.push_back(create_test_case( + {.vocab_size = 128256, + .embed_dim = 2048, + .group_size = 32, + .indices_shape = {1, 2047}, + .test_case_name = "llama_3_2_1b_prefill_fp32_float_scales", + .dtype = vkapi::kFloat, + .scales_dtype = vkapi::kFloat, + .storage_type = utils::kBuffer})); + + // Llama 3.2 1B prefill configuration (fp16 output, half scales) + test_cases.push_back(create_test_case( + {.vocab_size = 128256, + .embed_dim = 2048, + .group_size = 32, + .indices_shape = {1, 2047}, + .test_case_name = "llama_3_2_1b_prefill_fp16", + .dtype = vkapi::kHalf, + .storage_type = utils::kBuffer})); + + // Llama 3.2 1B prefill configuration (fp16 output, float scales) + test_cases.push_back(create_test_case( + {.vocab_size = 128256, + .embed_dim = 2048, + .group_size = 32, + .indices_shape = {1, 2047}, + .test_case_name = "llama_3_2_1b_prefill_fp16_float_scales", + .dtype = vkapi::kHalf, + .scales_dtype = vkapi::kFloat, + .storage_type = utils::kBuffer})); + + return test_cases; +} + +int main(int argc, char** argv) { + auto results = execute_test_cases( + generate_test_cases, + "embedding_q4gsw", + /* warmup_runs */ 3, + /* benchmark_runs */ 10, + embedding_4bit_reference); + + results.print_summary(); + + if (results.get_failed_count() > 0) { + std::cerr << "FAILED: " << results.get_failed_count() << " test(s) failed." + << std::endl; + return 1; + } + + std::cout << "PASSED: All tests passed." << std::endl; + return 0; +} diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index 2a50e7b5ec1..fc406c47403 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -74,6 +74,49 @@ void generate_random_int4_data( void generate_ones_data(std::vector& data); void generate_zeros_data(std::vector& data); +// Convert a float32 value to IEEE 754 half-precision (uint16_t) +uint16_t float_to_half(float value) { + uint32_t float_bits; + std::memcpy(&float_bits, &value, sizeof(float)); + uint32_t sign = (float_bits >> 31) & 0x1; + int32_t exponent = static_cast((float_bits >> 23) & 0xFF) - 127; + uint32_t mantissa = float_bits & 0x7FFFFF; + + uint16_t half_val; + if (exponent > 15) { + half_val = static_cast((sign << 15) | 0x7C00); // Inf + } else if (exponent < -14) { + half_val = static_cast(sign << 15); // Zero / subnormal + } else { + half_val = static_cast( + (sign << 15) | (static_cast(exponent + 15) << 10) | + (mantissa >> 13)); + } + return half_val; +} + +// Convert a IEEE 754 half-precision (uint16_t) value to float32 +float half_to_float(uint16_t half_val) { + uint32_t sign = (half_val >> 15) & 0x1; + uint32_t exponent = (half_val >> 10) & 0x1F; + uint32_t mantissa = half_val & 0x3FF; + + float result; + if (exponent == 0) { + result = std::ldexp(static_cast(mantissa), -24); + } else if (exponent == 31) { + result = mantissa ? std::numeric_limits::quiet_NaN() + : std::numeric_limits::infinity(); + } else { + result = std::ldexp( + 1.0f + static_cast(mantissa) / 1024.0f, exponent - 15); + } + if (sign) { + result = -result; + } + return result; +} + // Output and latency printing utilities namespace { static int print_output_enabled = 0; @@ -154,6 +197,13 @@ void ValueSpec::generate_tensor_data(int seed) { // Simple conversion to uint16_t representation of half half_data[i] = static_cast(temp_data[i] * 32767.0f); } + } else if (data_gen_type == DataGenType::RANDOM_SCALES) { + // Generate random scales in float, then convert to proper fp16 + std::vector temp_data(num_elements); + generate_random_float_data(temp_data, 0.005f, 0.015f, seed); + for (size_t i = 0; i < temp_data.size(); ++i) { + half_data[i] = float_to_half(temp_data[i]); + } } else if (data_gen_type == DataGenType::RANDINT) { generate_randint_half_data(half_data, -10, 10, seed); } else if (data_gen_type == DataGenType::RANDINT8) { @@ -1395,15 +1445,19 @@ execute_test_case(TestCase& test_case, int warmup_runs, int benchmark_runs) { graph.prepack(); // Copy input data into the graph's staging buffers + size_t graph_input_idx = 0; for (size_t i = 0; i < test_case.num_inputs(); ++i) { const ValueSpec& input_spec = test_case.inputs()[i]; - if (input_spec.is_tensor() && i < graph.inputs().size()) { - // Skip copying data for constant tensors - if (input_spec.is_constant()) { - continue; - } - const auto& input_ref = graph.inputs()[i]; + // Only non-constant tensor inputs correspond to graph.inputs() entries + bool is_graph_input = input_spec.is_tensor() && !input_spec.is_constant() && + !input_spec.is_none(); + if (!is_graph_input) { + continue; + } + + if (graph_input_idx < graph.inputs().size()) { + const auto& input_ref = graph.inputs()[graph_input_idx]; // Get the appropriate data based on dtype const void* data_ptr = nullptr; @@ -1433,6 +1487,7 @@ execute_test_case(TestCase& test_case, int warmup_runs, int benchmark_runs) { graph.maybe_cast_and_copy_into_staging( input_ref.staging, data_ptr, data_numel, input_spec.dtype); } + ++graph_input_idx; } // Warmup runs diff --git a/backends/vulkan/test/custom_ops/utils.h b/backends/vulkan/test/custom_ops/utils.h index 9b5b6a46782..4f3e4ec629f 100644 --- a/backends/vulkan/test/custom_ops/utils.h +++ b/backends/vulkan/test/custom_ops/utils.h @@ -864,6 +864,10 @@ void compute_weight_sums_4bit_grouped( int64_t out_features, int64_t group_size); +// Half-precision conversion utilities +uint16_t float_to_half(float value); +float half_to_float(uint16_t half_val); + // Setup compute graph based on TestCase and operation name ComputeGraph setup_compute_graph(TestCase& test_case, std::string op_name);