Skip to content

Commit 88bb7da

Browse files
Fix int64 workspace usage
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
1 parent ed0deaf commit 88bb7da

2 files changed

Lines changed: 50 additions & 16 deletions

File tree

transformer_engine/jax/cpp_extensions/gemm.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1549,7 +1549,24 @@ def abstract(
15491549
shape=(get_grouped_gemm_setup_workspace_size(num_groups),), dtype=jnp.uint8
15501550
)
15511551
# Temporary buffer for int32 -> int64 conversion of group_sizes on device.
1552-
int64_workspace_size = num_groups * jnp.dtype(jnp.int64).itemsize
1552+
# Each non-empty *_dims buffer needs its own slot of num_groups int64 elements so that
1553+
# make_grouped_tensor can write to a distinct region per ragged dimension. Allocate
1554+
# exactly as many slots as there are non-empty buffers (minimum 1 to avoid zero-size).
1555+
num_ragged_dim_buffers = sum(
1556+
1
1557+
for aval in [
1558+
lhs_first_dims_aval,
1559+
lhs_last_dims_aval,
1560+
rhs_first_dims_aval,
1561+
rhs_last_dims_aval,
1562+
out_first_dims_aval,
1563+
out_last_dims_aval,
1564+
]
1565+
if aval.size > 0
1566+
)
1567+
int64_workspace_size = (
1568+
max(num_ragged_dim_buffers, 1) * num_groups * jnp.dtype(jnp.int64).itemsize
1569+
)
15531570
int64_workspace_aval = jax.core.ShapedArray(
15541571
shape=(int64_workspace_size,), dtype=jnp.uint8
15551572
)

transformer_engine/jax/csrc/extensions/gemm.cpp

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -559,13 +559,17 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data,
559559
return std::move(grouped_tensor_wrapper);
560560
}
561561

562-
// V2 variant: derives data shape from the 2D XLA buffer directly, converts group_sizes
563-
// int32→int64 per-tensor into int64_workspace, and wires first_dims/last_dims.
564-
// Only NO_SCALING is supported.
562+
// V2 variant: derives data shape from the XLA buffer directly, converts group_sizes
563+
// int32→int64 per-tensor into a dedicated slot of int64_workspace, and wires first_dims/last_dims.
564+
// int64_offset (in int64 elements) is updated on return to the next available slot so callers can
565+
// thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked
566+
// before each slot is used. Only NO_SCALING is supported.
565567
JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data,
566568
Buffer_Type const &first_dims,
567569
Buffer_Type const &last_dims,
568-
Result_Type int64_workspace,
570+
int64_t *int64_workspace_base,
571+
size_t int64_workspace_capacity,
572+
size_t &int64_offset,
569573
size_t num_gemms,
570574
cudaStream_t stream) {
571575
auto dims = data.dimensions();
@@ -577,22 +581,27 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data,
577581
.ndim = 2};
578582
JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape);
579583
wrapper.set_rowwise(data, std::nullopt);
580-
auto *int64_sizes_ptr = reinterpret_cast<int64_t *>(int64_workspace->untyped_data());
581584
if (first_dims.element_count() > 0) {
582585
NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32,
583586
"group_sizes must be int32.");
587+
NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity,
588+
"int64_workspace overflow: not enough space for first_dims conversion.");
589+
auto *slot = int64_workspace_base + int64_offset;
584590
nvte_convert_int32_to_int64(
585-
reinterpret_cast<const int32_t *>(first_dims.untyped_data()),
586-
int64_sizes_ptr, num_gemms, stream);
587-
wrapper.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims);
591+
reinterpret_cast<const int32_t *>(first_dims.untyped_data()), slot, num_gemms, stream);
592+
wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims);
593+
int64_offset += num_gemms;
588594
}
589595
if (last_dims.element_count() > 0) {
590596
NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32,
591597
"group_sizes must be int32.");
598+
NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity,
599+
"int64_workspace overflow: not enough space for last_dims conversion.");
600+
auto *slot = int64_workspace_base + int64_offset;
592601
nvte_convert_int32_to_int64(
593-
reinterpret_cast<const int32_t *>(last_dims.untyped_data()),
594-
int64_sizes_ptr, num_gemms, stream);
595-
wrapper.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims);
602+
reinterpret_cast<const int32_t *>(last_dims.untyped_data()), slot, num_gemms, stream);
603+
wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims);
604+
int64_offset += num_gemms;
596605
}
597606
return wrapper;
598607
}
@@ -770,13 +779,21 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty
770779
convert_ffi_datatype_to_te_dtype(beta.element_type()));
771780

772781
// Build grouped tensors from XLA buffer shapes and group_sizes — no m/n/k derivation needed.
773-
// int32→int64 conversion for group_sizes is handled per-tensor inside make_grouped_tensor.
782+
// int64_workspace is partitioned into per-ragged-buffer slots of num_gemms int64 elements each.
783+
// int64_offset is threaded through the three make_grouped_tensor calls so each non-empty *_dims
784+
// buffer gets its own non-aliasing slot; bounds are checked inside make_grouped_tensor.
785+
auto *int64_base = reinterpret_cast<int64_t *>(int64_workspace->untyped_data());
786+
size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t);
787+
size_t int64_offset = 0;
774788
auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims,
775-
int64_workspace, num_gemms, stream);
789+
int64_base, int64_capacity, int64_offset, num_gemms,
790+
stream);
776791
auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims,
777-
int64_workspace, num_gemms, stream);
792+
int64_base, int64_capacity, int64_offset, num_gemms,
793+
stream);
778794
auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims,
779-
int64_workspace, num_gemms, stream);
795+
int64_base, int64_capacity, int64_offset, num_gemms,
796+
stream);
780797

781798
nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor,
782799
alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(),

0 commit comments

Comments
 (0)