Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.export import ExportedProgram


def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: # noqa: C901
"""
Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator
is responsible for transferring the tensor data, which is serialized with the model,
Expand Down Expand Up @@ -54,9 +54,13 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
# Most prepacking ops have the primary input at arg 0, but
# embedding is embedding(weight, indices, ...) where the
# primary input (indices) is at arg 1.
# For embedding_q4gsw, args are (weight, weight_scales, group_size,
# indices) so the primary input (indices) is at arg 3.
primary_arg_idx = 0
if user.target == exir_ops.edge.aten.embedding.default:
primary_arg_idx = 1
elif user.target == exir_ops.edge.et_vk.embedding_q4gsw.default:
primary_arg_idx = 3

if node in user.args and user.args.index(node) == primary_arg_idx:
nodes_to_replace_input.append(user)
Expand Down
40 changes: 40 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,46 @@ def q8ta_relu_impl(
lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd")
q8ta_relu_op = getattr(getattr(torch.ops, namespace), name)

########################
## embedding_q4gsw ##
########################


def embedding_q4gsw_impl(
weight: torch.Tensor,
weight_scales: torch.Tensor,
group_size: int,
indices: torch.Tensor,
is_linear_weight: bool = False,
) -> torch.Tensor:
# Unpack 4-bit values from packed uint8 tensor
# Packing convention: packed_byte = (even_val + 8) << 4 | (odd_val + 8)
high = (weight >> 4).to(torch.int8) - 8
low = (weight & 0xF).to(torch.int8) - 8
if is_linear_weight:
unpacked = torch.stack([low, high], dim=-1).reshape(weight.shape[0], -1)
else:
unpacked = torch.stack([high, low], dim=-1).reshape(weight.shape[0], -1)
# Dequantize using per-group scales
num_groups = weight_scales.shape[1] if weight_scales.dim() > 1 else 1
unpacked_groups = unpacked.reshape(weight.shape[0], num_groups, group_size)
scales = (
weight_scales.unsqueeze(-1)
if weight_scales.dim() > 1
else weight_scales.reshape(1, 1, 1)
)
dequantized = unpacked_groups.float() * scales.float()
dequantized = dequantized.reshape(weight.shape[0], -1)
return torch.nn.functional.embedding(indices, dequantized)


name = "embedding_q4gsw"
lib.define(
f"{name}(Tensor weight, Tensor weight_scales, int group_size, Tensor indices, bool is_linear_weight = False) -> Tensor"
)
lib.impl(name, embedding_q4gsw_impl, "CompositeExplicitAutograd")
embedding_q4gsw_op = getattr(getattr(torch.ops, namespace), name)

#############################
## select_as_symint ##
#############################
Expand Down
32 changes: 32 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,38 @@ def check_embedding_weight_size(node: torch.fx.Node) -> bool:
)


# =============================================================================
# EmbeddingQ4gsw (Quantized Embedding)
# =============================================================================


@update_features(exir_ops.edge.quantized_decomposed.embedding_4bit.dtype)
def register_quantized_decomposed_embedding_4bit():
def check_embedding_4bit_weight_size(node: torch.fx.Node) -> bool:
weight = node.args[0]
if isinstance(weight, torch.fx.Node) and utils.is_tensor_node(weight):
numel = weight.meta["val"].numel()
if numel > utils.DEFAULT_BUFFER_LIMIT:
return False
return True

return OpFeatures(
inputs_storage=utils.ANY_BUFFER,
supports_prepacking=True,
supports_resize=True,
are_node_inputs_supported_fn=check_embedding_4bit_weight_size,
)


@update_features(exir_ops.edge.et_vk.embedding_q4gsw.default)
def register_embedding_q4gsw():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
supports_prepacking=True,
supports_resize=True,
)


# =============================================================================
# BatchNorm.cpp
# =============================================================================
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/patterns/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ fbcode_target(_kind = runtime.python_library,
"__init__.py",
"pattern_registry.py",
"rope.py",
"quantized_embedding.py",
"quantized_linear.py",
"quantized_convolution.py",
"quantized_binary.py",
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import executorch.backends.vulkan.patterns.quantized_convolution # noqa

import executorch.backends.vulkan.patterns.quantized_embedding # noqa

import executorch.backends.vulkan.patterns.quantized_linear # noqa

import executorch.backends.vulkan.patterns.quantized_unary # noqa
Expand Down
177 changes: 177 additions & 0 deletions backends/vulkan/patterns/quantized_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import executorch.backends.vulkan.utils as utils
import torch
from executorch.backends.transforms.utils import get_param_tensor
from executorch.backends.vulkan.patterns.pattern_registry import (
PatternMatch,
register_pattern_detector,
register_pattern_replacement,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops


class QuantizedEmbeddingMatch(PatternMatch):
def __init__(self, node: torch.fx.Node) -> None:
self.anchor_node = node
self.match_found = False
self.all_nodes = [node]

# quantized_decomposed.embedding_4bit.dtype args:
# (weight, weight_scales, weight_zero_points, quant_min, quant_max,
# indices, *, dtype)
self.weight_node = node.args[0]
self.scales_node = node.args[1]
self.indices_node = node.args[5]

# Validate quantization parameters match our shader's assumptions.
# The shader hardcodes the 4-bit signed offset (subtract 8), which
# corresponds to quant_min=-8, quant_max=7, zero_points=0.
quant_min = node.args[3]
quant_max = node.args[4]
if quant_min != -8 or quant_max != 7:
self.match_found = False
return

# weight_zero_points (args[2]) should be None or all-zeros
weight_zp_node = node.args[2]
if weight_zp_node is not None:
# If it's a constant tensor, verify it's all zeros
if (
isinstance(weight_zp_node, torch.fx.Node)
and "val" in weight_zp_node.meta
):
zp_val = weight_zp_node.meta["val"]
if isinstance(zp_val, torch.Tensor) and not torch.all(zp_val == 0):
self.match_found = False
return

# Trace weight to its placeholder
const_node, arg_chain = utils.trace_args_until_placeholder(self.weight_node)
if const_node is not None:
self.weight_node = const_node
self.all_nodes.extend(arg_chain)

# Trace scales to their placeholder
scales_node, arg_chain = utils.trace_args_until_placeholder(self.scales_node)
if scales_node is not None:
self.scales_node = scales_node
self.all_nodes.extend(arg_chain)

self.match_found = True


embedding_4bit_target = exir_ops.edge.quantized_decomposed.embedding_4bit.dtype


def _detect_tied_linear_weight(
ep: ExportedProgram,
weight_node: torch.fx.Node,
weight_tensor: torch.Tensor,
) -> bool:
"""Check if this embedding weight is tied to a linear weight.

