diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 9ec5c4f..dc703c2 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -31,3 +31,35 @@ steps: cuda: "*" if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 30 + + - label: "Julia v1 -- AMDGPU" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-test#v1: ~ + - JuliaCI/julia-coverage#v1: + dirs: + - src + - ext + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 30 + + - label: "Julia LTS -- AMDGPU" + plugins: + - JuliaCI/julia#v1: + version: "1.10" # "lts" isn't valid + - JuliaCI/julia-test#v1: ~ + - JuliaCI/julia-coverage#v1: + dirs: + - src + - ext + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 30 diff --git a/Project.toml b/Project.toml index 169cca7..9e3d854 100644 --- a/Project.toml +++ b/Project.toml @@ -8,14 +8,17 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" [weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" PtrArrays = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" [extensions] +StridedViewsAMDGPUExt = "AMDGPU" StridedViewsCUDAExt = "CUDA" StridedViewsPtrArraysExt = "PtrArrays" [compat] +AMDGPU = "2" Aqua = "0.8" CUDA = "4,5" JET = "0.9, 0.10, 0.11" @@ -27,6 +30,7 @@ Test = "1.6" julia = "1.10" [extras] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" @@ -35,4 +39,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Random", "Aqua", "JET", "PtrArrays", "CUDA"] +test = ["Test", "Random", "Aqua", "JET", "PtrArrays", "CUDA", "AMDGPU"] diff --git a/ext/StridedViewsAMDGPUExt.jl b/ext/StridedViewsAMDGPUExt.jl new file mode 100644 index 0000000..c15a112 --- /dev/null +++ b/ext/StridedViewsAMDGPUExt.jl @@ -0,0 +1,27 @@ +module StridedViewsAMDGPUExt + +using StridedViews +using AMDGPU +using AMDGPU: Adapt, ROCPtr + +const ROCStridedView{T, N, A <: ROCArray{T}} = StridedView{T, N, A} + +function Adapt.adapt_structure(to, A::ROCStridedView) + return StridedView( + Adapt.adapt_structure(to, parent(A)), + A.size, A.strides, A.offset, A.op + ) +end + +function Base.pointer(x::ROCStridedView{T}) where {T} + return Base.unsafe_convert(Ptr{T}, pointer(x.parent, x.offset + 1)) +end +function Base.unsafe_convert(::Type{Ptr{T}}, a::ROCStridedView{T}) where {T} + return convert(Ptr{T}, pointer(a)) +end + +function Base.print_array(io::IO, X::ROCStridedView) + return Base.print_array(io, Adapt.adapt_structure(Array, X)) +end + +end # module diff --git a/ext/StridedViewsCUDAExt.jl b/ext/StridedViewsCUDAExt.jl index 503e750..92f9bc2 100644 --- a/ext/StridedViewsCUDAExt.jl +++ b/ext/StridedViewsCUDAExt.jl @@ -6,9 +6,9 @@ using CUDA: Adapt, CuPtr const CuStridedView{T, N, A <: CuArray{T}} = StridedView{T, N, A} -function Adapt.adapt_structure(::Type{T}, A::StridedView) where {T} +function Adapt.adapt_structure(to, A::CuStridedView) return StridedView( - Adapt.adapt_structure(T, parent(A)), + Adapt.adapt_structure(to, parent(A)), A.size, A.strides, A.offset, A.op ) end diff --git a/test/runtests.jl b/test/runtests.jl index de22471..6e46933 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -298,3 +298,29 @@ if !is_buildkite JET.test_package(StridedViews; target_modules = (StridedViews,)) end end + +using CUDA, AMDGPU + +if CUDA.functional() + @testset "CuArrays with StridedView" begin + @testset for T in (Float64, ComplexF64) + A = CUDA.randn!(T, 10, 10, 10, 10) + @test isstrided(A) + B = StridedView(A) + @test B isa StridedView + @test B == A + end + end +end + +if AMDGPU.functional() + @testset "ROCArrays with StridedView" begin + @testset for T in (Float64, ComplexF64) + A = AMDGPU.randn!(T, 10, 10, 10, 10) + @test isstrided(A) + B = StridedView(A) + @test B isa StridedView + @test B == A + end + end +end