Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions ext/AMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,10 @@ module AMDGPUExt

import MPI
isdefined(Base, :get_extension) ? (import AMDGPU) : (import ..AMDGPU)
import MPI: MPIPtr, Buffer, Datatype
import MPI: MPIPtr, Buffer, Datatype, CConvWrapper

function Base.cconvert(::Type{MPIPtr}, A::AMDGPU.ROCArray{T}) where T
A
end

function Base.unsafe_convert(::Type{MPIPtr}, X::AMDGPU.ROCArray{T}) where T
reinterpret(MPIPtr, Base.unsafe_convert(Ptr{T}, X))
end

# only need to define this for strided arrays: all others can be handled by generic machinery
function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:AMDGPU.ROCArray,I}
X = parent(V)
pX = Base.unsafe_convert(Ptr{T}, X)
pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T)
return reinterpret(MPIPtr, pV)
function Base.cconvert(::Type{MPIPtr}, x::AMDGPU.ROCArray{T}) where T
CConvWrapper(Ptr{T}, x)
end

function Buffer(arr::AMDGPU.ROCArray)
Expand Down
16 changes: 4 additions & 12 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,14 @@ module CUDAExt

import MPI
isdefined(Base, :get_extension) ? (import CUDA) : (import ..CUDA)
import MPI: MPIPtr, Buffer, Datatype
import MPI: MPIPtr, Buffer, Datatype, CConvWrapper

function Base.cconvert(::Type{MPIPtr}, buf::CUDA.CuArray{T}) where T
Base.cconvert(CUDA.CuPtr{T}, buf) # returns DeviceBuffer
CConvWrapper(CUDA.CuPtr{T}, buf)
end

function Base.unsafe_convert(::Type{MPIPtr}, X::CUDA.CuArray{T}) where T
reinterpret(MPIPtr, Base.unsafe_convert(CUDA.CuPtr{T}, X))
end

# only need to define this for strided arrays: all others can be handled by generic machinery
function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:CUDA.CuArray,I}
X = parent(V)
pX = Base.unsafe_convert(CUDA.CuPtr{T}, X)
pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T)
return reinterpret(MPIPtr, pV)
function Base.cconvert(::Type{MPIPtr}, buf::SubArray{T,N,P,I,true}) where {T,N,P<:CUDA.CuArray,I}
CConvWrapper(CUDA.CuPtr{T}, buf)
end

function Buffer(arr::CUDA.CuArray)
Expand Down
4 changes: 1 addition & 3 deletions src/api/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ end
primitive type MPIPtr Sys.WORD_SIZE
end
@assert sizeof(MPIPtr) == sizeof(Ptr{Cvoid})
Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = x
Base.unsafe_convert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x)

Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x)

# Initialize the ref constants from the library.
# This is not `API.__init__`, as it should be called _after_
Expand Down
71 changes: 64 additions & 7 deletions src/buffers.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,73 @@
MPIBuffertype{T} = Union{Ptr{T}, Array{T}, SubArray{T}, Ref{T}}
MPIBuffertypeOrConst{T} = Union{MPIBuffertype{T}, SentinelPtr}

Base.cconvert(::Type{MPIPtr}, x::Union{Ptr{T}, Array{T}, Ref{T}}) where T = Base.cconvert(Ptr{T}, x)
Base.cconvert(::Type{MPIPtr}, x::SubArray{T}) where T = Base.cconvert(Ptr{T}, x)
function Base.unsafe_convert(::Type{MPIPtr}, x::MPIBuffertype{T}) where T
ptr = Base.unsafe_convert(Ptr{T}, x)
# CConvWrapper: GC-safe adapter for converting Julia objects to MPIPtr in ccall.
#
# Background: ccall's argument conversion protocol works in two steps:
# 1. cconvert(T, x) — called before the ccall. Its return value is GC-rooted
# by ccall for the duration of the foreign call, keeping the underlying
# Julia object alive while a pointer to it is in use.
# 2. unsafe_convert(T, result_of_cconvert) — called on the GC-rooted result
# to extract the raw pointer. Crucially, dispatch is on the *return type*
# of cconvert, not the original argument type.
#
# Problem: because unsafe_convert dispatches on the cconvert return type, the
# unsafe_convert(::Type{MPIPtr}, ...) method must match whatever cconvert
# returned. If cconvert delegates to e.g. Base.cconvert(Ptr{T}, x), the return
# type depends on the Base implementation, so an unsafe_convert method written
# for the original type will never be called.
#
# Solution: CConvWrapper provides a single, predictable return type from
# cconvert(MPIPtr, x). The conversion proceeds as:
#
# ccall argument x::Array{Float64}
# │
# ▼
# cconvert(MPIPtr, x)
# calls Base.cconvert(Ptr{Float64}, x) — returns the Array (kept alive)
# wraps it in CConvWrapper{Ptr{Float64}}(array)
# ◄── ccall GC-roots this CConvWrapper, which holds the Array
# │
# ▼
# unsafe_convert(MPIPtr, wrapper::CConvWrapper{Ptr{Float64}})
# calls Base.unsafe_convert(Ptr{Float64}, wrapper.cconv) — extracts raw ptr
# reinterprets to MPIPtr
# ◄── only called while ccall holds the GC root on the wrapper
#
# Types that don't need GC protection (Ptr, Nothing, InPlace, SentinelPtr) skip
# the wrapper and return an MPIPtr directly from cconvert, since they are plain
# bit types with no GC-managed backing memory.
struct CConvWrapper{T, C}
# T: the intermediate pointer type (e.g. Ptr{Float64}, CuPtr{Float64})
# C: the type of the GC-rooted cconvert result (e.g. Array{Float64,1})
cconv::C # the GC-rooted object — kept alive by ccall holding the wrapper
end
function CConvWrapper(::Type{T}, x) where T
# Delegate to Base.cconvert(T, x) to get the GC-rootable object, then wrap
# it so unsafe_convert dispatch is predictable.
cconv = Base.cconvert(T, x)
CConvWrapper{T, typeof(cconv)}(cconv)
end

function Base.unsafe_convert(::Type{MPIPtr}, x::CConvWrapper{T}) where T
# Called by ccall while x (and thus x.cconv) is GC-rooted.
# Delegate to the Base pointer extraction, then reinterpret to MPIPtr.
ptr = Base.unsafe_convert(T, x.cconv)
reinterpret(MPIPtr, ptr)
end

# --- cconvert methods for types with GC-managed memory (use CConvWrapper) ---

function Base.cconvert(::Type{MPIPtr}, x::Union{Array{T}, SubArray{T}, Ref{T}}) where T
CConvWrapper(Ptr{T}, x)
end
function Base.cconvert(::Type{MPIPtr}, x::String)
CConvWrapper(Ptr{UInt8}, x)
end

# --- cconvert methods for plain bit types (no GC protection needed) ---

Base.cconvert(::Type{MPIPtr}, x::String) = x
Base.unsafe_convert(::Type{MPIPtr}, x::String) = reinterpret(MPIPtr, pointer(x))
Base.cconvert(::Type{MPIPtr}, ptr::Ptr) = reinterpret(MPIPtr, ptr)

Base.cconvert(::Type{MPIPtr}, ::Nothing) = reinterpret(MPIPtr, C_NULL)

Expand Down Expand Up @@ -45,7 +102,7 @@ MPIPtr

struct InPlace
end
Base.cconvert(::Type{MPIPtr}, ::InPlace) = API.MPI_IN_PLACE[]
Base.cconvert(::Type{MPIPtr}, ::InPlace) = reinterpret(MPIPtr, API.MPI_IN_PLACE[])


"""
Expand Down
Loading