@@ -64,6 +64,49 @@ using ExtraTensorList = std::vector<std::pair<std::string, TensorAttr>>;
6464
6565constexpr char kDlpackCapsuleName [] = " dltensor" ;
6666
67+ struct DlpackTensorView {
68+ PyObject *capsule = nullptr ;
69+ DLManagedTensor *managed = nullptr ;
70+
71+ DlpackTensorView () = default ;
72+ DlpackTensorView (PyObject *capsule_obj, DLManagedTensor *managed_tensor)
73+ : capsule(capsule_obj), managed(managed_tensor) {}
74+
75+ DlpackTensorView (const DlpackTensorView &) = delete ;
76+ auto operator =(const DlpackTensorView &) -> DlpackTensorView & = delete ;
77+
78+ DlpackTensorView (DlpackTensorView &&other) noexcept
79+ : capsule(other.capsule), managed(other.managed) {
80+ other.capsule = nullptr ;
81+ other.managed = nullptr ;
82+ }
83+
84+ auto operator =(DlpackTensorView &&other) noexcept -> DlpackTensorView & {
85+ if (this != &other) {
86+ reset ();
87+ capsule = other.capsule ;
88+ managed = other.managed ;
89+ other.capsule = nullptr ;
90+ other.managed = nullptr ;
91+ }
92+ return *this ;
93+ }
94+
95+ ~DlpackTensorView () { reset (); }
96+
97+ void reset () {
98+ if (capsule == nullptr ) {
99+ return ;
100+ }
101+ py::gil_scoped_acquire gil;
102+ Py_DECREF (capsule);
103+ capsule = nullptr ;
104+ managed = nullptr ;
105+ }
106+ };
107+
108+ using DlpackTensorViews = std::vector<DlpackTensorView>;
109+
67110cudnn_frontend::DataType_t convert_dlpack_dtype (const DLDataType &dtype) {
68111 switch (dtype.code ) {
69112 case DLDataTypeCode::kDLUInt :
@@ -127,10 +170,16 @@ decltype(auto) with_dlpack_tensor(py::handle tensor_obj, Fn &&fn) {
127170 return std::forward<Fn>(fn)(*managed);
128171}
129172
130- void *get_dlpack_data_pointer (py::handle tensor_obj) {
131- return with_dlpack_tensor (tensor_obj, [](const DLManagedTensor &managed) -> void * {
132- return static_cast <char *>(managed.dl_tensor .data ) + managed.dl_tensor .byte_offset ;
133- });
173+ DlpackTensorView make_dlpack_tensor_view (py::handle tensor_obj) {
174+ NVTE_CHECK (py::hasattr (tensor_obj, " __dlpack__" ),
175+ " score_mod_tensors entries must support __dlpack__()." );
176+ py::capsule capsule = tensor_obj.attr (" __dlpack__" )();
177+ NVTE_CHECK (!capsule.is_none (), " Failed to retrieve DLPack capsule for score_mod_tensors entry." );
178+ auto *managed =
179+ static_cast <DLManagedTensor *>(PyCapsule_GetPointer (capsule.ptr (), kDlpackCapsuleName ));
180+ NVTE_CHECK (managed != nullptr , " Invalid DLPack capsule in score_mod_tensors entry." );
181+ Py_INCREF (capsule.ptr ());
182+ return DlpackTensorView (capsule.ptr (), managed);
134183}
135184
136185TensorAttr create_tensor_attr_from_dlpack (const std::shared_ptr<cudnn_frontend::graph::Graph> &graph,
@@ -229,20 +278,26 @@ py::dict get_score_mod_tensor_attrs(const std::shared_ptr<cudnn_frontend::graph:
229278 return callback_tensors;
230279}
231280
232- void extend_variant_pack_with_extra_tensors (
281+ DlpackTensorViews extend_variant_pack_with_extra_tensors (
233282 void *extra_tensors_ptr, const ExtraTensorList &extra_tensor_attrs,
234283 std::unordered_map<TensorAttr, void *> &variant_pack) {
284+ DlpackTensorViews views;
235285 if (extra_tensors_ptr == nullptr || extra_tensor_attrs.empty ()) {
236- return ;
286+ return views ;
237287 }
238288
239289 py::gil_scoped_acquire gil;
240290 py::dict extra_tensors = py::reinterpret_borrow<py::dict>(reinterpret_cast <PyObject *>(extra_tensors_ptr));
291+ views.reserve (extra_tensor_attrs.size ());
241292 for (const auto &[name, tensor_attr] : extra_tensor_attrs) {
242293 py::str key (name);
243294 NVTE_CHECK (extra_tensors.contains (key), " Missing score_mod tensor entry: " , name);
244- variant_pack[tensor_attr] = get_dlpack_data_pointer (extra_tensors[key]);
295+ auto view = make_dlpack_tensor_view (extra_tensors[key]);
296+ variant_pack[tensor_attr] =
297+ static_cast <char *>(view.managed ->dl_tensor .data ) + view.managed ->dl_tensor .byte_offset ;
298+ views.emplace_back (std::move (view));
245299 }
300+ return views;
246301}
247302
248303auto make_attention_score_modifier (void *callback_ptr, void *extra_tensors_ptr,
@@ -788,7 +843,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
788843 variant_pack[softmax_offset] = devPtrSoftmaxOffset;
789844 }
790845
791- extend_variant_pack_with_extra_tensors (score_mod_tensors, score_mod_extra_tensors, variant_pack);
846+ auto score_mod_tensor_views =
847+ extend_variant_pack_with_extra_tensors (score_mod_tensors, score_mod_extra_tensors, variant_pack);
792848
793849 NVTE_CHECK_CUDNN_FE (mha_graph->execute (handle, variant_pack, workspace));
794850 } catch (cudnn_frontend::cudnnException &e) {
@@ -1320,9 +1376,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
13201376 variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset;
13211377 }
13221378
1323- extend_variant_pack_with_extra_tensors (score_mod_tensors, score_mod_extra_tensors, variant_pack);
1324- extend_variant_pack_with_extra_tensors (score_mod_bprop_tensors, score_mod_bprop_extra_tensors,
1325- variant_pack);
1379+ auto score_mod_tensor_views =
1380+ extend_variant_pack_with_extra_tensors (score_mod_tensors, score_mod_extra_tensors, variant_pack);
1381+ auto score_mod_bprop_tensor_views = extend_variant_pack_with_extra_tensors (
1382+ score_mod_bprop_tensors, score_mod_bprop_extra_tensors, variant_pack);
13261383
13271384 NVTE_CHECK_CUDNN_FE (mha_graph->execute (handle, variant_pack, workspace));
13281385 } catch (cudnn_frontend::cudnnException &e) {
0 commit comments