From 60804eecbba6f428e946efc4f63981e305eaaa01 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 25 Feb 2026 10:44:32 -0500 Subject: [PATCH] Apply runic formatting --- benchmarks/benchtests.jl | 56 ++++++----- src/Strided.jl | 4 +- src/broadcast.jl | 26 +++--- src/convert.jl | 4 +- src/linalg.jl | 82 +++++++++++------ src/macros.jl | 22 +++-- src/mapreduce.jl | 162 ++++++++++++++++++++------------ test/othertests.jl | 194 ++++++++++++++++++++++++--------------- test/runtests.jl | 2 +- 9 files changed, 341 insertions(+), 211 deletions(-) diff --git a/benchmarks/benchtests.jl b/benchmarks/benchtests.jl index f860151..9617987 100644 --- a/benchmarks/benchtests.jl +++ b/benchmarks/benchtests.jl @@ -22,7 +22,7 @@ function benchmark_sum(sizes) return times end -function benchmark_permute(sizes, p=(4, 3, 2, 1)) +function benchmark_permute(sizes, p = (4, 3, 2, 1)) times = zeros(length(sizes), 4) for (i, s) in enumerate(sizes) A = randn(Float64, s .* one.(p)) @@ -41,7 +41,7 @@ permute_times1 = benchmark_permute(sizes, (4, 3, 2, 1)) permute_times2 = benchmark_permute(sizes, (2, 3, 4, 1)) permute_times3 = benchmark_permute(sizes, (3, 4, 1, 2)) -function benchmark_mul(sizesm, sizesk=sizesm, sizesn=sizesm) +function benchmark_mul(sizesm, sizesk = sizesm, sizesn = sizesm) N = Threads.nthreads() @assert length(sizesm) == length(sizesk) == length(sizesn) times = zeros(length(sizesm), 4) @@ -62,23 +62,23 @@ function benchmark_mul(sizesm, sizesk=sizesm, sizesn=sizesm) BLAS.set_num_threads(1) # single-threaded blas with strided multithreading Strided.enable_threaded_mul() times[i, 4] = @belapsed @strided mul!($C, $A, $B) - println("step $i: sizes $((m,k,n)) => times = $(times[i, :])") + println("step $i: sizes $((m, k, n)) => times = $(times[i, :])") end return times end function tensorcontraction!(wEnv, hamAB, hamBA, rhoBA, rhoAB, w, v, u) @tensor wEnv[-1, -2, -3] = hamAB[7, 8, -1, 9] * rhoBA[4, 3, -3, 2] * conj(w[7, 5, 4]) * - u[9, 10, -2, 11] * conj(u[8, 10, 5, 6]) * v[1, 11, 2] * - conj(v[1, 6, 3]) + - hamBA[1, 2, 3, 4] * rhoBA[10, 7, -3, 6] * - conj(w[-1, 11, 10]) * u[3, 4, -2, 8] * conj(u[1, 2, 11, 9]) * - v[5, 8, 6] * conj(v[5, 9, 7]) + - hamAB[5, 7, 3, 1] * rhoBA[10, 9, -3, 8] * - conj(w[-1, 11, 10]) * u[4, 3, -2, 2] * conj(u[4, 5, 11, 6]) * - v[1, 2, 8] * conj(v[7, 6, 9]) + - hamBA[3, 7, 2, -1] * rhoAB[5, 6, 4, -3] * v[2, 1, 4] * - conj(v[3, 1, 5]) * conj(w[7, -2, 6]) + u[9, 10, -2, 11] * conj(u[8, 10, 5, 6]) * v[1, 11, 2] * + conj(v[1, 6, 3]) + + hamBA[1, 2, 3, 4] * rhoBA[10, 7, -3, 6] * + conj(w[-1, 11, 10]) * u[3, 4, -2, 8] * conj(u[1, 2, 11, 9]) * + v[5, 8, 6] * conj(v[5, 9, 7]) + + hamAB[5, 7, 3, 1] * rhoBA[10, 9, -3, 8] * + conj(w[-1, 11, 10]) * u[4, 3, -2, 2] * conj(u[4, 5, 11, 6]) * + v[1, 2, 8] * conj(v[7, 6, 9]) + + hamBA[3, 7, 2, -1] * rhoAB[5, 6, 4, -3] * v[2, 1, 4] * + conj(v[3, 1, 5]) * conj(w[7, -2, 6]) return wEnv end @@ -100,32 +100,42 @@ function benchmark_tensorcontraction(sizes) BLAS.set_num_threads(1) Strided.disable_threads() Strided.disable_threaded_mul() - times[i, 1] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB, - $w, $v, $u) + times[i, 1] = @belapsed tensorcontraction!( + $wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB, + $w, $v, $u + ) BLAS.set_num_threads(1) Strided.enable_threads() Strided.disable_threaded_mul() - times[i, 2] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB, - $w, $v, $u) + times[i, 2] = @belapsed tensorcontraction!( + $wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB, + $w, $v, $u + ) BLAS.set_num_threads(N) Strided.disable_threads() Strided.disable_threaded_mul() - times[i, 3] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB, - $w, $v, $u) + times[i, 3] = @belapsed tensorcontraction!( + $wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB, + $w, $v, $u + ) BLAS.set_num_threads(N) Strided.enable_threads() Strided.disable_threaded_mul() - times[i, 4] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB, - $w, $v, $u) + times[i, 4] = @belapsed tensorcontraction!( + $wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB, + $w, $v, $u + ) BLAS.set_num_threads(1) Strided.enable_threads() Strided.enable_threaded_mul() - times[i, 5] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB, - $w, $v, $u) + times[i, 5] = @belapsed tensorcontraction!( + $wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB, + $w, $v, $u + ) println("step $i: size $s => times = $(times[i, :])") end diff --git a/src/Strided.jl b/src/Strided.jl index da4deb2..0daba19 100644 --- a/src/Strided.jl +++ b/src/Strided.jl @@ -2,7 +2,7 @@ module Strided import Base: parent, size, strides, tail, setindex using Base: @propagate_inbounds, RangeIndex, Dims -const SliceIndex = Union{RangeIndex,Colon} +const SliceIndex = Union{RangeIndex, Colon} using LinearAlgebra @@ -27,7 +27,7 @@ function set_num_threads(n::Int) return _NTHREADS[] = n end @noinline function _set_num_threads_warn(n) - @warn "Maximal number of threads limited by number of Julia threads, + return @warn "Maximal number of threads limited by number of Julia threads, setting number of threads equal to Threads.nthreads() = $n" end diff --git a/src/broadcast.jl b/src/broadcast.jl index 15da5ad..229acbb 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -3,29 +3,33 @@ using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Bro struct StridedArrayStyle{N} <: AbstractArrayStyle{N} end -Broadcast.BroadcastStyle(::Type{<:StridedView{<:Any,N}}) where {N} = StridedArrayStyle{N}() +Broadcast.BroadcastStyle(::Type{<:StridedView{<:Any, N}}) where {N} = StridedArrayStyle{N}() StridedArrayStyle(::Val{N}) where {N} = StridedArrayStyle{N}() -StridedArrayStyle{M}(::Val{N}) where {M,N} = StridedArrayStyle{N}() +StridedArrayStyle{M}(::Val{N}) where {M, N} = StridedArrayStyle{N}() Broadcast.BroadcastStyle(a::StridedArrayStyle, ::DefaultArrayStyle{0}) = a function Broadcast.BroadcastStyle(::StridedArrayStyle{N}, a::DefaultArrayStyle) where {N} return BroadcastStyle(DefaultArrayStyle{N}(), a) end -function Broadcast.BroadcastStyle(::StridedArrayStyle{N}, - ::Broadcast.Style{Tuple}) where {N} +function Broadcast.BroadcastStyle( + ::StridedArrayStyle{N}, + ::Broadcast.Style{Tuple} + ) where {N} return DefaultArrayStyle{N}() end -function Base.similar(bc::Broadcasted{<:StridedArrayStyle{N}}, ::Type{T}) where {N,T} +function Base.similar(bc::Broadcasted{<:StridedArrayStyle{N}}, ::Type{T}) where {N, T} return StridedView(similar(convert(Broadcasted{DefaultArrayStyle{N}}, bc), T)) end -Base.dotview(a::StridedView{<:Any,N}, I::Vararg{SliceIndex,N}) where {N} = getindex(a, I...) +Base.dotview(a::StridedView{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} = getindex(a, I...) # Broadcasting implementation -@inline function Base.copyto!(dest::StridedView{<:Any,N}, - bc::Broadcasted{StridedArrayStyle{N}}) where {N} +@inline function Base.copyto!( + dest::StridedView{<:Any, N}, + bc::Broadcasted{StridedArrayStyle{N}} + ) where {N} # convert to map # flatten and only keep the StridedView arguments @@ -36,7 +40,7 @@ Base.dotview(a::StridedView{<:Any,N}, I::Vararg{SliceIndex,N}) where {N} = getin return dest end -const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}} +const WrappedScalarArgs = Union{AbstractArray{<:Any, 0}, Ref{<:Any}} @inline function capturestridedargs(t::Broadcasted, rest...) return (capturestridedargs(t.args...)..., capturestridedargs(rest...)...) @@ -64,7 +68,7 @@ function promoteshape1(sz::Dims{N}, a::StridedView) where {N} return StridedView(a.parent, sz, newstrides, a.offset, a.op) end -struct CaptureArgs{F,Args<:Tuple} +struct CaptureArgs{F, Args <: Tuple} f::F args::Args end @@ -84,7 +88,7 @@ end # Evaluate CaptureArgs (c::CaptureArgs)(vals...) = consume(c, vals)[1] -@inline function consume(c::CaptureArgs{F,Args}, vals) where {F,Args} +@inline function consume(c::CaptureArgs{F, Args}, vals) where {F, Args} args, newvals = t_consume(c.args, vals) return c.f(args...), newvals end diff --git a/src/convert.jl b/src/convert.jl index 78b837e..c1d6e84 100644 --- a/src/convert.jl +++ b/src/convert.jl @@ -4,13 +4,13 @@ function Base.Array(a::StridedView) return b end -function (Base.Array{T})(a::StridedView{S,N}) where {T,S,N} +function (Base.Array{T})(a::StridedView{S, N}) where {T, S, N} b = Array{T}(undef, size(a)) copy!(StridedView(b), a) return b end -function (Base.Array{T,N})(a::StridedView{S,N}) where {T,S,N} +function (Base.Array{T, N})(a::StridedView{S, N}) where {T, S, N} b = Array{T}(undef, size(a)) copy!(StridedView(b), a) return b diff --git a/src/linalg.jl b/src/linalg.jl index fade2f1..5b054ca 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -2,8 +2,10 @@ LinearAlgebra.rmul!(dst::StridedView, α::Number) = mul!(dst, dst, α) LinearAlgebra.lmul!(α::Number, dst::StridedView) = mul!(dst, α, dst) -function LinearAlgebra.mul!(dst::StridedView{<:Number,N}, α::Number, - src::StridedView{<:Number,N}) where {N} +function LinearAlgebra.mul!( + dst::StridedView{<:Number, N}, α::Number, + src::StridedView{<:Number, N} + ) where {N} if α == 1 copy!(dst, src) else @@ -11,8 +13,10 @@ function LinearAlgebra.mul!(dst::StridedView{<:Number,N}, α::Number, end return dst end -function LinearAlgebra.mul!(dst::StridedView{<:Number,N}, src::StridedView{<:Number,N}, - α::Number) where {N} +function LinearAlgebra.mul!( + dst::StridedView{<:Number, N}, src::StridedView{<:Number, N}, + α::Number + ) where {N} if α == 1 copy!(dst, src) else @@ -20,8 +24,10 @@ function LinearAlgebra.mul!(dst::StridedView{<:Number,N}, src::StridedView{<:Num end return dst end -function LinearAlgebra.axpy!(a::Number, X::StridedView{<:Number,N}, - Y::StridedView{<:Number,N}) where {N} +function LinearAlgebra.axpy!( + a::Number, X::StridedView{<:Number, N}, + Y::StridedView{<:Number, N} + ) where {N} if a == 1 Y .= X .+ Y else @@ -29,8 +35,10 @@ function LinearAlgebra.axpy!(a::Number, X::StridedView{<:Number,N}, end return Y end -function LinearAlgebra.axpby!(a::Number, X::StridedView{<:Number,N}, - b::Number, Y::StridedView{<:Number,N}) where {N} +function LinearAlgebra.axpby!( + a::Number, X::StridedView{<:Number, N}, + b::Number, Y::StridedView{<:Number, N} + ) where {N} if b == 1 axpy!(a, X, Y) elseif b == 0 @@ -41,9 +49,11 @@ function LinearAlgebra.axpby!(a::Number, X::StridedView{<:Number,N}, return Y end -function LinearAlgebra.mul!(C::StridedView{T,2}, - A::StridedView{<:Any,2}, B::StridedView{<:Any,2}, - α::Number=true, β::Number=false) where {T} +function LinearAlgebra.mul!( + C::StridedView{T, 2}, + A::StridedView{<:Any, 2}, B::StridedView{<:Any, 2}, + α::Number = true, β::Number = false + ) where {T} if !(eltype(C) <: LinearAlgebra.BlasFloat && eltype(A) == eltype(B) == eltype(C)) return __mul!(C, A, B, α, β) end @@ -62,7 +72,7 @@ function LinearAlgebra.mul!(C::StridedView{T,2}, return C end -function isblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat} +function isblasmatrix(A::StridedView{T, 2}) where {T <: LinearAlgebra.BlasFloat} if A.op == identity return stride(A, 1) == 1 || stride(A, 2) == 1 elseif A.op == conj @@ -71,7 +81,7 @@ function isblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat} return false end end -function getblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat} +function getblasmatrix(A::StridedView{T, 2}) where {T <: LinearAlgebra.BlasFloat} if A.op == identity if stride(A, 1) == 1 return A, 'N' @@ -84,8 +94,10 @@ function getblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat} end # here we will have C.op == :identity && stride(C,1) < stride(C,2) -function _mul!(C::StridedView{T,2}, A::StridedView{T,2}, B::StridedView{T,2}, - α::Number, β::Number) where {T<:LinearAlgebra.BlasFloat} +function _mul!( + C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2}, + α::Number, β::Number + ) where {T <: LinearAlgebra.BlasFloat} if stride(C, 1) == 1 && isblasmatrix(A) && isblasmatrix(B) nthreads = use_threaded_mul() ? get_num_threads() : 1 _threaded_blas_mul!(C, A, B, α, β, nthreads) @@ -94,12 +106,14 @@ function _mul!(C::StridedView{T,2}, A::StridedView{T,2}, B::StridedView{T,2}, end end -function _threaded_blas_mul!(C::StridedView{T,2}, A::StridedView{T,2}, B::StridedView{T,2}, - α::Number, β::Number, - nthreads) where {T<:LinearAlgebra.BlasFloat} +function _threaded_blas_mul!( + C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2}, + α::Number, β::Number, + nthreads + ) where {T <: LinearAlgebra.BlasFloat} m, n = size(C) m == size(A, 1) && n == size(B, 2) || throw(DimensionMismatch()) - if nthreads == 1 || m * n < 1024 + return if nthreads == 1 || m * n < 1024 A2, CA = getblasmatrix(A) B2, CB = getblasmatrix(B) LinearAlgebra.BLAS.gemm!(CA, CB, convert(T, α), A2, B2, convert(T, β), C) @@ -107,19 +121,27 @@ function _threaded_blas_mul!(C::StridedView{T,2}, A::StridedView{T,2}, B::Stride if m > n m2 = round(Int, m / 16) * 8 nthreads2 = nthreads >> 1 - t = Threads.@spawn _threaded_blas_mul!(C[1:($m2), :], A[1:($m2), :], B, α, β, - $nthreads2) - _threaded_blas_mul!(C[(m2 + 1):m, :], A[(m2 + 1):m, :], B, α, β, - nthreads - nthreads2) + t = Threads.@spawn _threaded_blas_mul!( + C[1:($m2), :], A[1:($m2), :], B, α, β, + $nthreads2 + ) + _threaded_blas_mul!( + C[(m2 + 1):m, :], A[(m2 + 1):m, :], B, α, β, + nthreads - nthreads2 + ) wait(t) return C else n2 = round(Int, n / 16) * 8 nthreads2 = nthreads >> 1 - t = Threads.@spawn _threaded_blas_mul!(C[:, 1:($n2)], A, B[:, 1:($n2)], α, β, - $nthreads2) - _threaded_blas_mul!(C[:, (n2 + 1):n], A, B[:, (n2 + 1):n], α, β, - nthreads - nthreads2) + t = Threads.@spawn _threaded_blas_mul!( + C[:, 1:($n2)], A, B[:, 1:($n2)], α, β, + $nthreads2 + ) + _threaded_blas_mul!( + C[:, (n2 + 1):n], A, B[:, (n2 + 1):n], α, β, + nthreads - nthreads2 + ) wait(t) return C end @@ -127,8 +149,10 @@ function _threaded_blas_mul!(C::StridedView{T,2}, A::StridedView{T,2}, B::Stride end # This implementation is faster than LinearAlgebra.generic_matmatmul -function __mul!(C::StridedView{<:Any,2}, A::StridedView{<:Any,2}, B::StridedView{<:Any,2}, - α::Number, β::Number) +function __mul!( + C::StridedView{<:Any, 2}, A::StridedView{<:Any, 2}, B::StridedView{<:Any, 2}, + α::Number, β::Number + ) (size(C, 1) == size(A, 1) && size(C, 2) == size(B, 2) && size(A, 2) == size(B, 1)) || throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B)), C has size $(size(C))")) diff --git a/src/macros.jl b/src/macros.jl index 46f4fd1..a6005dc 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -15,11 +15,15 @@ function _strided(ex::Expr) return Expr(:call, ex.args[1], map(_strided, ex.args[2:end])...) end elseif (ex.head == :(=) || ex.head == :(kw)) && ex.args[1] isa Symbol - return Expr(ex.head, ex.args[1], - Expr(:call, :(Strided.maybeunstrided), _strided(ex.args[2]))) + return Expr( + ex.head, ex.args[1], + Expr(:call, :(Strided.maybeunstrided), _strided(ex.args[2])) + ) elseif (ex.head == :(->)) - return Expr(ex.head, ex.args[1], - Expr(:call, :(Strided.maybeunstrided), _strided(ex.args[2]))) + return Expr( + ex.head, ex.args[1], + Expr(:call, :(Strided.maybeunstrided), _strided(ex.args[2])) + ) else return Expr(ex.head, map(_strided, ex.args)...) end @@ -35,7 +39,7 @@ maybestrided(A) = A function maybeunstrided(A::StridedView) Ap = A.parent if size(A) == size(Ap) && strides(A) == strides(Ap) && offset(A) == 0 && - A.op == identity + A.op == identity return Ap else return reshape(copy(A).parent, size(A)) @@ -52,8 +56,12 @@ macro unsafe_strided(args...) error("The first arguments to `@unsafe_strided` must be variable names") ex = Expr(:let, Expr(:block, [:($s = Strided.StridedView($s)) for s in syms]...), ex) - warnex = :(Base.depwarn("`@unsafe_strided A B C ... ex` is deprecated, use `@strided ex` instead.", - Symbol("@unsafe_strided"); force=true)) + warnex = :( + Base.depwarn( + "`@unsafe_strided A B C ... ex` is deprecated, use `@strided ex` instead.", + Symbol("@unsafe_strided"); force = true + ) + ) return esc(Expr(:block, warnex, ex)) end diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 270f2fc..8bfd7d9 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -1,19 +1,23 @@ # Methods based on map! -function Base.copy!(dst::StridedView{<:Any,N}, src::StridedView{<:Any,N}) where {N} +function Base.copy!(dst::StridedView{<:Any, N}, src::StridedView{<:Any, N}) where {N} return map!(identity, dst, src) end Base.conj!(a::StridedView{<:Real}) = a Base.conj!(a::StridedView) = map!(conj, a, a) -function LinearAlgebra.adjoint!(dst::StridedView{<:Any,N}, - src::StridedView{<:Any,N}) where {N} +function LinearAlgebra.adjoint!( + dst::StridedView{<:Any, N}, + src::StridedView{<:Any, N} + ) where {N} return copy!(dst, adjoint(src)) end -function Base.permutedims!(dst::StridedView{<:Any,N}, src::StridedView{<:Any,N}, - p) where {N} +function Base.permutedims!( + dst::StridedView{<:Any, N}, src::StridedView{<:Any, N}, + p + ) where {N} return copy!(dst, permutedims(src, p)) end -function Base.mapreduce(f, op, A::StridedView; dims=:, kw...) +function Base.mapreduce(f, op, A::StridedView; dims = :, kw...) return Base._mapreduce_dim(f, op, values(kw), A, dims) end @@ -29,14 +33,18 @@ function Base._mapreduce_dim(f, op, ::NamedTuple{()}, A::StridedView, dims) return Base.mapreducedim!(f, op, Base.reducedim_init(f, op, A, dims), A) end -function Base.map(f::F, a1::StridedView{<:Any,N}, - A::Vararg{StridedView{<:Any,N}}) where {F,N} +function Base.map( + f::F, a1::StridedView{<:Any, N}, + A::Vararg{StridedView{<:Any, N}} + ) where {F, N} T = Base.promote_eltype(a1, A...) return map!(f, similar(a1, T), a1, A...) end -function Base.map!(f::F, b::StridedView{<:Any,N}, a1::StridedView{<:Any,N}, - A::Vararg{StridedView{<:Any,N}}) where {F,N} +function Base.map!( + f::F, b::StridedView{<:Any, N}, a1::StridedView{<:Any, N}, + A::Vararg{StridedView{<:Any, N}} + ) where {F, N} dims = size(b) # Check dimesions @@ -52,7 +60,7 @@ function Base.map!(f::F, b::StridedView{<:Any,N}, a1::StridedView{<:Any,N}, return b end -function _mapreduce(f, op, A::StridedView{T}, nt=nothing) where {T} +function _mapreduce(f, op, A::StridedView{T}, nt = nothing) where {T} if length(A) == 0 b = Base.mapreduce_empty(f, op, T) return nt === nothing ? b : op(b, nt.init) @@ -71,9 +79,11 @@ function _mapreduce(f, op, A::StridedView{T}, nt=nothing) where {T} return out[ParentIndex(1)] end -function Base.mapreducedim!(f, op, b::StridedView{<:Any,N}, - a1::StridedView{<:Any,N}, - A::Vararg{StridedView{<:Any,N}}) where {N} +function Base.mapreducedim!( + f, op, b::StridedView{<:Any, N}, + a1::StridedView{<:Any, N}, + A::Vararg{StridedView{<:Any, N}} + ) where {N} outdims = size(b) dims = map(max, outdims, map(max, map(size, (a1, A...))...)) @@ -83,8 +93,10 @@ function Base.mapreducedim!(f, op, b::StridedView{<:Any,N}, return _mapreducedim!(f, op, nothing, dims, (b, a1, A...)) end -function _mapreducedim!((f), (op), (initop), - dims::Dims, arrays::Tuple{Vararg{StridedView}}) +function _mapreducedim!( + (f), (op), (initop), + dims::Dims, arrays::Tuple{Vararg{StridedView}} + ) if any(isequal(0), dims) if length(arrays[1]) != 0 && !isnothing(initop) map!(initop, arrays[1], arrays[1]) @@ -95,8 +107,10 @@ function _mapreducedim!((f), (op), (initop), return arrays[1] end -function _mapreduce_fuse!((f), (op), (initop), - dims::Dims, arrays::Tuple{Vararg{StridedView}}) +function _mapreduce_fuse!( + (f), (op), (initop), + dims::Dims, arrays::Tuple{Vararg{StridedView}} + ) # Fuse dimensions if possible: assume that at least one array, e.g. the output array in # arrays[1], has its strides sorted allstrides = map(strides, arrays) @@ -116,8 +130,10 @@ function _mapreduce_fuse!((f), (op), (initop), return _mapreduce_order!(f, op, initop, dims, allstrides, arrays) end -function _mapreduce_order!((f), (op), (initop), - dims, strides, arrays) +function _mapreduce_order!( + (f), (op), (initop), + dims, strides, arrays + ) M = length(arrays) N = length(dims) # sort order of loops/dimensions by modelling the importance of each dimension @@ -130,7 +146,7 @@ function _mapreduce_order!((f), (op), (initop), end importance = importance .* (dims .> 1) # put dims 1 at the back - p = TupleTools.sortperm(importance; rev=true) + p = TupleTools.sortperm(importance; rev = true) dims = TupleTools.getindices(dims, p) strides = broadcast(TupleTools.getindices, strides, (p,)) offsets = map(offset, arrays) @@ -139,8 +155,10 @@ function _mapreduce_order!((f), (op), (initop), end const MINTHREADLENGTH = 1 << 15 # minimal length before any kind of threading is applied -function _mapreduce_block!((f), (op), (initop), - dims, strides, offsets, costs, arrays) +function _mapreduce_block!( + (f), (op), (initop), + dims, strides, offsets, costs, arrays + ) bytestrides = map((s, stride) -> s .* stride, sizeof.(eltype.(arrays)), strides) strideorders = map(indexorder, strides) blocks = _computeblocks(dims, costs, bytestrides, strideorders) @@ -152,7 +170,7 @@ function _mapreduce_block!((f), (op), (initop), _mapreduce_kernel!(f, op, initop, dims, blocks, arrays, strides, offsets) elseif op !== nothing && _length(dims, strides[1]) == 1 # complete reduction T = eltype(arrays[1]) - spacing = isbitstype(T) ? max(1, div(64, sizeof(T))) : 1# to avoid false sharing + spacing = isbitstype(T) ? max(1, div(64, sizeof(T))) : 1 # to avoid false sharing threadedout = similar(arrays[1], spacing * get_num_threads()) a = arrays[1][ParentIndex(1)] if initop !== nothing @@ -161,8 +179,10 @@ function _mapreduce_block!((f), (op), (initop), _init_reduction!(threadedout, f, op, a) newarrays = (threadedout, Base.tail(arrays)...) - _mapreduce_threaded!(f, op, nothing, dims, blocks, strides, offsets, costs, - newarrays, get_num_threads(), spacing, 1) + _mapreduce_threaded!( + f, op, nothing, dims, blocks, strides, offsets, costs, + newarrays, get_num_threads(), spacing, 1 + ) for i in 1:get_num_threads() a = op(a, threadedout[(i - 1) * spacing + 1]) @@ -173,28 +193,32 @@ function _mapreduce_block!((f), (op), (initop), # make cost of dimensions with zero stride in output array (reduction dimensions), # so that they are not divided in threading (which would lead to race conditions) - _mapreduce_threaded!(f, op, initop, dims, blocks, strides, offsets, costs, arrays, - get_num_threads(), 0, 1) + _mapreduce_threaded!( + f, op, initop, dims, blocks, strides, offsets, costs, arrays, + get_num_threads(), 0, 1 + ) end return nothing end -_init_reduction!(out, f, op::Union{typeof(+),typeof(Base.add_sum)}, a) = fill!(out, zero(a)) -_init_reduction!(out, f, op::Union{typeof(*),typeof(Base.mul_prod)}, a) = fill!(out, one(a)) +_init_reduction!(out, f, op::Union{typeof(+), typeof(Base.add_sum)}, a) = fill!(out, zero(a)) +_init_reduction!(out, f, op::Union{typeof(*), typeof(Base.mul_prod)}, a) = fill!(out, one(a)) _init_reduction!(out, f, op::typeof(min), a) = fill!(out, a) _init_reduction!(out, f, op::typeof(max), a) = fill!(out, a) _init_reduction!(out, f, op::typeof(&), a) = fill!(out, true) _init_reduction!(out, f, op::typeof(|), a) = fill!(out, false) function _init_reduction!(out, f, op, a) return op(a, a) == a ? fill!(out, a) : - error("unknown reduction; incompatible with multithreading") + error("unknown reduction; incompatible with multithreading") end # nthreads: number of threads spacing: extra addition to offset of array 1, to account for # reduction -function _mapreduce_threaded!((f), (op), (initop), - dims, blocks, strides, offsets, costs, arrays, nthreads, - spacing, taskindex) +function _mapreduce_threaded!( + (f), (op), (initop), + dims, blocks, strides, offsets, costs, arrays, nthreads, + spacing, taskindex + ) if nthreads == 1 || prod(dims) <= MINTHREADLENGTH offset1 = offsets[1] + spacing * (taskindex - 1) spacedoffsets = (offset1, Base.tail(offsets)...) @@ -211,27 +235,33 @@ function _mapreduce_threaded!((f), (op), (initop), nnthreads = nthreads >> 1 newdims = setindex(dims, ndi, i) newoffsets = offsets - t = Threads.@spawn _mapreduce_threaded!(f, op, initop, newdims, blocks, strides, - newoffsets, costs, arrays, nnthreads, - spacing, taskindex) + t = Threads.@spawn _mapreduce_threaded!( + f, op, initop, newdims, blocks, strides, + newoffsets, costs, arrays, nnthreads, + spacing, taskindex + ) stridesi = getindex.(strides, i) newoffsets2 = offsets .+ ndi .* stridesi newdims2 = setindex(dims, di - ndi, i) nnthreads2 = nthreads - nnthreads - _mapreduce_threaded!(f, op, initop, newdims2, blocks, strides, newoffsets2, - costs, arrays, nnthreads2, spacing, taskindex + nnthreads) + _mapreduce_threaded!( + f, op, initop, newdims2, blocks, strides, newoffsets2, + costs, arrays, nnthreads2, spacing, taskindex + nnthreads + ) wait(t) end end return nothing end -@generated function _mapreduce_kernel!((f), (op), - (initop), dims::NTuple{N,Int}, - blocks::NTuple{N,Int}, - arrays::NTuple{M,StridedView}, - strides::NTuple{M,NTuple{N,Int}}, - offsets::NTuple{M,Int}) where {N,M} +@generated function _mapreduce_kernel!( + (f), (op), + (initop), dims::NTuple{N, Int}, + blocks::NTuple{N, Int}, + arrays::NTuple{M, StridedView}, + strides::NTuple{M, NTuple{N, Int}}, + offsets::NTuple{M, Int} + ) where {N, M} # many variables blockloopvars = Array{Symbol}(undef, N) @@ -353,8 +383,10 @@ end i = 1 if N >= 1 initex = quote - $(initblockdimvars[i]) = ifelse($(stridevars[i, 1]) == 0, 1, - $(blockdimvars[i])) + $(initblockdimvars[i]) = ifelse( + $(stridevars[i, 1]) == 0, 1, + $(blockdimvars[i]) + ) @simd for $(innerloopvars[i]) in Base.OneTo($(initblockdimvars[i])) $initex $(stepstride1ex[i]) @@ -364,8 +396,10 @@ end end for outer i in 2:N initex = quote - $(initblockdimvars[i]) = ifelse($(stridevars[i, 1]) == 0, 1, - $(blockdimvars[i])) + $(initblockdimvars[i]) = ifelse( + $(stridevars[i, 1]) == 0, 1, + $(blockdimvars[i]) + ) for $(innerloopvars[i]) in Base.OneTo($(initblockdimvars[i])) $initex $(stepstride1ex[i]) @@ -424,7 +458,7 @@ end return ex end -function indexorder(strides::NTuple{N,Int}) where {N} +function indexorder(strides::NTuple{N, Int}) where {N} # returns order such that strides[i] is the order[i]th smallest element of strides, not # counting zero strides zero strides have order 1 return ntuple(Val(N)) do i @@ -442,7 +476,7 @@ end function _length(dims::Tuple, strides::Tuple) return ifelse(iszero(strides[1]), 1, dims[1]) * - _length(Base.tail(dims), Base.tail(strides)) + _length(Base.tail(dims), Base.tail(strides)) end _length(dims::Tuple{}, strides::Tuple{}) = 1 function _maxlength(dims::Tuple, strides::Tuple{Vararg{Tuple}}) @@ -460,25 +494,31 @@ function _lastargmax(t::Tuple) end const BLOCKMEMORYSIZE = 1 << 15 # L1 cache size in bytes -function _computeblocks(dims::Tuple{}, costs::Tuple{}, - bytestrides::Tuple{Vararg{Tuple{}}}, - strideorders::Tuple{Vararg{Tuple{}}}, - blocksize::Int=BLOCKMEMORYSIZE) +function _computeblocks( + dims::Tuple{}, costs::Tuple{}, + bytestrides::Tuple{Vararg{Tuple{}}}, + strideorders::Tuple{Vararg{Tuple{}}}, + blocksize::Int = BLOCKMEMORYSIZE + ) return () end -function _computeblocks(dims::NTuple{N,Int}, costs::NTuple{N,Int}, - bytestrides::Tuple{Vararg{NTuple{N,Int}}}, - strideorders::Tuple{Vararg{NTuple{N,Int}}}, - blocksize::Int=BLOCKMEMORYSIZE) where {N} +function _computeblocks( + dims::NTuple{N, Int}, costs::NTuple{N, Int}, + bytestrides::Tuple{Vararg{NTuple{N, Int}}}, + strideorders::Tuple{Vararg{NTuple{N, Int}}}, + blocksize::Int = BLOCKMEMORYSIZE + ) where {N} if totalmemoryregion(dims, bytestrides) <= blocksize return dims end minstrideorder = minimum(minimum.(strideorders)) if all(isequal(minstrideorder), first.(strideorders)) d1 = dims[1] - dr = _computeblocks(tail(dims), tail(costs), - map(tail, bytestrides), map(tail, strideorders), blocksize) + dr = _computeblocks( + tail(dims), tail(costs), + map(tail, bytestrides), map(tail, strideorders), blocksize + ) return (d1, dr...) end diff --git a/test/othertests.jl b/test/othertests.jl index ff813ff..3ca6702 100644 --- a/test/othertests.jl +++ b/test/othertests.jl @@ -24,7 +24,7 @@ end B3 = permutedims(StridedView(R3), randperm(N)) A1 = convert(Array, B1) A2 = convert(Array{T}, B2) # test different converts - A3 = convert(Array{T,N}, B3) + A3 = convert(Array{T, N}, B3) C1 = deepcopy(B1) @test rmul!(B1, 1 // 2) ≈ rmul!(A1, 1 // 2) @@ -34,7 +34,7 @@ end @test axpby!(1 // 3, B1, 1 // 2, B3) ≈ axpby!(1 // 3, A1, 1 // 2, A3) @test axpby!(1, B2, 1, B1) ≈ axpby!(1, A2, 1, A1) @test map((x, y, z) -> sin(x) + y / exp(-abs(z)), B1, B2, B3) ≈ - map((x, y, z) -> sin(x) + y / exp(-abs(z)), A1, A2, A3) + map((x, y, z) -> sin(x) + y / exp(-abs(z)), A1, A2, A3) @test map((x, y, z) -> sin(x) + y / exp(-abs(z)), B1, B2, B3) isa StridedView @test map((x, y, z) -> sin(x) + y / exp(-abs(z)), B1, A2, B3) isa Array @test mul!(B1, 1, B2) ≈ mul!(A1, 1, A2) @@ -51,12 +51,12 @@ end B3 = permutedims(StridedView(R3), randperm(3)) A1 = convert(Array, B1) A2 = convert(Array{T}, B2) - A3 = convert(Array{T,3}, B3) + A3 = convert(Array{T, 3}, B3) @test @inferred(B1 .+ sin.(B2 .- 3)) ≈ A1 .+ sin.(A2 .- 3) @test @inferred(B2' .* B3 .- Ref(0.5)) ≈ A2' .* A3 .- Ref(0.5) @test @inferred(B2' .* B3 .- max.(abs.(B1), real.(B3))) ≈ - A2' .* A3 .- max.(abs.(A1), real.(A3)) + A2' .* A3 .- max.(abs.(A1), real.(A3)) @test (B1 .+ sin.(B2 .- 3)) isa StridedView @test (B2' .* B3 .- Ref(0.5)) isa StridedView @@ -68,41 +68,61 @@ end @testset "mapreduce with StridedView" begin @testset for T in (Float32, Float64, ComplexF32, ComplexF64) R1 = rand(T, (10, 10, 10, 10, 10, 10)) - @test sum(R1; dims=(1, 3, 5)) ≈ sum(StridedView(R1); dims=(1, 3, 5)) - @test mapreduce(sin, +, R1; dims=(1, 3, 5)) ≈ - mapreduce(sin, +, StridedView(R1); dims=(1, 3, 5)) + @test sum(R1; dims = (1, 3, 5)) ≈ sum(StridedView(R1); dims = (1, 3, 5)) + @test mapreduce(sin, +, R1; dims = (1, 3, 5)) ≈ + mapreduce(sin, +, StridedView(R1); dims = (1, 3, 5)) R2 = rand(T, (10, 10, 10)) R2c = copy(R2) - @test Strided._mapreducedim!(sin, +, identity, (10, 10, 10, 10, 10, 10), - (sreshape(StridedView(R2c), (10, 1, 1, 10, 10, 1)), - StridedView(R1))) ≈ - mapreduce(sin, +, R1; dims=(2, 3, 6)) .+ reshape(R2, (10, 1, 1, 10, 10, 1)) + @test Strided._mapreducedim!( + sin, +, identity, (10, 10, 10, 10, 10, 10), + ( + sreshape(StridedView(R2c), (10, 1, 1, 10, 10, 1)), + StridedView(R1), + ) + ) ≈ + mapreduce(sin, +, R1; dims = (2, 3, 6)) .+ reshape(R2, (10, 1, 1, 10, 10, 1)) R2c = copy(R2) - @test Strided._mapreducedim!(sin, +, x -> 0, (10, 10, 10, 10, 10, 10), - (sreshape(StridedView(R2c), (10, 1, 1, 10, 10, 1)), - StridedView(R1))) ≈ - mapreduce(sin, +, R1; dims=(2, 3, 6)) + @test Strided._mapreducedim!( + sin, +, x -> 0, (10, 10, 10, 10, 10, 10), + ( + sreshape(StridedView(R2c), (10, 1, 1, 10, 10, 1)), + StridedView(R1), + ) + ) ≈ + mapreduce(sin, +, R1; dims = (2, 3, 6)) R2c = copy(R2) β = rand(T) - @test Strided._mapreducedim!(sin, +, x -> β * x, (10, 10, 10, 10, 10, 10), - (sreshape(StridedView(R2c), (10, 1, 1, 10, 10, 1)), - StridedView(R1))) ≈ - mapreduce(sin, +, R1; dims=(2, 3, 6)) .+ - β .* reshape(R2, (10, 1, 1, 10, 10, 1)) + @test Strided._mapreducedim!( + sin, +, x -> β * x, (10, 10, 10, 10, 10, 10), + ( + sreshape(StridedView(R2c), (10, 1, 1, 10, 10, 1)), + StridedView(R1), + ) + ) ≈ + mapreduce(sin, +, R1; dims = (2, 3, 6)) .+ + β .* reshape(R2, (10, 1, 1, 10, 10, 1)) R2c = copy(R2) - @test Strided._mapreducedim!(sin, +, x -> β, (10, 10, 10, 10, 10, 10), - (sreshape(StridedView(R2c), (10, 1, 1, 10, 10, 1)), - StridedView(R1))) ≈ - mapreduce(sin, +, R1; dims=(2, 3, 6), init=β) + @test Strided._mapreducedim!( + sin, +, x -> β, (10, 10, 10, 10, 10, 10), + ( + sreshape(StridedView(R2c), (10, 1, 1, 10, 10, 1)), + StridedView(R1), + ) + ) ≈ + mapreduce(sin, +, R1; dims = (2, 3, 6), init = β) R2c = copy(R2) - @test Strided._mapreducedim!(sin, +, conj, (10, 10, 10, 10, 10, 10), - (sreshape(StridedView(R2c), (10, 1, 1, 10, 10, 1)), - StridedView(R1))) ≈ - mapreduce(sin, +, R1; dims=(2, 3, 6)) .+ - conj.(reshape(R2, (10, 1, 1, 10, 10, 1))) + @test Strided._mapreducedim!( + sin, +, conj, (10, 10, 10, 10, 10, 10), + ( + sreshape(StridedView(R2c), (10, 1, 1, 10, 10, 1)), + StridedView(R1), + ) + ) ≈ + mapreduce(sin, +, R1; dims = (2, 3, 6)) .+ + conj.(reshape(R2, (10, 1, 1, 10, 10, 1))) R3 = rand(T, (100, 100, 2)) - @test sum(R3; dims=(1, 2)) ≈ sum(StridedView(R3); dims=(1, 2)) + @test sum(R3; dims = (1, 2)) ≈ sum(StridedView(R3); dims = (1, 2)) end end @@ -136,49 +156,57 @@ end @test (@strided(A1 .+ sin.(A2 .- 3))) ≈ A1 .+ sin.(A2 .- 3) @test (@strided(A2' .* A3 .- Ref(0.5))) ≈ A2' .* A3 .- Ref(0.5) @test (@strided(A2' .* A3 .- max.(abs.(A1), real.(A3)))) ≈ - A2' .* A3 .- max.(abs.(A1), real.(A3)) + A2' .* A3 .- max.(abs.(A1), real.(A3)) B2 = view(A2, :, 1:2:10) @test (@strided(A1 .+ sin.(view(A2, :, 1:2:10) .- 3))) ≈ - (@strided(A1 .+ sin.(B2 .- 3))) ≈ - A1 .+ sin.(view(A2, :, 1:2:10) .- 3) + (@strided(A1 .+ sin.(B2 .- 3))) ≈ + A1 .+ sin.(view(A2, :, 1:2:10) .- 3) B2 = view(A2', :, 1:6) B3 = view(A3, :, 1:6, 4) @test (@strided(view(A2', :, 1:6) .* view(A3, :, 1:6, 4) .- Ref(0.5))) ≈ - (@strided(B2 .* B3 .- Ref(0.5))) ≈ - view(A2', :, 1:6) .* view(A3, :, 1:6, 4) .- Ref(0.5) + (@strided(B2 .* B3 .- Ref(0.5))) ≈ + view(A2', :, 1:6) .* view(A3, :, 1:6, 4) .- Ref(0.5) B2 = view(A2, :, 3) B3 = view(A3, 1:5, :, 2:2:10) B1 = view(A1, 1:5) B3b = view(A3, 4:4, 4:4, 2:2:10) - @test (@strided(view(A2, :, 3)' .* view(A3, 1:5, :, 2:2:10) .- - max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))))) ≈ - (@strided(B2' .* B3 .- max.(abs.(B1), real.(B3b)))) ≈ - view(A2, :, 3)' .* view(A3, 1:5, :, 2:2:10) .- - max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))) + @test ( + @strided( + view(A2, :, 3)' .* view(A3, 1:5, :, 2:2:10) .- + max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))) + ) + ) ≈ + (@strided(B2' .* B3 .- max.(abs.(B1), real.(B3b)))) ≈ + view(A2, :, 3)' .* view(A3, 1:5, :, 2:2:10) .- + max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))) B2 = reshape(A2, (10, 2, 5)) @test (@strided(A1 .+ sin.(reshape(A2, (10, 2, 5)) .- 3))) ≈ - (@strided(A1 .+ sin.(B2 .- 3))) ≈ - A1 .+ sin.(reshape(A2, (10, 2, 5)) .- 3) + (@strided(A1 .+ sin.(B2 .- 3))) ≈ + A1 .+ sin.(reshape(A2, (10, 2, 5)) .- 3) B2 = reshape(A2, 1, 100) B3 = reshape(A3, 100, 1, 10) @test (@strided(reshape(A2, 1, 100)' .* reshape(A3, 100, 1, 10) .- Ref(0.5))) ≈ - (@strided(B2' .* B3 .- Ref(0.5))) ≈ - reshape(A2, 1, 100)' .* reshape(A3, 100, 1, 10) .- Ref(0.5) + (@strided(B2' .* B3 .- Ref(0.5))) ≈ + reshape(A2, 1, 100)' .* reshape(A3, 100, 1, 10) .- Ref(0.5) B2 = view(A2, :, 3) B3 = reshape(view(A3, 1:5, :, :), 5, 10, 5, 2) B1 = view(A1, 1:5) B3b = view(A3, 4:4, 4:4, 2:2:10) - @test (@strided(view(A2, :, 3)' .* reshape(view(A3, 1:5, :, :), 5, 10, 5, 2) .- - max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))))) ≈ - (@strided(B2' .* B3 .- max.(abs.(B1), real.(B3b)))) ≈ - view(A2, :, 3)' .* reshape(view(A3, 1:5, :, :), 5, 10, 5, 2) .- - max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))) + @test ( + @strided( + view(A2, :, 3)' .* reshape(view(A3, 1:5, :, :), 5, 10, 5, 2) .- + max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))) + ) + ) ≈ + (@strided(B2' .* B3 .- max.(abs.(B1), real.(B3b)))) ≈ + view(A2, :, 3)' .* reshape(view(A3, 1:5, :, :), 5, 10, 5, 2) .- + max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))) x = @strided begin p = :A => A1 @@ -199,54 +227,70 @@ end @test (@unsafe_strided(A1, A2, A1 .+ sin.(A2 .- 3))) ≈ A1 .+ sin.(A2 .- 3) @test (@unsafe_strided(A2, A3, A2' .* A3 .- Ref(0.5))) ≈ A2' .* A3 .- Ref(0.5) @test (@unsafe_strided(A1, A2, A3, A2' .* A3 .- max.(abs.(A1), real.(A3)))) ≈ - A2' .* A3 .- max.(abs.(A1), real.(A3)) + A2' .* A3 .- max.(abs.(A1), real.(A3)) B2 = view(A2, :, 1:2:10) @test (@unsafe_strided(A1, A2, A1 .+ sin.(view(A2, :, 1:2:10) .- 3))) ≈ - (@unsafe_strided(A1, B2, A1 .+ sin.(B2 .- 3))) ≈ - A1 .+ sin.(view(A2, :, 1:2:10) .- 3) + (@unsafe_strided(A1, B2, A1 .+ sin.(B2 .- 3))) ≈ + A1 .+ sin.(view(A2, :, 1:2:10) .- 3) B2 = view(A2', :, 1:6) B3 = view(A3, :, 1:6, 4) - @test (@unsafe_strided(A2, A3, - view(A2', :, 1:6) .* view(A3, :, 1:6, 4) .- Ref(0.5))) ≈ - (@unsafe_strided(B2, B3, B2 .* B3 .- Ref(0.5))) ≈ - view(A2', :, 1:6) .* view(A3, :, 1:6, 4) .- Ref(0.5) + @test ( + @unsafe_strided( + A2, A3, + view(A2', :, 1:6) .* view(A3, :, 1:6, 4) .- Ref(0.5) + ) + ) ≈ + (@unsafe_strided(B2, B3, B2 .* B3 .- Ref(0.5))) ≈ + view(A2', :, 1:6) .* view(A3, :, 1:6, 4) .- Ref(0.5) B2 = view(A2, :, 3) B3 = view(A3, 1:5, :, 2:2:10) B1 = view(A1, 1:5) B3b = view(A3, 4:4, 4:4, 2:2:10) - @test (@unsafe_strided(A1, A2, A3, - view(A2, :, 3)' .* view(A3, 1:5, :, 2:2:10) .- - max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))))) ≈ - (@unsafe_strided(B1, B2, B3, B2' .* B3 .- max.(abs.(B1), real.(B3b)))) ≈ - view(A2, :, 3)' .* view(A3, 1:5, :, 2:2:10) .- - max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))) + @test ( + @unsafe_strided( + A1, A2, A3, + view(A2, :, 3)' .* view(A3, 1:5, :, 2:2:10) .- + max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))) + ) + ) ≈ + (@unsafe_strided(B1, B2, B3, B2' .* B3 .- max.(abs.(B1), real.(B3b)))) ≈ + view(A2, :, 3)' .* view(A3, 1:5, :, 2:2:10) .- + max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))) B2 = reshape(A2, (10, 2, 5)) @test (@unsafe_strided(A1, A2, A1 .+ sin.(reshape(A2, (10, 2, 5)) .- 3))) ≈ - (@unsafe_strided(A1, B2, A1 .+ sin.(B2 .- 3))) ≈ - A1 .+ sin.(reshape(A2, (10, 2, 5)) .- 3) + (@unsafe_strided(A1, B2, A1 .+ sin.(B2 .- 3))) ≈ + A1 .+ sin.(reshape(A2, (10, 2, 5)) .- 3) B2 = reshape(A2, 1, 100) B3 = reshape(A3, 100, 1, 10) - @test (@unsafe_strided(A2, A3, - reshape(A2, 1, 100)' .* reshape(A3, 100, 1, 10) .- Ref(0.5))) ≈ - (@unsafe_strided(B2, B3, B2' .* B3 .- Ref(0.5))) ≈ - reshape(A2, 1, 100)' .* reshape(A3, 100, 1, 10) .- Ref(0.5) + @test ( + @unsafe_strided( + A2, A3, + reshape(A2, 1, 100)' .* reshape(A3, 100, 1, 10) .- Ref(0.5) + ) + ) ≈ + (@unsafe_strided(B2, B3, B2' .* B3 .- Ref(0.5))) ≈ + reshape(A2, 1, 100)' .* reshape(A3, 100, 1, 10) .- Ref(0.5) B2 = view(A2, :, 3) B3 = reshape(view(A3, 1:5, :, :), 5, 10, 5, 2) B1 = view(A1, 1:5) B3b = view(A3, 4:4, 4:4, 2:2:10) - @test (@unsafe_strided(A1, A2, A3, - view(A2, :, 3)' .* - reshape(view(A3, 1:5, :, :), 5, 10, 5, 2) .- - max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))))) ≈ - (@unsafe_strided(B1, B2, B3, B2' .* B3 .- max.(abs.(B1), real.(B3b)))) ≈ - view(A2, :, 3)' .* reshape(view(A3, 1:5, :, :), 5, 10, 5, 2) .- - max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))) + @test ( + @unsafe_strided( + A1, A2, A3, + view(A2, :, 3)' .* + reshape(view(A3, 1:5, :, :), 5, 10, 5, 2) .- + max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))) + ) + ) ≈ + (@unsafe_strided(B1, B2, B3, B2' .* B3 .- max.(abs.(B1), real.(B3b)))) ≈ + view(A2, :, 3)' .* reshape(view(A3, 1:5, :, :), 5, 10, 5, 2) .- + max.(abs.(view(A1, 1:5)), real.(view(A3, 4:4, 4:4, 2:2:10))) end end diff --git a/test/runtests.jl b/test/runtests.jl index 3801294..fc411cc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,4 +24,4 @@ include("blasmultests.jl") Strided.disable_threaded_mul() using Aqua -Aqua.test_all(Strided; piracies=false) +Aqua.test_all(Strided; piracies = false)