forked from deepseek-ai/DeepEP
-
Notifications
You must be signed in to change notification settings - Fork 3
[Compat] Make hybrid ep branch compatible with PaddlePaddle #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
SigureMo
wants to merge
2
commits into
hybrid-ep-paddle
Choose a base branch
from
hybrid-ep-paddle-dev
base: hybrid-ep-paddle
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| rm -rf build | ||
| rm -rf dist | ||
| rm -rf deep_ep_cpp.cpython-38-x86_64-linux-gnu.so | ||
| export TORCH_CUDA_ARCH_LIST="10.0" | ||
| export PADDLE_CUDA_ARCH_LIST="10.0" | ||
| python setup_deep_ep.py bdist_wheel | ||
| python setup_hybrid_ep.py bdist_wheel | ||
| pip install dist/*.whl --force-reinstall |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里把
compute_stream从at::cuda::getCurrentCUDAStream()改成了固定的calc_ctx->stream(),但EventHandle仍然在at::cuda::getCurrentCUDAStream()上 record/wait。这样 current-stream 语义已经分裂成两套来源:上层如果在自定义torch.cuda.Stream()上排 kernel,previous_event为空时这里不再自动等待调用者当前流,可能出现缺依赖或额外串行化。建议把 wait / allocate / record_stream 统一到同一个 stream 源。Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你说得有一定道理,那我能不能这样改,在调用 at::cuda::setCurrentCUDAStream() 时,同时更新 calc_ctx 的 stream:
// 封装一个工具函数 inline void set_current_stream(c10::cuda::CUDAStream stream, phi::GPUContext* calc_ctx) { // 1. 设置 ATen current stream(影响 at::cuda::getCurrentCUDAStream()) at::cuda::setCurrentCUDAStream(stream); // 2. 同步更新 Paddle GPUContext 的 stream // 假设 calc_ctx 提供了设置 stream 的方法 calc_ctx->SetStream(stream.stream()); // 或者等效方法 }然后替换 deep_ep.cpp 中所有的 setCurrentCUDAStream 调用:
注意,这里的修改是以 https://github.com/youge325/DeepEP/tree/cRemove_Optional 和 Paddle 最新的 develop 分支 为基准,或者你有更好的,能够尽可能减小 diff 的方案
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已解释:这个方向比现在更接近正确,但我不建议只在
setCurrentCUDAStream()这一处做“双向同步”。我前面指出的问题不只是在“切 allocator 时两边 stream 没同步”,更关键的是这里已经把 caller stream 的来源 从
at::cuda::getCurrentCUDAStream()改成了calc_ctx->stream():csrc/deep_ep.cpp:391这里先读auto compute_stream = calc_ctx->stream();previous_event为空时,后面stream_wait(comm_stream, compute_stream)等待的也是它EventHandle仍然在at::cuda::getCurrentCUDAStream()上 record/wait(csrc/event.hpp:13-29)所以如果只是在后面补一个
set_current_stream(...),默认路径里仍然可能没有等待调用方当前torch.cuda.Stream()上已经排好的工作;也就是说,问题的根还在“source of truth 已经换了”。如果想尽量减小 diff,我更倾向于:
auto compute_stream = at::cuda::getCurrentCUDAStream();calc_ctx/ allocator 跟随,再额外封一个很薄的 helper,把 当前选定的 ATen stream 同步给calc_ctx/ allocator;但不要反过来让calc_ctx->stream()成为主语义。这样改的好处是:
如果你后面准备这样收敛,我觉得是可以的;但建议先把“谁是唯一真源(single source of truth)”定下来,否则只是把 split stream contract 从一处挪到另一处。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考 #11 和 https://github.com/PaddlePaddle/Paddle 仓库,我们很经常使用
current_stream_wait吗,能不能使用void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1)和void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event)来代替,我看到在 https://github.com/PaddlePaddle/Paddle 仓库里,有关current_stream_wait的测试都被注释掉了我们的目的就是让 Paddle 来管理 stream ,这样比较好区分通信流和计算流,因此,ProcessGroup 自然要成为唯一真源,至于
current_stream_wait,可以用stream_wait来代替,通过显式传入 CUDAStream 和 EventHandle 来明确语义,从而消除代码歧义There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已解释:如果只是说“内部 C++ 的等待关系尽量改成显式
stream_wait(s_0, s_1)/stream_wait(s, event)”,我并不反对,这个形式本身比隐式 current-stream 更清晰。但这还不能单独消除我前面提的 blocker,因为这里的问题已经不只是内部 helper 选型,而是外部 contract 还没有一起迁移完:
current_stream_wait在这个 PR 里仍然是公开 API,不是单纯内部实现细节。EventOverlap.current_stream_wait()(deep_ep/utils.py:36-41),__exit__里也会直接调它。tests/test_intranode.py:105-106、tests/test_low_latency.py:71-72。所以如果目标真的是“ProcessGroup / Paddle 成为唯一真源”,这里只改内部 wait helper 还不够,外层语义契约还没有一起切过去。
当前迁移本身还是半闭环。
calc_stream_wait/comm_stream_wait(deep_ep/utils.py:43-47),current_stream_wait(csrc/deep_ep.cpp:1928-1930)。也就是说,PR 现在同时暴露了“current stream 语义”和“想转向显式 calc/comm 语义”的两套接口,但后者还没有真正闭合。
所以我的核心意见不变:
EventOverlapAPI、tests/docs 一起切到显式语义,不能只把current_stream_wait换成内部的stream_wait(...)就算完成。换句话说,我不是反对
stream_wait(...)这个形式;我反对的是当前 PR 还处在“内部语义想切,外部 contract 还没切完”的状态,这一点在 review 里仍然是 blocker。另外,Paddle 那边把
current_stream_wait相关测试注释掉,最多说明那边对 current-stream 契约本身也比较谨慎;对这个 PR 来说,反而更说明这里需要一次性把外部语义讲清楚,而不是保留半迁移状态。