Skip to content

Commit cd86c51

Browse files
Support host score mod tensors via retained DLPack capsules
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
1 parent 9d4ba52 commit cd86c51

2 files changed

Lines changed: 71 additions & 13 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,7 +1715,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
17151715
@pytest.mark.parametrize("dtype", [torch.bfloat16])
17161716
@pytest.mark.parametrize("model_configs", [model_configs_score_mod])
17171717
@pytest.mark.parametrize("model", model_configs_score_mod.keys())
1718-
def test_dpa_score_mod_causal_external_neg_inf(dtype, model_configs, model):
1718+
@pytest.mark.parametrize("neg_inf_device", ["cuda", "cpu"], ids=["cuda_tensor", "cpu_by_value_tensor"])
1719+
def test_dpa_score_mod_causal_external_neg_inf(dtype, model_configs, model, neg_inf_device):
17191720
"""Test DotProductAttention causal masking via score_mod with external variant-pack tensor."""
17201721

17211722
config = model_configs[model]
@@ -1748,7 +1749,7 @@ def test_dpa_score_mod_causal_external_neg_inf(dtype, model_configs, model):
17481749
v = (torch.randn(b, sq, h, d, dtype=dtype, device="cuda") * 0.1).detach().requires_grad_(True)
17491750
cu_seqlens = torch.arange(0, (b + 1) * sq, sq, dtype=torch.int32, device="cuda")
17501751
out_grad = (torch.randn(b, sq, h * d, dtype=dtype, device="cuda") * 0.01).detach()
1751-
neg_inf = torch.full((1, 1, 1, 1), float("-inf"), dtype=torch.float32, device="cuda")
1752+
neg_inf = torch.full((1, 1, 1, 1), float("-inf"), dtype=torch.float32, device=neg_inf_device)
17521753

17531754
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
17541755
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,49 @@ using ExtraTensorList = std::vector<std::pair<std::string, TensorAttr>>;
6464

6565
constexpr char kDlpackCapsuleName[] = "dltensor";
6666

