diff --git a/Project.toml b/Project.toml index 55aca8e..3472e21 100644 --- a/Project.toml +++ b/Project.toml @@ -10,11 +10,13 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [extensions] StridedAMDGPUExt = "AMDGPU" +StridedJLArraysExt = "JLArrays" StridedGPUArraysExt = "GPUArrays" StridedCUDAExt = "CUDA" @@ -22,6 +24,7 @@ StridedCUDAExt = "CUDA" AMDGPU = "2" Aqua = "0.8" CUDA = "5" +JLArrays = "0.3.1" GPUArrays = "11.4.1" LinearAlgebra = "1.6" Random = "1.6" @@ -34,9 +37,10 @@ julia = "1.6" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Random", "Aqua", "AMDGPU", "CUDA", "GPUArrays"] +test = ["Test", "Random", "Aqua", "AMDGPU", "CUDA", "GPUArrays", "JLArrays"] diff --git a/ext/StridedJLArraysExt.jl b/ext/StridedJLArraysExt.jl new file mode 100644 index 0000000..1ea05b5 --- /dev/null +++ b/ext/StridedJLArraysExt.jl @@ -0,0 +1,23 @@ +module StridedJLArraysExt + +using Strided, StridedViews, JLArrays +using JLArrays: Adapt +using JLArrays: GPUArrays + +const ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)} + +function Base.copy!(dst::StridedView{TD, ND, TAD, FD}, src::StridedView{TS, NS, TAS, FS}) where {TD <: Number, ND, TAD <: JLArray{TD}, FD <: ALL_FS, TS <: Number, NS, TAS <: JLArray{TS}, FS <: ALL_FS} + bc_style = Base.Broadcast.BroadcastStyle(TAS) + bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst)) + GPUArrays._copyto!(dst, bc) + return dst +end + +function Base.copy!(dst::AbstractArray{TD, ND}, src::StridedView{TS, NS, TAS, FS}) where {TD <: Number, ND, TS <: Number, NS, TAS <: JLArray{TS}, FS <: ALL_FS} + bc_style = Base.Broadcast.BroadcastStyle(TAS) + bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst)) + GPUArrays._copyto!(dst, bc) + return dst +end + +end diff --git a/test/jlarrays.jl b/test/jlarrays.jl new file mode 100644 index 0000000..9848347 --- /dev/null +++ b/test/jlarrays.jl @@ -0,0 +1,14 @@ +for T in (Float32, Float64, Complex{Float32}, Complex{Float64}) + @testset "Copy with JLArrayStridedView: $T, $f1, $f2" for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint) + for m1 in (0, 16, 32), m2 in (0, 16, 32) + A1 = JLArray(randn(T, (m1, m2))) + A2 = similar(A1) + A1c = copy(A1) + A2c = copy(A2) + B1 = f1(StridedView(A1c)) + B2 = f2(StridedView(A2c)) + axes(f1(A1)) == axes(f2(A2)) || continue + @test collect(Matrix(copy!(f2(A2), f1(A1)))) == JLArrays.Adapt.adapt(Vector{T}, copy!(B2, B1)) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 3f9ee6f..0b0dcc8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using Random using Strided using Strided: StridedView using Aqua -using AMDGPU, CUDA, GPUArrays +using JLArrays, AMDGPU, CUDA, GPUArrays Random.seed!(1234) @@ -28,6 +28,7 @@ if !is_buildkite include("blasmultests.jl") Strided.disable_threaded_mul() + include("jlarrays.jl") Aqua.test_all(Strided; piracies = false) end