[Compat] Make hybrid ep branch compatible with PaddlePaddle#12
[Compat] Make hybrid ep branch compatible with PaddlePaddle#12SigureMo wants to merge 2 commits intohybrid-ep-paddlefrom
Conversation
7febc6e to
a484936
Compare
91af98a to
c9fc08d
Compare
ShigureNyako
left a comment
There was a problem hiding this comment.
这版兼容改造我先不建议合入,主要有几点是“保留了 PyTorch 接口外形,但语义已经变了”,其中既有兼容性问题,也有明确的性能退化风险:
-
deep_ep.Buffer的comm路径已经失效。deep_ep/buffer.py:74-98, 128-133定义了all_gather_objecthelper,但后续同步阶段全部直接调用dist.all_gather_object(..., group),而且构造 runtime 时无条件取group.id(deep_ep/buffer.py:92-93)。- 这意味着文档里仍然声称支持
comm,但实际一走comm分支就会在group.id/dist.all_gather_object(..., group)处崩掉。 - 建议:要么把
comm路径完整接回(包括context_ring_id/ object gather),要么直接删掉该 API/文档并显式报错,避免半迁移状态。
-
current stream 语义已经和 PyTorch 版本不一致。
- 原实现这里用的是
at::cuda::getCurrentCUDAStream()/setCurrentCUDAStream(...),现在csrc/deep_ep.cpp多处改成固定读calc_ctx->stream(),并通过SetAllocatorStreamForGPUContext(...)切 allocator。 - 但
EventHandle仍然用at::cuda::getCurrentCUDAStream()录制/等待(csrc/event.hpp:13-29)。 - 结果就是“计算流”来源被拆成了 Paddle
GPUContext和 ATen current stream 两套定义。上层如果在自定义torch.cuda.Stream()上排 kernel,previous_event为空时这里不再自动等待调用者当前流,容易出现缺依赖或者额外串行化。 - 建议:至少保证 wait / allocate / record_stream 读取的是同一个 stream 源;如果必须走 Paddle stream,也需要把 current-stream 相关接口一起切过去,而不是混用。
- 原实现这里用的是
-
enable_custom_allgather参数被静默忽略,热路径会直接退化。- Python 侧还保留了
enable_custom_allgather参数(deep_ep/hybrid_ep_buffer.py:67),但创建 runtime 时直接硬编码成False(deep_ep/hybrid_ep_buffer.py:157-162)。 - 这样
Executor::allgather_routing_map会永久走paddle.distributed.stream.all_gather(...)分支,而不是原来的 custom intra-node allgather(csrc/hybrid_ep/executor/executor.cu:27-34)。 - 这既是接口语义变化,也是明确的性能风险。调用方即使显式传
True也拿不到原本的快路径。 - 建议:要么把参数删掉/显式报错,要么把 custom 路径按兼容层约束修好后继续透传,不要 silent fallback。
- Python 侧还保留了
-
HybridEP 的拓扑检测/校验被删掉后,默认“全员同一 NVLink 域”风险太高。
- 原来会用
ExtendedMemoryAllocator.detect_accessible_ranks()做检测;现在deep_ep/hybrid_ep_buffer.py:79-87直接默认num_of_hybrid_ep_ranks_per_nvlink_domain = world_size,deep_ep/buffer.py:65-66的 NVLink 检查也被注释掉了。 - 这在多节点、PCIe、或者局部不可达拓扑下不再 fail fast,而是可能直接走错路径。
- 建议:至少在自动检测不可用时强制要求显式传/设
NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN,不要默认world_size。
- 原来会用
附带一个建议尽快修的小点:csrc/hybrid_ep/hybrid_ep.cu:15-20 现在用重复的 group.id 生成 comm_id,而 csrc/hybrid_ep/jit/compiler.cuh:57 明确它的语义是 hash(all ranks in the process_group)。这会让 JIT cache key 丢失成员信息,最好还是按 group member 列表生成,避免后续 cache/句柄复用出现撞车。
| # Disable custom allgather by default because its data layout is incompatible with scan kernel | ||
| # The custom allgather kernel produces token-interleaved layout, but scan kernel expects | ||
| # the standard allgather layout (rank-blocked layout) | ||
| enable_custom_allgather = False # Always use standard allgather for correctness |
There was a problem hiding this comment.
这里把 enable_custom_allgather 直接硬编码成 False,会让 Python 暴露的参数失效,并且把原来的 intra-node fast path 永久降级到通用 all_gather。如果当前兼容层暂时不支持 custom allgather,建议至少在 Python 侧显式报错/移除该参数,而不是 silent fallback。
| self.explicitly_destroy = explicitly_destroy | ||
| self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, | ||
| self.disable_nvlink_for_normal_mode,explicitly_destroy, use_fabric) | ||
| self.disable_nvlink_for_normal_mode,explicitly_destroy, use_fabric, group.id) |
There was a problem hiding this comment.
这里已经把 group.id 当成必选输入传进 C++ runtime 了,但这个类的文档和分支逻辑仍然保留了 comm 兜底路径。后面同步阶段也直接写死成 dist.all_gather_object(..., group),不再走上面定义的 all_gather_object helper。这样一来 comm 分支实际上已经不可用了,属于典型的半迁移状态。建议要么把 comm 路径完整接回,要么在接口层显式去掉/报错,不要保留一个看起来可用但实际会崩的 API。
| else: | ||
| self.num_of_hybrid_ep_ranks_per_nvlink_domain = detected_ranks | ||
| # Default: assume all ranks are in the same NVLink domain (single node) | ||
| self.num_of_hybrid_ep_ranks_per_nvlink_domain = self.group_size |
There was a problem hiding this comment.
这里直接把默认值设成 group_size,等价于假设“所有 rank 都在同一个 NVLink 域”。原实现会先探测 accessible ranks,这里删掉以后,多节点 / PCIe / 局部不可达拓扑都不再 fail fast,而是可能静默走错路径。建议至少在自动探测不可用时要求显式设置 NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN,不要默认全员可达。
| // Allocate all tensors on comm stream if set | ||
| // NOTES: do not allocate tensors upfront! | ||
| auto compute_stream = at::cuda::getCurrentCUDAStream(); | ||
| auto compute_stream = calc_ctx->stream(); |
There was a problem hiding this comment.
这里把 compute_stream 从 at::cuda::getCurrentCUDAStream() 改成了固定的 calc_ctx->stream(),但 EventHandle 仍然在 at::cuda::getCurrentCUDAStream() 上 record/wait。这样 current-stream 语义已经分裂成两套来源:上层如果在自定义 torch.cuda.Stream() 上排 kernel,previous_event 为空时这里不再自动等待调用者当前流,可能出现缺依赖或额外串行化。建议把 wait / allocate / record_stream 统一到同一个 stream 源。
There was a problem hiding this comment.
你说得有一定道理,那我能不能这样改,在调用 at::cuda::setCurrentCUDAStream() 时,同时更新 calc_ctx 的 stream:
// 封装一个工具函数
inline void set_current_stream(c10::cuda::CUDAStream stream,
phi::GPUContext* calc_ctx) {
// 1. 设置 ATen current stream(影响 at::cuda::getCurrentCUDAStream())
at::cuda::setCurrentCUDAStream(stream);
// 2. 同步更新 Paddle GPUContext 的 stream
// 假设 calc_ctx 提供了设置 stream 的方法
calc_ctx->SetStream(stream.stream()); // 或者等效方法
}然后替换 deep_ep.cpp 中所有的 setCurrentCUDAStream 调用:
// 原代码(第464行)
deep_ep::SetAllocatorStreamForGPUContext(compute_stream.stream(), calc_ctx);
// 改为
set_current_stream(compute_stream, calc_ctx);
deep_ep::SetAllocatorStreamForGPUContext(compute_stream.stream(), calc_ctx);注意,这里的修改是以 https://github.com/youge325/DeepEP/tree/cRemove_Optional 和 Paddle 最新的 develop 分支 为基准,或者你有更好的,能够尽可能减小 diff 的方案
There was a problem hiding this comment.
你说得有一定道理,那我能不能这样改,在调用
at::cuda::setCurrentCUDAStream()时,同时更新calc_ctx的 stream ... 或者你有更好的、能够尽可能减小 diff 的方案
已解释:这个方向比现在更接近正确,但我不建议只在 setCurrentCUDAStream() 这一处做“双向同步”。
我前面指出的问题不只是在“切 allocator 时两边 stream 没同步”,更关键的是这里已经把 caller stream 的来源 从 at::cuda::getCurrentCUDAStream() 改成了 calc_ctx->stream():
csrc/deep_ep.cpp:391这里先读auto compute_stream = calc_ctx->stream();previous_event为空时,后面stream_wait(comm_stream, compute_stream)等待的也是它- 但
EventHandle仍然在at::cuda::getCurrentCUDAStream()上 record/wait(csrc/event.hpp:13-29)
所以如果只是在后面补一个 set_current_stream(...),默认路径里仍然可能没有等待调用方当前 torch.cuda.Stream() 上已经排好的工作;也就是说,问题的根还在“source of truth 已经换了”。
如果想尽量减小 diff,我更倾向于:
- 保持 caller stream 语义继续以 ATen current stream 为准(尽量贴近原 PyTorch 代码)
auto compute_stream = at::cuda::getCurrentCUDAStream();- wait / event / restore 都围绕同一个 stream 源
- 如果兼容层确实要求
calc_ctx/ allocator 跟随,再额外封一个很薄的 helper,把 当前选定的 ATen stream 同步给calc_ctx/ allocator;但不要反过来让calc_ctx->stream()成为主语义。
这样改的好处是:
- 语义上更接近上游 PyTorch 版本
- diff 也更小
- 不需要在每个调用点都推断“现在到底该信 ATen current stream 还是 Paddle context stream”
如果你后面准备这样收敛,我觉得是可以的;但建议先把“谁是唯一真源(single source of truth)”定下来,否则只是把 split stream contract 从一处挪到另一处。
There was a problem hiding this comment.
参考 #11 和 https://github.com/PaddlePaddle/Paddle 仓库,我们很经常使用 current_stream_wait 吗,能不能使用 void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) 和 void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) 来代替,我看到在 https://github.com/PaddlePaddle/Paddle 仓库里,有关 current_stream_wait 的测试都被注释掉了
我们的目的就是让 Paddle 来管理 stream ,这样比较好区分通信流和计算流,因此,ProcessGroup 自然要成为唯一真源,至于 current_stream_wait ,可以用 stream_wait 来代替,通过显式传入 CUDAStream 和 EventHandle 来明确语义,从而消除代码歧义
There was a problem hiding this comment.
参考 #11 和 https://github.com/PaddlePaddle/Paddle 仓库,我们很经常使用
current_stream_wait吗,能不能使用stream_wait(...)来代替;我们的目的就是让 Paddle 来管理 stream,因此 ProcessGroup 自然要成为唯一真源
已解释:如果只是说“内部 C++ 的等待关系尽量改成显式 stream_wait(s_0, s_1) / stream_wait(s, event)”,我并不反对,这个形式本身比隐式 current-stream 更清晰。
但这还不能单独消除我前面提的 blocker,因为这里的问题已经不只是内部 helper 选型,而是外部 contract 还没有一起迁移完:
-
current_stream_wait在这个 PR 里仍然是公开 API,不是单纯内部实现细节。- Python wrapper 还保留了
EventOverlap.current_stream_wait()(deep_ep/utils.py:36-41),__exit__里也会直接调它。 - repo 里的测试也还在大量依赖它,比如
tests/test_intranode.py:105-106、tests/test_low_latency.py:71-72。
所以如果目标真的是“ProcessGroup / Paddle 成为唯一真源”,这里只改内部 wait helper 还不够,外层语义契约还没有一起切过去。
- Python wrapper 还保留了
-
当前迁移本身还是半闭环。
- Python wrapper 已经新增了
calc_stream_wait/comm_stream_wait(deep_ep/utils.py:43-47), - 但 C++ / pybind 侧目前只实现并导出了
current_stream_wait(csrc/deep_ep.cpp:1928-1930)。
也就是说,PR 现在同时暴露了“current stream 语义”和“想转向显式 calc/comm 语义”的两套接口,但后者还没有真正闭合。
- Python wrapper 已经新增了
所以我的核心意见不变:
- 如果想维持最小 diff / 最接近上游 PyTorch 语义:继续让 ATen current stream 做唯一真源,再把选定 stream 同步给 Paddle allocator/context。
- 如果想彻底改成 Paddle/ProcessGroup 真源:那就需要把 event capture、wait target、Python
EventOverlapAPI、tests/docs 一起切到显式语义,不能只把current_stream_wait换成内部的stream_wait(...)就算完成。
换句话说,我不是反对 stream_wait(...) 这个形式;我反对的是当前 PR 还处在“内部语义想切,外部 contract 还没切完”的状态,这一点在 review 里仍然是 blocker。
另外,Paddle 那边把 current_stream_wait 相关测试注释掉,最多说明那边对 current-stream 契约本身也比较谨慎;对这个 PR 来说,反而更说明这里需要一次性把外部语义讲清楚,而不是保留半迁移状态。
No description provided.