The embedding weight is packed uint8 [vocab_size, embed_dim/2]. The linear
output weight may be stored as unpacked int8 [vocab_size, embed_dim]. If we
find a placeholder whose int8 values match our unpacked embedding values,
the weights are tied and we should use the linear packing to enable dedup.
"""
vocab_size = weight_tensor.shape[0]
embed_dim = weight_tensor.shape[1] * 2

# Unpack embedding weight using embedding convention (high nibble first)
emb_high = (weight_tensor >> 4).to(torch.int8) - 8
emb_low = (weight_tensor & 0xF).to(torch.int8) - 8
emb_unpacked = torch.stack([emb_high, emb_low], dim=-1).reshape(
vocab_size, embed_dim
)

for node in ep.graph_module.graph.nodes:
if node.op != "placeholder" or node == weight_node:
continue

try:
candidate = get_param_tensor(ep, node)
except RuntimeError:
continue
if candidate is None:
continue
if candidate.shape != (vocab_size, embed_dim) or candidate.dtype != torch.int8:
continue

if torch.equal(emb_unpacked, candidate):
return True

return False


@register_pattern_detector("quantized_embedding")
def find_quantized_embedding_patterns(
node: torch.fx.Node,
) -> Optional[QuantizedEmbeddingMatch]:
if node.target != embedding_4bit_target:
return None

matched_pattern = QuantizedEmbeddingMatch(node)
if matched_pattern.match_found:
return matched_pattern
return None


@register_pattern_replacement("quantized_embedding")
def replace_quantized_embedding_patterns(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: QuantizedEmbeddingMatch,
):
weight_tensor = get_param_tensor(ep, match.weight_node)
assert weight_tensor is not None

scales_tensor = get_param_tensor(ep, match.scales_node)
assert scales_tensor is not None

is_linear_weight = _detect_tied_linear_weight(ep, match.weight_node, weight_tensor)

if is_linear_weight:
# Repack using linear convention (low nibble = even, high nibble = odd)
vocab_size = weight_tensor.shape[0]
high = (weight_tensor >> 4).to(torch.int8) - 8
low = (weight_tensor & 0xF).to(torch.int8) - 8
unpacked = torch.stack([high, low], dim=-1).reshape(vocab_size, -1)
repacked = unpacked.to(torch.uint8) + 8
weight_tensor = repacked[:, 1::2] << 4 | repacked[:, ::2]
# Update the state dict with repacked tensor
original_weight = get_param_tensor(ep, match.weight_node)
if original_weight is not None:
for key, value in ep.state_dict.items():
if value.data_ptr() == original_weight.data_ptr():
ep.state_dict[key] = weight_tensor
break

# Compute group_size from weight and scales shapes
embed_dim = weight_tensor.shape[1] * 2 # packed, 2 values per byte
groups_per_row = scales_tensor.shape[1] if scales_tensor.ndim > 1 else 1
group_size = embed_dim // groups_per_row

with graph_module.graph.inserting_before(match.anchor_node):
embedding_q4gsw_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.embedding_q4gsw.default,
args=(
match.weight_node,
match.scales_node,
group_size,
match.indices_node,
is_linear_weight,
),
)

embedding_q4gsw_node.meta["val"] = match.anchor_node.meta["val"]
match.anchor_node.replace_all_uses_with(embedding_q4gsw_node)
Loading
Loading