@@ -559,13 +559,17 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data,
559559 return std::move (grouped_tensor_wrapper);
560560}
561561
562- // V2 variant: derives data shape from the 2D XLA buffer directly, converts group_sizes
563- // int32→int64 per-tensor into int64_workspace, and wires first_dims/last_dims.
564- // Only NO_SCALING is supported.
562+ // V2 variant: derives data shape from the XLA buffer directly, converts group_sizes
563+ // int32→int64 per-tensor into a dedicated slot of int64_workspace, and wires first_dims/last_dims.
564+ // int64_offset (in int64 elements) is updated on return to the next available slot so callers can
565+ // thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked
566+ // before each slot is used. Only NO_SCALING is supported.
565567JAXX_GroupedTensorWrapper make_grouped_tensor (Buffer_Type const &data,
566568 Buffer_Type const &first_dims,
567569 Buffer_Type const &last_dims,
568- Result_Type int64_workspace,
570+ int64_t *int64_workspace_base,
571+ size_t int64_workspace_capacity,
572+ size_t &int64_offset,
569573 size_t num_gemms,
570574 cudaStream_t stream) {
571575 auto dims = data.dimensions ();
@@ -577,22 +581,27 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data,
577581 .ndim = 2 };
578582 JAXX_GroupedTensorWrapper wrapper (JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape);
579583 wrapper.set_rowwise (data, std::nullopt );
580- auto *int64_sizes_ptr = reinterpret_cast <int64_t *>(int64_workspace->untyped_data ());
581584 if (first_dims.element_count () > 0 ) {
582585 NVTE_CHECK (first_dims.element_type () == xla::ffi::DataType::S32,
583586 " group_sizes must be int32." );
587+ NVTE_CHECK (int64_offset + num_gemms <= int64_workspace_capacity,
588+ " int64_workspace overflow: not enough space for first_dims conversion." );
589+ auto *slot = int64_workspace_base + int64_offset;
584590 nvte_convert_int32_to_int64 (
585- reinterpret_cast <const int32_t *>(first_dims.untyped_data ()),
586- int64_sizes_ptr , num_gemms, stream );
587- wrapper. set_group_sizes_only (int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims ) ;
591+ reinterpret_cast <const int32_t *>(first_dims.untyped_data ()), slot, num_gemms, stream);
592+ wrapper. set_group_sizes_only (slot , num_gemms, kNVTEGroupedFirstDims );
593+ int64_offset += num_gemms ;
588594 }
589595 if (last_dims.element_count () > 0 ) {
590596 NVTE_CHECK (last_dims.element_type () == xla::ffi::DataType::S32,
591597 " group_sizes must be int32." );
598+ NVTE_CHECK (int64_offset + num_gemms <= int64_workspace_capacity,
599+ " int64_workspace overflow: not enough space for last_dims conversion." );
600+ auto *slot = int64_workspace_base + int64_offset;
592601 nvte_convert_int32_to_int64 (
593- reinterpret_cast <const int32_t *>(last_dims.untyped_data ()),
594- int64_sizes_ptr , num_gemms, stream );
595- wrapper. set_group_sizes_only (int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims ) ;
602+ reinterpret_cast <const int32_t *>(last_dims.untyped_data ()), slot, num_gemms, stream);
603+ wrapper. set_group_sizes_only (slot , num_gemms, kNVTEGroupedLastDims );
604+ int64_offset += num_gemms ;
596605 }
597606 return wrapper;
598607}
@@ -770,13 +779,21 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty
770779 convert_ffi_datatype_to_te_dtype (beta.element_type ()));
771780
772781 // Build grouped tensors from XLA buffer shapes and group_sizes — no m/n/k derivation needed.
773- // int32→int64 conversion for group_sizes is handled per-tensor inside make_grouped_tensor.
782+ // int64_workspace is partitioned into per-ragged-buffer slots of num_gemms int64 elements each.
783+ // int64_offset is threaded through the three make_grouped_tensor calls so each non-empty *_dims
784+ // buffer gets its own non-aliasing slot; bounds are checked inside make_grouped_tensor.
785+ auto *int64_base = reinterpret_cast <int64_t *>(int64_workspace->untyped_data ());
786+ size_t int64_capacity = int64_workspace->element_count () / sizeof (int64_t );
787+ size_t int64_offset = 0 ;
774788 auto rhs_tensor = make_grouped_tensor (rhs_data, rhs_first_dims, rhs_last_dims,
775- int64_workspace, num_gemms, stream);
789+ int64_base, int64_capacity, int64_offset, num_gemms,
790+ stream);
776791 auto lhs_tensor = make_grouped_tensor (lhs_data, lhs_first_dims, lhs_last_dims,
777- int64_workspace, num_gemms, stream);
792+ int64_base, int64_capacity, int64_offset, num_gemms,
793+ stream);
778794 auto out_tensor = make_grouped_tensor (*output, out_first_dims, out_last_dims,
779- int64_workspace, num_gemms, stream);
795+ int64_base, int64_capacity, int64_offset, num_gemms,
796+ stream);
780797
781798 nvte_grouped_gemm (rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr , out_tensor,
782799 alpha_tensor.data (), beta_tensor.data (), workspace_setup.data (),
0 commit comments