Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,47 @@ def check_conv_node(node: torch.fx.Node) -> bool:

return True

def pick_conv_storage(
node: torch.fx.Node,
) -> Tuple[List[utils.TensorRepSet], utils.TensorRepSet]:
x = node.args[0]
assert isinstance(x, torch.fx.Node)
x_shape = x.meta["val"].size()

# Default: channels-packed texture (conv2d and fallback conv1d)
input_storage = utils.CHANNELS_PACKED_TEXTURE
output_storage = utils.CHANNELS_PACKED_TEXTURE

if len(x_shape) == 3:
# Conv1d: check if we can use height-packed
weight = node.args[1]
assert isinstance(weight, torch.fx.Node)
w_shape = weight.meta["val"].size()
groups = node.args[8]

c_in = x_shape[1]
c_out = w_shape[0]
kernel_size = w_shape[2]

is_pointwise = kernel_size == 1
is_depthwise = (
isinstance(groups, int)
and groups == c_in
and c_out == c_in
and w_shape[1] == 1
)
if is_pointwise or is_depthwise:
input_storage = utils.HEIGHT_PACKED_TEXTURE
output_storage = utils.HEIGHT_PACKED_TEXTURE

# Build per-input storage list. The convolution op has variable args:
# aten.convolution.default: input, weight, bias, stride, padding,
# dilation, transposed, output_padding, groups
# et_vk.conv_with_clamp.default: + output_min, output_max
# All args after input are NO_STORAGE (prepacked or non-tensor)
inputs = [input_storage] + [utils.NO_STORAGE] * 10
return inputs, output_storage

return OpFeatures(
inputs_storage=[
utils.CHANNELS_PACKED_TEXTURE, # input
Expand All @@ -820,6 +861,7 @@ def check_conv_node(node: torch.fx.Node) -> bool:
supports_resize=True,
supports_prepacking=True,
are_node_inputs_supported_fn=check_conv_node,
pick_io_storage_fn=pick_conv_storage,
)


Expand Down
50 changes: 50 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,56 @@ void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) {
true);
}
} else {
// Conv1d path
if (graph.packed_dim_of(args[0]) == WHCN::kHeightDim) {
// Height-packed: route to optimized conv1d implementations
const auto weight_sizes = graph.sizes_of(args[1]);
const int64_t groups_val = graph.get_int(args[8]);
const bool is_pointwise = weight_sizes.at(2) == 1;
const bool is_depthwise =
groups_val == weight_sizes.at(0) && weight_sizes.at(1) == 1;

// Build unified 10-arg vector:
// in, weight, bias, stride, padding, dilation, groups,
// output_min, output_max, out
// For non-clamp (args.size() == 10): output_min/max = kDummyValueRef
// For clamp (args.size() == 12): output_min/max from args[9]/args[10]
ValueRef output_min = kDummyValueRef;
ValueRef output_max = kDummyValueRef;
ValueRef out;
if (args.size() == 10) {
out = args[9];
} else {
output_min = args[9];
output_max = args[10];
out = args[11];
}

std::vector<ValueRef> conv1d_args = {
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[8],
output_min,
output_max,
out};

if (is_pointwise) {
VK_GET_OP_FN("et_vk.conv1d_pw.default")(graph, conv1d_args);
} else if (is_depthwise) {
VK_GET_OP_FN("et_vk.conv1d_dw.default")(graph, conv1d_args);
} else {
VK_THROW(
"Height-packed conv1d only supports pointwise (K=1) or "
"depthwise (groups=C)");
}
return;
}

// Existing channels-packed fallback
if (args.size() == 10) {
// ordinary conv1d
return add_conv1d_node(
Expand Down
Loading