Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,23 @@ bool ComputeGraph::is_valid_value_idx(const ValueRef idx) const noexcept {
return idx >= 0 && idx < static_cast<int>(values_.size());
}

ValueRef ComputeGraph::get_cached_prepack(
const ValueRef input,
const std::string& kernel_name) const {
auto it = prepack_cache_.find({input, kernel_name});
if (it != prepack_cache_.end()) {
return it->second;
}
return kDummyValueRef;
}

void ComputeGraph::cache_prepack(
const ValueRef input,
const std::string& kernel_name,
const ValueRef prepacked) {
prepack_cache_.emplace(std::make_pair(input, kernel_name), prepacked);
}

std::vector<int64_t> ComputeGraph::sizes_of(const ValueRef idx) const {
const Value& val = values_.at(idx);
if (val.isTensor()) {
Expand Down
33 changes: 33 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <optional>
#include <stack>
#include <unordered_map>

#include <executorch/backends/vulkan/runtime/api/api.h>

Expand Down Expand Up @@ -204,6 +205,22 @@ class ComputeGraph final {
// Set to track which ValueRefs were updated during inference
std::unordered_set<ValueRef> updated_values_;

// Cache to prevent duplicate prepacking of the same weight tensor with the
// same kernel. Key is (inputValueRef, kernel_name).
struct PrepackCacheHash {
size_t operator()(const std::pair<ValueRef, std::string>& key) const {
size_t h1 = std::hash<ValueRef>{}(key.first);
size_t h2 = std::hash<std::string>{}(key.second);
// Combine hashes using a method similar to boost::hash_combine
return h1 ^ (h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2));
}
};
std::unordered_map<
std::pair<ValueRef, std::string>,
ValueRef,
PrepackCacheHash>
prepack_cache_;

// Flag to indicate if re-encoding is required
bool requires_reencode_ = false;

Expand Down Expand Up @@ -687,6 +704,22 @@ class ComputeGraph final {
void check_no_active_value_ptrs();

public:
/*
* Check if a prepacked tensor already exists for the given input and kernel.
*/
ValueRef get_cached_prepack(
const ValueRef input,
const std::string& kernel_name) const;

/*
* Store a prepacked tensor in the cache, keyed by input ValueRef and kernel
* name.
*/
void cache_prepack(
const ValueRef input,
const std::string& kernel_name,
const ValueRef prepacked);

/*
* Add a `api::vTensor` value to the graph with the specified properties.
* There are various convenience overloads of this function that may be used
Expand Down
19 changes: 14 additions & 5 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,19 @@ ValueRef prepack_quantized_linear_weight(
storage_type = utils::kBuffer;
}

std::string kernel_name = weight_quant_config.nbits == 4
? "pack_q4_linear_weight"
: "pack_q8_linear_weight";
add_storage_type_suffix(kernel_name, storage_type);

// Check prepack cache before creating a new prepack node. This avoids
// allocating a duplicate output tensor when the same weight data has already
// been prepacked with the same kernel (e.g. tied embedding/linear weights).
ValueRef cached = graph.get_cached_prepack(qmat2_data, kernel_name);
if (is_valid(cached)) {
return cached;
}

ValueRef qmat2 = graph.add_tensor(
qmat2_sizes, vkcompute::vkapi::kInt, storage_type, utils::kWidthPacked);

Expand All @@ -273,11 +286,6 @@ ValueRef prepack_quantized_linear_weight(
1u};
}

std::string kernel_name = weight_quant_config.nbits == 4
? "pack_q4_linear_weight"
: "pack_q8_linear_weight";
add_storage_type_suffix(kernel_name, storage_type);

graph.prepack_nodes().emplace_back(new PrepackNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
Expand All @@ -294,6 +302,7 @@ ValueRef prepack_quantized_linear_weight(
{graph.sizes_pc_of(qmat2),
PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec2))}));

graph.cache_prepack(qmat2_data, kernel_name, qmat2);
return qmat2;
}

Expand Down
Loading