67+
struct DlpackTensorView {
68+
PyObject *capsule = nullptr;
69+
DLManagedTensor *managed = nullptr;
70+
71+
DlpackTensorView() = default;
72+
DlpackTensorView(PyObject *capsule_obj, DLManagedTensor *managed_tensor)
73+
: capsule(capsule_obj), managed(managed_tensor) {}
74+
75+
DlpackTensorView(const DlpackTensorView &) = delete;
76+
auto operator=(const DlpackTensorView &) -> DlpackTensorView & = delete;
77+
78+
DlpackTensorView(DlpackTensorView &&other) noexcept
79+
: capsule(other.capsule), managed(other.managed) {
80+
other.capsule = nullptr;
81+
other.managed = nullptr;
82+
}
83+
84+
auto operator=(DlpackTensorView &&other) noexcept -> DlpackTensorView & {
85+
if (this != &other) {
86+
reset();
87+
capsule = other.capsule;
88+
managed = other.managed;
89+
other.capsule = nullptr;
90+
other.managed = nullptr;
91+
}
92+
return *this;
93+
}
94+
95+
~DlpackTensorView() { reset(); }
96+
97+
void reset() {
98+
if (capsule == nullptr) {
99+
return;
100+
}
101+
py::gil_scoped_acquire gil;
102+
Py_DECREF(capsule);
103+
capsule = nullptr;
104+
managed = nullptr;
105+
}
106+
};
107+
108+
using DlpackTensorViews = std::vector<DlpackTensorView>;
109+
67110
cudnn_frontend::DataType_t convert_dlpack_dtype(const DLDataType &dtype) {
68111
switch (dtype.code) {
69112
case DLDataTypeCode::kDLUInt:
@@ -127,10 +170,16 @@ decltype(auto) with_dlpack_tensor(py::handle tensor_obj, Fn &&fn) {
127170
return std::forward<Fn>(fn)(*managed);
128171
}
129172

130-
void *get_dlpack_data_pointer(py::handle tensor_obj) {
131-
return with_dlpack_tensor(tensor_obj, [](const DLManagedTensor &managed) -> void * {
132-
return static_cast<char *>(managed.dl_tensor.data) + managed.dl_tensor.byte_offset;
133-
});
173+
DlpackTensorView make_dlpack_tensor_view(py::handle tensor_obj) {
174+
NVTE_CHECK(py::hasattr(tensor_obj, "__dlpack__"),
175+
"score_mod_tensors entries must support __dlpack__().");
176+
py::capsule capsule = tensor_obj.attr("__dlpack__")();
177+
NVTE_CHECK(!capsule.is_none(), "Failed to retrieve DLPack capsule for score_mod_tensors entry.");
178+
auto *managed =
179+
static_cast<DLManagedTensor *>(PyCapsule_GetPointer(capsule.ptr(), kDlpackCapsuleName));
180+
NVTE_CHECK(managed != nullptr, "Invalid DLPack capsule in score_mod_tensors entry.");
181+
Py_INCREF(capsule.ptr());
182+
return DlpackTensorView(capsule.ptr(), managed);
134183
}
135184

136185
TensorAttr create_tensor_attr_from_dlpack(const std::shared_ptr<cudnn_frontend::graph::Graph> &graph,
@@ -229,20 +278,26 @@ py::dict get_score_mod_tensor_attrs(const std::shared_ptr<cudnn_frontend::graph:
229278
return callback_tensors;
230279
}
231280

232-
void extend_variant_pack_with_extra_tensors(
281+
DlpackTensorViews extend_variant_pack_with_extra_tensors(
233282
void *extra_tensors_ptr, const ExtraTensorList &extra_tensor_attrs,
234283
std::unordered_map<TensorAttr, void *> &variant_pack) {
284+
DlpackTensorViews views;
235285
if (extra_tensors_ptr == nullptr || extra_tensor_attrs.empty()) {
236-
return;
286+
return views;
237287
}
238288

239289
py::gil_scoped_acquire gil;
240290
py::dict extra_tensors = py::reinterpret_borrow<py::dict>(reinterpret_cast<PyObject *>(extra_tensors_ptr));
291+
views.reserve(extra_tensor_attrs.size());
241292
for (const auto &[name, tensor_attr] : extra_tensor_attrs) {
242293
py::str key(name);
243294
NVTE_CHECK(extra_tensors.contains(key), "Missing score_mod tensor entry: ", name);
244-
variant_pack[tensor_attr] = get_dlpack_data_pointer(extra_tensors[key]);
295+
auto view = make_dlpack_tensor_view(extra_tensors[key]);
296+
variant_pack[tensor_attr] =
297+
static_cast<char *>(view.managed->dl_tensor.data) + view.managed->dl_tensor.byte_offset;
298+
views.emplace_back(std::move(view));
245299
}
300+
return views;
246301
}
247302

248303
auto make_attention_score_modifier(void *callback_ptr, void *extra_tensors_ptr,
@@ -788,7 +843,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
788843
variant_pack[softmax_offset] = devPtrSoftmaxOffset;
789844
}
790845

791-
extend_variant_pack_with_extra_tensors(score_mod_tensors, score_mod_extra_tensors, variant_pack);
846+
auto score_mod_tensor_views =
847+
extend_variant_pack_with_extra_tensors(score_mod_tensors, score_mod_extra_tensors, variant_pack);
792848

793849
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
794850
} catch (cudnn_frontend::cudnnException &e) {
@@ -1320,9 +1376,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
13201376
variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset;
13211377
}
13221378

1323-
extend_variant_pack_with_extra_tensors(score_mod_tensors, score_mod_extra_tensors, variant_pack);
1324-
extend_variant_pack_with_extra_tensors(score_mod_bprop_tensors, score_mod_bprop_extra_tensors,
1325-
variant_pack);
1379+
auto score_mod_tensor_views =
1380+
extend_variant_pack_with_extra_tensors(score_mod_tensors, score_mod_extra_tensors, variant_pack);
1381+
auto score_mod_bprop_tensor_views = extend_variant_pack_with_extra_tensors(
1382+
score_mod_bprop_tensors, score_mod_bprop_extra_tensors, variant_pack);
13261383

13271384
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
13281385
} catch (cudnn_frontend::cudnnException &e) {

0 commit comments

Comments
 (0)