From c395ca3957cac801cae4e1e5f2cb703042dfbaec Mon Sep 17 00:00:00 2001 From: unexploredtest <53617231+unexploredtest@users.noreply.github.com> Date: Tue, 31 Mar 2026 00:17:06 +0000 Subject: [PATCH 1/2] [wasm] Fix GatherV2 batch indices bug --- tfjs-backend-wasm/src/kernels/GatherV2.ts | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/GatherV2.ts b/tfjs-backend-wasm/src/kernels/GatherV2.ts index b686b80138b..80431784c5a 100644 --- a/tfjs-backend-wasm/src/kernels/GatherV2.ts +++ b/tfjs-backend-wasm/src/kernels/GatherV2.ts @@ -24,7 +24,7 @@ import {CppDType} from './types'; let wasmGather: ( xId: number, dtype: CppDType, xStrides: Uint8Array, stridesSize: number, - indicesId: number, batchSize: number, outStrides: Uint8Array, + indicesId: number, indicesPerBatch: number, outStrides: Uint8Array, outId: number) => void; function setup(backend: BackendWasm): void { @@ -34,7 +34,7 @@ function setup(backend: BackendWasm): void { 'array', // xStrides 'number', // stridesSize 'number', // indicesId - 'number', // batchSize + 'number', // indicesPerBatch 'array', // outStrides 'number' // outId ]); @@ -88,6 +88,7 @@ function gatherV2( return out; } const stridesSize = flattenX.shape.length - 1; + const indicesPerBatch = indicesSize / shapeInfo.batchSize; const xData = backend.dataIdMap.get(flattenX.dataId); const xId = xData.id; @@ -104,7 +105,7 @@ function gatherV2( wasmGather( xId, CppDType[x.dtype], xStridesBytes, stridesSize, indicesId, - shapeInfo.batchSize, outStridesBytes, outId); + indicesPerBatch, outStridesBytes, outId); backend.disposeData(flattenX.dataId); backend.disposeData(flattenIndex.dataId); From 44a368cc1251f72cfb9f7ce7208847ff61e9d957 Mon Sep 17 00:00:00 2001 From: unexploredtest <53617231+unexploredtest@users.noreply.github.com> Date: Tue, 31 Mar 2026 00:19:27 +0000 Subject: [PATCH 2/2] [wasm] Change gather parameter name --- tfjs-backend-wasm/src/cc/kernels/Gather.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/Gather.cc b/tfjs-backend-wasm/src/cc/kernels/Gather.cc index 5ce0a5f3c22..5519e4fbd4c 100644 --- a/tfjs-backend-wasm/src/cc/kernels/Gather.cc +++ b/tfjs-backend-wasm/src/cc/kernels/Gather.cc @@ -24,13 +24,13 @@ namespace { template void gather_impl(const T* x_ptr, const std::vector& x_strides, const int32_t* indices_ptr, const size_t out_size, - const size_t batch_size, + const size_t indices_per_batch, const std::vector& out_strides, T* out_buf_ptr) { for (size_t i = 0; i < out_size; ++i) { auto loc = tfjs::util::offset_to_loc(i, out_strides); const size_t batch_loc = loc[0]; const size_t indices_loc = loc[2]; - loc[2] = indices_ptr[batch_loc * batch_size + indices_loc]; + loc[2] = indices_ptr[batch_loc * indices_per_batch + indices_loc]; const size_t original_index = tfjs::util::loc_to_offset(loc, x_strides); @@ -50,7 +50,7 @@ EMSCRIPTEN_KEEPALIVE void Gather(const size_t x_id, const DType dtype, const int32_t* x_strides_ptr, const size_t strides_size, const size_t indices_id, - const size_t batch_size, const int32_t* out_strides_ptr, + const size_t indices_per_batch, const int32_t* out_strides_ptr, const size_t out_id) { auto& x_info = backend::get_tensor_info(x_id); auto& indices_info = backend::get_tensor_info(indices_id); @@ -67,15 +67,15 @@ void Gather(const size_t x_id, const DType dtype, const int32_t* x_strides_ptr, switch (dtype) { case DType::float32: gather_impl(x_info.f32(), x_strides, indices_buf, out_size, - batch_size, out_strides, out_info.f32_write()); + indices_per_batch, out_strides, out_info.f32_write()); break; case DType::int32: gather_impl(x_info.i32(), x_strides, indices_buf, out_size, - batch_size, out_strides, out_info.i32_write()); + indices_per_batch, out_strides, out_info.i32_write()); break; case DType::boolean: gather_impl(x_info.b(), x_strides, indices_buf, out_size, - batch_size, out_strides, out_info.b_write()); + indices_per_batch, out_strides, out_info.b_write()); break; default: util::warn("Gather for tensor id %d failed. Unknown dtype %d", x_id,