-
Notifications
You must be signed in to change notification settings - Fork 80
Cuda-graph capturable Dispatch and combine #6031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
75f42b3
a83937d
93af651
d67780b
7480e88
8b7c2de
5377fed
ce82b07
75fa32f
b8c4c78
ee790cb
86bc54b
1f622c7
d955e65
9c3e464
2cc0919
c5e1ae6
2130af3
0c0fad3
b45edf3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1457,25 +1457,10 @@ void alltoallvWithCudaBackend( | |
| const at::Tensor& send, | ||
| const at::Tensor& recv, | ||
| const AlltoallvMetadata& metadata, | ||
| const std::vector<void*>& recv_ptrs, | ||
| const at::Tensor& recv_ptrs_gpu, | ||
| CUstream stream) { | ||
| NVF_CHECK(send.is_cuda(), "alltoallv send must be CUDA."); | ||
| NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA."); | ||
| NVF_CHECK( | ||
| (int64_t)recv_ptrs.size() == metadata.world_size, | ||
| "recv_ptrs size must match world size."); | ||
|
|
||
| auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); | ||
| auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options); | ||
| auto* ptrs = recv_ptrs_cpu.data_ptr<int64_t>(); | ||
| for (int64_t rank = 0; rank < metadata.world_size; ++rank) { | ||
| ptrs[rank] = | ||
| static_cast<int64_t>(reinterpret_cast<uintptr_t>(recv_ptrs[rank])); | ||
| } | ||
| auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device()); | ||
|
|
||
| const int64_t elem_stride = | ||
| metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1; | ||
| NVF_CHECK( | ||
| metadata.max_send_total == 0 || | ||
| send.numel() % metadata.max_send_total == 0, | ||
|
|
@@ -1484,6 +1469,9 @@ void alltoallvWithCudaBackend( | |
| metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0, | ||
| "alltoallv recv numel must be divisible by max_recv."); | ||
|
|
||
| const int64_t elem_stride = | ||
| metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1; | ||
|
Comment on lines
1472
to
1473
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Divisibility guard removed — silent wrong The PR removes the checks: NVF_CHECK(
metadata.max_send_total == 0 ||
send.numel() % metadata.max_send_total == 0, ...);
NVF_CHECK(
metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0, ...);
|
||
|
|
||
| auto send_offsets = metadata.send_offsets; | ||
| auto send_counts = metadata.send_counts; | ||
| auto recv_offsets = metadata.recv_offsets; | ||
|
|
@@ -1497,7 +1485,7 @@ void alltoallvWithCudaBackend( | |
|
|
||
| launchAlltoallvKernel( | ||
| send.data_ptr(), | ||
| reinterpret_cast<const uint64_t*>(recv_ptrs_cuda.data_ptr<int64_t>()), | ||
| reinterpret_cast<const uint64_t*>(recv_ptrs_gpu.data_ptr<int64_t>()), | ||
| send_offsets.data_ptr<int64_t>(), | ||
| send_counts.data_ptr<int64_t>(), | ||
| recv_offsets.data_ptr<int64_t>(), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recv_ptrs_gpu— no size or device checkThe old call-site accepted
const std::vector<void*>& recv_ptrsand explicitly verified:It also coerced the pointer table to the send device via
.to(send.device()).The new
at::Tensor recv_ptrs_gpuhas neither check: if it has fewer thanworld_sizeentries the kernel silently reads garbage pointers; if it lives on the wrong device the launch will fault.remotePointersTensor()always produces a[world_size]tensor on the right device by construction, but the API contract is now implicit and fragile for any future caller. Consider adding: