From 77854e8b4b6ece0650a36e9c332063714eca63e2 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 19 Mar 2026 23:14:16 -0700 Subject: [PATCH] [ET-VK][runtime] Add prepack cache to avoid duplicate weight prepacking When embedding and linear ops share tied weights and both use the same prepacking function (prepack_quantized_linear_weight), the weight gets prepacked twice, wasting GPU memory. Add a cache to ComputeGraph keyed by (input ValueRef, kernel name) that returns the already-prepacked tensor on cache hit, avoiding the duplicate allocation. Differential Revision: [D97430801](https://our.internmc.facebook.com/intern/diff/D97430801/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ComputeGraph.cpp | 17 ++++++++++ backends/vulkan/runtime/graph/ComputeGraph.h | 33 +++++++++++++++++++ .../graph/ops/impl/QuantizedLinear.cpp | 19 ++++++++--- 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 4b435ae6215..5e9c7b7ad2a 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -297,6 +297,23 @@ bool ComputeGraph::is_valid_value_idx(const ValueRef idx) const noexcept { return idx >= 0 && idx < static_cast(values_.size()); } +ValueRef ComputeGraph::get_cached_prepack( + const ValueRef input, + const std::string& kernel_name) const { + auto it = prepack_cache_.find({input, kernel_name}); + if (it != prepack_cache_.end()) { + return it->second; + } + return kDummyValueRef; +} + +void ComputeGraph::cache_prepack( + const ValueRef input, + const std::string& kernel_name, + const ValueRef prepacked) { + prepack_cache_.emplace(std::make_pair(input, kernel_name), prepacked); +} + std::vector ComputeGraph::sizes_of(const ValueRef idx) const { const Value& val = values_.at(idx); if (val.isTensor()) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 9935b9be51b..61968348edc 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -12,6 +12,7 @@ #include #include +#include #include @@ -204,6 +205,22 @@ class ComputeGraph final { // Set to track which ValueRefs were updated during inference std::unordered_set updated_values_; + // Cache to prevent duplicate prepacking of the same weight tensor with the + // same kernel. Key is (inputValueRef, kernel_name). + struct PrepackCacheHash { + size_t operator()(const std::pair& key) const { + size_t h1 = std::hash{}(key.first); + size_t h2 = std::hash{}(key.second); + // Combine hashes using a method similar to boost::hash_combine + return h1 ^ (h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2)); + } + }; + std::unordered_map< + std::pair, + ValueRef, + PrepackCacheHash> + prepack_cache_; + // Flag to indicate if re-encoding is required bool requires_reencode_ = false; @@ -687,6 +704,22 @@ class ComputeGraph final { void check_no_active_value_ptrs(); public: + /* + * Check if a prepacked tensor already exists for the given input and kernel. + */ + ValueRef get_cached_prepack( + const ValueRef input, + const std::string& kernel_name) const; + + /* + * Store a prepacked tensor in the cache, keyed by input ValueRef and kernel + * name. + */ + void cache_prepack( + const ValueRef input, + const std::string& kernel_name, + const ValueRef prepacked); + /* * Add a `api::vTensor` value to the graph with the specified properties. * There are various convenience overloads of this function that may be used diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 7a42d463f2a..4a29fe91c3d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -256,6 +256,19 @@ ValueRef prepack_quantized_linear_weight( storage_type = utils::kBuffer; } + std::string kernel_name = weight_quant_config.nbits == 4 + ? "pack_q4_linear_weight" + : "pack_q8_linear_weight"; + add_storage_type_suffix(kernel_name, storage_type); + + // Check prepack cache before creating a new prepack node. This avoids + // allocating a duplicate output tensor when the same weight data has already + // been prepacked with the same kernel (e.g. tied embedding/linear weights). + ValueRef cached = graph.get_cached_prepack(qmat2_data, kernel_name); + if (is_valid(cached)) { + return cached; + } + ValueRef qmat2 = graph.add_tensor( qmat2_sizes, vkcompute::vkapi::kInt, storage_type, utils::kWidthPacked); @@ -273,11 +286,6 @@ ValueRef prepack_quantized_linear_weight( 1u}; } - std::string kernel_name = weight_quant_config.nbits == 4 - ? "pack_q4_linear_weight" - : "pack_q8_linear_weight"; - add_storage_type_suffix(kernel_name, storage_type); - graph.prepack_nodes().emplace_back(new PrepackNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -294,6 +302,7 @@ ValueRef prepack_quantized_linear_weight( {graph.sizes_pc_of(qmat2), PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec2))})); + graph.cache_prepack(qmat2_data, kernel_name, qmat2); return qmat2; }