Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
rm -rf build
rm -rf dist
rm -rf deep_ep_cpp.cpython-38-x86_64-linux-gnu.so
export TORCH_CUDA_ARCH_LIST="10.0"
export PADDLE_CUDA_ARCH_LIST="10.0"
python setup_deep_ep.py bdist_wheel
python setup_hybrid_ep.py bdist_wheel
pip install dist/*.whl --force-reinstall
104 changes: 65 additions & 39 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#include <ATen/cuda/CUDAContext.h>
// #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
#include <ATen/Functions.h>
#include <chrono>
#include <cuda_runtime.h>
#include <memory>
Expand Down Expand Up @@ -131,16 +132,29 @@ Buffer::Buffer(int rank,
bool low_latency_mode,
bool disable_nvlink_for_normal_mode,
bool explicitly_destroy,
bool use_fabric)
bool use_fabric,
int context_ring_id)
: rank(rank),
num_ranks(num_ranks),
num_nvl_bytes(num_nvl_bytes),
num_rdma_bytes(num_rdma_bytes),
low_latency_mode(low_latency_mode),
disable_nvlink_for_normal_mode(disable_nvlink_for_normal_mode),
explicitly_destroy(explicitly_destroy),
comm_stream(at::cuda::getStreamFromPool(true)),
shared_memory_allocator(use_fabric) {

CUDA_CHECK(cudaGetDevice(&device_id));
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();
paddle::distributed::ProcessGroup* pg = map->get(context_ring_id);
const auto& place = phi::GPUPlace(device_id);
comm_ctx =
reinterpret_cast<paddle::distributed::ProcessGroupNCCL*>(pg)
->GetOrCreateCommContext(place, phi::distributed::CommType::ALLTOALL);
comm_stream = comm_ctx->GetStream();
calc_ctx = reinterpret_cast<phi::GPUContext*>(
reinterpret_cast<paddle::distributed::ProcessGroupNCCL*>(pg)
->GetDeviceContext(place, true));

// Metadata memory
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*);
Expand Down Expand Up @@ -262,7 +276,7 @@ torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int
return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA));
}

torch::Stream Buffer::get_comm_stream() const {
cudaStream_t Buffer::get_comm_stream() const {
return comm_stream;
}

Expand Down Expand Up @@ -374,10 +388,10 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,

// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
auto compute_stream = calc_ctx->stream();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把 compute_streamat::cuda::getCurrentCUDAStream() 改成了固定的 calc_ctx->stream(),但 EventHandle 仍然在 at::cuda::getCurrentCUDAStream() 上 record/wait。这样 current-stream 语义已经分裂成两套来源:上层如果在自定义 torch.cuda.Stream() 上排 kernel,previous_event 为空时这里不再自动等待调用者当前流,可能出现缺依赖或额外串行化。建议把 wait / allocate / record_stream 统一到同一个 stream 源。

Copy link
Copy Markdown

@youge325 youge325 Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你说得有一定道理,那我能不能这样改,在调用 at::cuda::setCurrentCUDAStream() 时,同时更新 calc_ctx 的 stream:

// 封装一个工具函数
inline void set_current_stream(c10::cuda::CUDAStream stream, 
                                phi::GPUContext* calc_ctx) {
    // 1. 设置 ATen current stream(影响 at::cuda::getCurrentCUDAStream())
    at::cuda::setCurrentCUDAStream(stream);
    
    // 2. 同步更新 Paddle GPUContext 的 stream
    // 假设 calc_ctx 提供了设置 stream 的方法
    calc_ctx->SetStream(stream.stream());  // 或者等效方法
}

然后替换 deep_ep.cpp 中所有的 setCurrentCUDAStream 调用:

// 原代码(第464行)
deep_ep::SetAllocatorStreamForGPUContext(compute_stream.stream(), calc_ctx);

// 改为
set_current_stream(compute_stream, calc_ctx);
deep_ep::SetAllocatorStreamForGPUContext(compute_stream.stream(), calc_ctx);

注意,这里的修改是以 https://github.com/youge325/DeepEP/tree/cRemove_OptionalPaddle 最新的 develop 分支 为基准,或者你有更好的,能够尽可能减小 diff 的方案

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你说得有一定道理,那我能不能这样改,在调用 at::cuda::setCurrentCUDAStream() 时,同时更新 calc_ctx 的 stream ... 或者你有更好的、能够尽可能减小 diff 的方案

已解释:这个方向比现在更接近正确,但我不建议只在 setCurrentCUDAStream() 这一处做“双向同步”。

我前面指出的问题不只是在“切 allocator 时两边 stream 没同步”,更关键的是这里已经把 caller stream 的来源at::cuda::getCurrentCUDAStream() 改成了 calc_ctx->stream()

  • csrc/deep_ep.cpp:391 这里先读 auto compute_stream = calc_ctx->stream();
  • previous_event 为空时,后面 stream_wait(comm_stream, compute_stream) 等待的也是它
  • EventHandle 仍然在 at::cuda::getCurrentCUDAStream() 上 record/wait(csrc/event.hpp:13-29

所以如果只是在后面补一个 set_current_stream(...),默认路径里仍然可能没有等待调用方当前 torch.cuda.Stream() 上已经排好的工作;也就是说,问题的根还在“source of truth 已经换了”。

如果想尽量减小 diff,我更倾向于:

  1. 保持 caller stream 语义继续以 ATen current stream 为准(尽量贴近原 PyTorch 代码)
    • auto compute_stream = at::cuda::getCurrentCUDAStream();
    • wait / event / restore 都围绕同一个 stream 源
  2. 如果兼容层确实要求 calc_ctx / allocator 跟随,再额外封一个很薄的 helper,把 当前选定的 ATen stream 同步给 calc_ctx / allocator;但不要反过来让 calc_ctx->stream() 成为主语义。

这样改的好处是:

  • 语义上更接近上游 PyTorch 版本
  • diff 也更小
  • 不需要在每个调用点都推断“现在到底该信 ATen current stream 还是 Paddle context stream”

如果你后面准备这样收敛,我觉得是可以的;但建议先把“谁是唯一真源(single source of truth)”定下来,否则只是把 split stream contract 从一处挪到另一处。

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考 #11https://github.com/PaddlePaddle/Paddle 仓库,我们很经常使用 current_stream_wait 吗,能不能使用 void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1)void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) 来代替,我看到在 https://github.com/PaddlePaddle/Paddle 仓库里,有关 current_stream_wait 的测试都被注释掉了
我们的目的就是让 Paddle 来管理 stream ,这样比较好区分通信流和计算流,因此,ProcessGroup 自然要成为唯一真源,至于 current_stream_wait ,可以用 stream_wait 来代替,通过显式传入 CUDAStream 和 EventHandle 来明确语义,从而消除代码歧义

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考 #11https://github.com/PaddlePaddle/Paddle 仓库,我们很经常使用 current_stream_wait 吗,能不能使用 stream_wait(...) 来代替;我们的目的就是让 Paddle 来管理 stream,因此 ProcessGroup 自然要成为唯一真源

已解释:如果只是说“内部 C++ 的等待关系尽量改成显式 stream_wait(s_0, s_1) / stream_wait(s, event)”,我并不反对,这个形式本身比隐式 current-stream 更清晰。

但这还不能单独消除我前面提的 blocker,因为这里的问题已经不只是内部 helper 选型,而是外部 contract 还没有一起迁移完

  1. current_stream_wait 在这个 PR 里仍然是公开 API,不是单纯内部实现细节。

    • Python wrapper 还保留了 EventOverlap.current_stream_wait()deep_ep/utils.py:36-41),__exit__ 里也会直接调它。
    • repo 里的测试也还在大量依赖它,比如 tests/test_intranode.py:105-106tests/test_low_latency.py:71-72
      所以如果目标真的是“ProcessGroup / Paddle 成为唯一真源”,这里只改内部 wait helper 还不够,外层语义契约还没有一起切过去。
  2. 当前迁移本身还是半闭环。

    • Python wrapper 已经新增了 calc_stream_wait/comm_stream_waitdeep_ep/utils.py:43-47),
    • 但 C++ / pybind 侧目前只实现并导出了 current_stream_waitcsrc/deep_ep.cpp:1928-1930)。
      也就是说,PR 现在同时暴露了“current stream 语义”和“想转向显式 calc/comm 语义”的两套接口,但后者还没有真正闭合。

所以我的核心意见不变:

  • 如果想维持最小 diff / 最接近上游 PyTorch 语义:继续让 ATen current stream 做唯一真源,再把选定 stream 同步给 Paddle allocator/context。
  • 如果想彻底改成 Paddle/ProcessGroup 真源:那就需要把 event capture、wait target、Python EventOverlap API、tests/docs 一起切到显式语义,不能只把 current_stream_wait 换成内部的 stream_wait(...) 就算完成。

换句话说,我不是反对 stream_wait(...) 这个形式;我反对的是当前 PR 还处在“内部语义想切,外部 contract 还没切完”的状态,这一点在 review 里仍然是 blocker。

另外,Paddle 那边把 current_stream_wait 相关测试注释掉,最多说明那边对 current-stream 契约本身也比较谨慎;对这个 PR 来说,反而更说明这里需要一次性把外部语义讲清楚,而不是保留半迁移状态。

if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx);
}

// Wait previous tasks to be finished
Expand Down Expand Up @@ -423,7 +437,7 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,

// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
deep_ep::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx);

return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event};
}
Expand Down Expand Up @@ -534,10 +548,10 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te

// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
auto compute_stream = calc_ctx->stream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
EP_HOST_ASSERT(previous_event.has_value() && async);
deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx);
}

// Wait previous tasks to be finished
Expand Down Expand Up @@ -686,8 +700,9 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
}

// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
if (allocate_on_comm_stream) {
deep_ep::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx);
}

// Return values
return {recv_x, recv_x_scales, recv_sf_scale_for_nvfp4, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event};
Expand Down Expand Up @@ -718,10 +733,10 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten

// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
auto compute_stream = calc_ctx->stream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
EP_HOST_ASSERT(previous_event.has_value() && async);
deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx);
}

// Wait previous tasks to be finished
Expand Down Expand Up @@ -798,8 +813,9 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
}

// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
if (allocate_on_comm_stream) {
deep_ep::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx);
}

return {recv_x, recv_topk_weights, event};
}
Expand Down Expand Up @@ -905,10 +921,10 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te

// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
auto compute_stream = calc_ctx->stream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
EP_HOST_ASSERT(previous_event.has_value() && async);
deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx);
}

// Wait previous tasks to be finished
Expand Down Expand Up @@ -1070,8 +1086,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
}

// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
if (allocate_on_comm_stream) {
deep_ep::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx);
}

// Return values
return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list,
Expand Down Expand Up @@ -1119,10 +1136,10 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten

// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
auto compute_stream = calc_ctx->stream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
EP_HOST_ASSERT(previous_event.has_value() && async);
deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx);
}

// Wait previous tasks to be finished
Expand Down Expand Up @@ -1207,8 +1224,9 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
}

// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
if (allocate_on_comm_stream) {
deep_ep::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx);
}

// Return values
return {combined_x, combined_topk_weights, event};
Expand Down Expand Up @@ -1616,10 +1634,10 @@ Buffer::dispatch_pcie(const torch::Tensor& x, const std::optional<torch::Tensor>
}

// Stream Management
auto compute_stream = at::cuda::getCurrentCUDAStream();
auto compute_stream = calc_ctx->stream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() && async);
at::cuda::setCurrentCUDAStream(comm_stream);
deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx);
}

if (previous_event.has_value()) {
Expand Down Expand Up @@ -1757,7 +1775,7 @@ Buffer::dispatch_pcie(const torch::Tensor& x, const std::optional<torch::Tensor>
}

if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
deep_ep::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx);

return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list,
rdma_channel_prefix_matrix,
Expand Down Expand Up @@ -1830,10 +1848,10 @@ Buffer::combine_pcie(const torch::Tensor& recv_x, const std::optional<torch::Ten
}

// Stream Management
auto compute_stream = at::cuda::getCurrentCUDAStream();
auto compute_stream = calc_ctx->stream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() && async);
at::cuda::setCurrentCUDAStream(comm_stream);
EP_HOST_ASSERT(previous_event.has_value() and async);
deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx);
}

if (previous_event.has_value()) {
Expand Down Expand Up @@ -1883,7 +1901,7 @@ Buffer::combine_pcie(const torch::Tensor& recv_x, const std::optional<torch::Ten
}

if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
deep_ep::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx);

return {combined_x, combined_topk_weights, event};
#else
Expand All @@ -1897,7 +1915,7 @@ Buffer::combine_pcie(const torch::Tensor& recv_x, const std::optional<torch::Ten
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "DeepEP: an efficient expert-parallel communication library";

pybind11::class_<deep_ep::Config>(m, "Config")
pybind11::class_<deep_ep::Config>(m, "Config", py::module_local())
.def(pybind11::init<int, int, int, int, int>(),
py::arg("num_sms") = 20,
py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256,
Expand All @@ -1907,12 +1925,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("get_pcie_buffer_size_hint", &deep_ep::Config::get_pcie_buffer_size_hint);
m.def("get_low_latency_rdma_size_hint", &deep_ep::get_low_latency_rdma_size_hint);

pybind11::class_<deep_ep::EventHandle>(m, "EventHandle")
pybind11::class_<deep_ep::EventHandle>(m, "EventHandle", py::module_local())
.def(pybind11::init<>())
.def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait);

pybind11::class_<deep_ep::Buffer>(m, "Buffer")
.def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool, bool>())
pybind11::class_<deep_ep::Buffer>(m, "Buffer", py::module_local())
.def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool, bool, int>())
.def("is_available", &deep_ep::Buffer::is_available)
.def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks)
.def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank)
Expand All @@ -1921,7 +1939,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle)
.def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id)
.def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor)
.def("get_comm_stream", &deep_ep::Buffer::get_comm_stream)
.def("get_comm_stream",
[](deep_ep::Buffer &self) {
int device_id = self.get_local_device_id();
cudaStream_t comm_stream = self.get_comm_stream();
auto s = phi::Stream(reinterpret_cast<phi::StreamId>(comm_stream));
#if defined(PADDLE_WITH_CUDA)
return phi::CUDAStream(phi::GPUPlace(device_id), s);
#endif
})
.def("sync", &deep_ep::Buffer::sync)
.def("destroy", &deep_ep::Buffer::destroy)
.def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout)
Expand Down
21 changes: 18 additions & 3 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/types.h>
#include <c10/cuda/CUDAStream.h>
#include <tuple>
#include <vector>

Expand All @@ -17,6 +18,9 @@
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"

#include "paddle/phi/core/memory/allocation/allocator_facade.h"
#include "paddle/fluid/distributed/collective/process_group_nccl.h"

#ifndef TORCH_EXTENSION_NAME
#define TORCH_EXTENSION_NAME deep_ep_cpp
#endif
Expand Down Expand Up @@ -79,7 +83,10 @@ struct Buffer {
shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS];

// Stream for communication
at::cuda::CUDAStream comm_stream;
cudaStream_t comm_stream;

phi::distributed::NCCLCommContext* comm_ctx;
phi::GPUContext* calc_ctx;

// After IPC/NVSHMEM synchronization, this flag will be true
bool available = false;
Expand Down Expand Up @@ -118,7 +125,8 @@ struct Buffer {
bool low_latency_mode,
bool disable_nvlink_for_normal_mode,
bool explicitly_destroy,
bool use_fabric);
bool use_fabric,
int context_ring_id);

~Buffer() noexcept(false);

Expand All @@ -140,7 +148,7 @@ struct Buffer {

torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const;

torch::Stream get_comm_stream() const;
cudaStream_t get_comm_stream() const;

void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt);

Expand Down Expand Up @@ -224,4 +232,11 @@ struct Buffer {
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;
};

inline void SetAllocatorStreamForGPUContext(gpuStream_t stream,
phi::GPUContext* ctx) {
ctx->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(ctx->GetPlace(), stream)
.get());
}

} // namespace deep_ep
23 changes: 14 additions & 9 deletions csrc/event.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include <ATen/cuda/CUDAContext.h>
#pragma once
// #include <ATen/cuda/CUDAContext.h>
#include <c10/core/Event.h>
#include <memory>

#include "kernels/exception.cuh"
Expand All @@ -13,31 +15,34 @@ struct EventHandle {
event->record(at::cuda::getCurrentCUDAStream());
}

explicit EventHandle(const at::cuda::CUDAStream& stream) {
explicit EventHandle(const cudaStream_t& stream) {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(stream);
}

EventHandle(const EventHandle& other) = default;

void current_stream_wait() const {
at::cuda::getCurrentCUDAStream().unwrap().wait(*event);
CUDA_CHECK(cudaStreamWaitEvent(
at::cuda::getCurrentCUDAStream().raw_stream(),
event->cuda_event(),
0));
}
};

torch::Event create_event(const at::cuda::CUDAStream &s) {
torch::Event create_event(const cudaStream_t &s) {
auto event = torch::Event(torch::kCUDA);
event.record(s);
return event;
}

void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) {
EP_HOST_ASSERT(s_0.id() != s_1.id());
s_0.unwrap().wait(create_event(s_1));
inline void stream_wait(const cudaStream_t& s_0, const cudaStream_t& s_1) {
EP_HOST_ASSERT(s_0 != s_1);
CUDA_CHECK(cudaStreamWaitEvent(s_0, create_event(s_1).cuda_event(), 0));
}

void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) {
s.unwrap().wait(*event.event);
inline void stream_wait(const cudaStream_t& s, const EventHandle& event) {
CUDA_CHECK(cudaStreamWaitEvent(s, event.event->cuda_event(), 0));
}

} // namespace deep_ep
4 changes: 2 additions & 2 deletions csrc/hybrid_ep/allocator/allocator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ bool ExtendedMemoryAllocator::is_accessible(MemHandle* mem_handle) {

int ExtendedMemoryAllocator::detect_accessible_ranks(pybind11::object process_group) {
auto torch_distributed = py::module_::import("torch.distributed");
int world_size = process_group.attr("size")().cast<int>();
int current_rank = process_group.attr("rank")().cast<int>();
int world_size = process_group.attr("world_size").cast<int>();
int current_rank = process_group.attr("rank").cast<int>();
auto stream = at::cuda::getCurrentCUDAStream();

// Put the test memory handle on a CUDA tensor
Expand Down
Loading