Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions benchmarks/benchtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function benchmark_sum(sizes)
return times
end

function benchmark_permute(sizes, p=(4, 3, 2, 1))
function benchmark_permute(sizes, p = (4, 3, 2, 1))
times = zeros(length(sizes), 4)
for (i, s) in enumerate(sizes)
A = randn(Float64, s .* one.(p))
Expand All @@ -41,7 +41,7 @@ permute_times1 = benchmark_permute(sizes, (4, 3, 2, 1))
permute_times2 = benchmark_permute(sizes, (2, 3, 4, 1))
permute_times3 = benchmark_permute(sizes, (3, 4, 1, 2))

function benchmark_mul(sizesm, sizesk=sizesm, sizesn=sizesm)
function benchmark_mul(sizesm, sizesk = sizesm, sizesn = sizesm)
N = Threads.nthreads()
@assert length(sizesm) == length(sizesk) == length(sizesn)
times = zeros(length(sizesm), 4)
Expand All @@ -62,23 +62,23 @@ function benchmark_mul(sizesm, sizesk=sizesm, sizesn=sizesm)
BLAS.set_num_threads(1) # single-threaded blas with strided multithreading
Strided.enable_threaded_mul()
times[i, 4] = @belapsed @strided mul!($C, $A, $B)
println("step $i: sizes $((m,k,n)) => times = $(times[i, :])")
println("step $i: sizes $((m, k, n)) => times = $(times[i, :])")
end
return times
end

function tensorcontraction!(wEnv, hamAB, hamBA, rhoBA, rhoAB, w, v, u)
@tensor wEnv[-1, -2, -3] = hamAB[7, 8, -1, 9] * rhoBA[4, 3, -3, 2] * conj(w[7, 5, 4]) *
u[9, 10, -2, 11] * conj(u[8, 10, 5, 6]) * v[1, 11, 2] *
conj(v[1, 6, 3]) +
hamBA[1, 2, 3, 4] * rhoBA[10, 7, -3, 6] *
conj(w[-1, 11, 10]) * u[3, 4, -2, 8] * conj(u[1, 2, 11, 9]) *
v[5, 8, 6] * conj(v[5, 9, 7]) +
hamAB[5, 7, 3, 1] * rhoBA[10, 9, -3, 8] *
conj(w[-1, 11, 10]) * u[4, 3, -2, 2] * conj(u[4, 5, 11, 6]) *
v[1, 2, 8] * conj(v[7, 6, 9]) +
hamBA[3, 7, 2, -1] * rhoAB[5, 6, 4, -3] * v[2, 1, 4] *
conj(v[3, 1, 5]) * conj(w[7, -2, 6])
u[9, 10, -2, 11] * conj(u[8, 10, 5, 6]) * v[1, 11, 2] *
conj(v[1, 6, 3]) +
hamBA[1, 2, 3, 4] * rhoBA[10, 7, -3, 6] *
conj(w[-1, 11, 10]) * u[3, 4, -2, 8] * conj(u[1, 2, 11, 9]) *
v[5, 8, 6] * conj(v[5, 9, 7]) +
hamAB[5, 7, 3, 1] * rhoBA[10, 9, -3, 8] *
conj(w[-1, 11, 10]) * u[4, 3, -2, 2] * conj(u[4, 5, 11, 6]) *
v[1, 2, 8] * conj(v[7, 6, 9]) +
hamBA[3, 7, 2, -1] * rhoAB[5, 6, 4, -3] * v[2, 1, 4] *
conj(v[3, 1, 5]) * conj(w[7, -2, 6])
return wEnv
end

Expand All @@ -100,32 +100,42 @@ function benchmark_tensorcontraction(sizes)
BLAS.set_num_threads(1)
Strided.disable_threads()
Strided.disable_threaded_mul()
times[i, 1] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
$w, $v, $u)
times[i, 1] = @belapsed tensorcontraction!(
$wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
$w, $v, $u
)

BLAS.set_num_threads(1)
Strided.enable_threads()
Strided.disable_threaded_mul()
times[i, 2] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
$w, $v, $u)
times[i, 2] = @belapsed tensorcontraction!(
$wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
$w, $v, $u
)

BLAS.set_num_threads(N)
Strided.disable_threads()
Strided.disable_threaded_mul()
times[i, 3] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
$w, $v, $u)
times[i, 3] = @belapsed tensorcontraction!(
$wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
$w, $v, $u
)

BLAS.set_num_threads(N)
Strided.enable_threads()
Strided.disable_threaded_mul()
times[i, 4] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
$w, $v, $u)
times[i, 4] = @belapsed tensorcontraction!(
$wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
$w, $v, $u
)

BLAS.set_num_threads(1)
Strided.enable_threads()
Strided.enable_threaded_mul()
times[i, 5] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
$w, $v, $u)
times[i, 5] = @belapsed tensorcontraction!(
$wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
$w, $v, $u
)

println("step $i: size $s => times = $(times[i, :])")
end
Expand Down
4 changes: 2 additions & 2 deletions src/Strided.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module Strided

import Base: parent, size, strides, tail, setindex
using Base: @propagate_inbounds, RangeIndex, Dims
const SliceIndex = Union{RangeIndex,Colon}
const SliceIndex = Union{RangeIndex, Colon}

using LinearAlgebra

Expand All @@ -27,7 +27,7 @@ function set_num_threads(n::Int)
return _NTHREADS[] = n
end
@noinline function _set_num_threads_warn(n)
@warn "Maximal number of threads limited by number of Julia threads,
return @warn "Maximal number of threads limited by number of Julia threads,
setting number of threads equal to Threads.nthreads() = $n"
end

Expand Down
26 changes: 15 additions & 11 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,33 @@ using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Bro
struct StridedArrayStyle{N} <: AbstractArrayStyle{N}
end

Broadcast.BroadcastStyle(::Type{<:StridedView{<:Any,N}}) where {N} = StridedArrayStyle{N}()
Broadcast.BroadcastStyle(::Type{<:StridedView{<:Any, N}}) where {N} = StridedArrayStyle{N}()

StridedArrayStyle(::Val{N}) where {N} = StridedArrayStyle{N}()
StridedArrayStyle{M}(::Val{N}) where {M,N} = StridedArrayStyle{N}()
StridedArrayStyle{M}(::Val{N}) where {M, N} = StridedArrayStyle{N}()

Broadcast.BroadcastStyle(a::StridedArrayStyle, ::DefaultArrayStyle{0}) = a
function Broadcast.BroadcastStyle(::StridedArrayStyle{N}, a::DefaultArrayStyle) where {N}
return BroadcastStyle(DefaultArrayStyle{N}(), a)
end
function Broadcast.BroadcastStyle(::StridedArrayStyle{N},
::Broadcast.Style{Tuple}) where {N}
function Broadcast.BroadcastStyle(
::StridedArrayStyle{N},
::Broadcast.Style{Tuple}
) where {N}
return DefaultArrayStyle{N}()
end

function Base.similar(bc::Broadcasted{<:StridedArrayStyle{N}}, ::Type{T}) where {N,T}
function Base.similar(bc::Broadcasted{<:StridedArrayStyle{N}}, ::Type{T}) where {N, T}
return StridedView(similar(convert(Broadcasted{DefaultArrayStyle{N}}, bc), T))
end

Base.dotview(a::StridedView{<:Any,N}, I::Vararg{SliceIndex,N}) where {N} = getindex(a, I...)
Base.dotview(a::StridedView{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} = getindex(a, I...)

# Broadcasting implementation
@inline function Base.copyto!(dest::StridedView{<:Any,N},
bc::Broadcasted{StridedArrayStyle{N}}) where {N}
@inline function Base.copyto!(
dest::StridedView{<:Any, N},
bc::Broadcasted{StridedArrayStyle{N}}
) where {N}
# convert to map

# flatten and only keep the StridedView arguments
Expand All @@ -36,7 +40,7 @@ Base.dotview(a::StridedView{<:Any,N}, I::Vararg{SliceIndex,N}) where {N} = getin
return dest
end

const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}}
const WrappedScalarArgs = Union{AbstractArray{<:Any, 0}, Ref{<:Any}}

@inline function capturestridedargs(t::Broadcasted, rest...)
return (capturestridedargs(t.args...)..., capturestridedargs(rest...)...)
Expand Down Expand Up @@ -64,7 +68,7 @@ function promoteshape1(sz::Dims{N}, a::StridedView) where {N}
return StridedView(a.parent, sz, newstrides, a.offset, a.op)
end

struct CaptureArgs{F,Args<:Tuple}
struct CaptureArgs{F, Args <: Tuple}
f::F
args::Args
end
Expand All @@ -84,7 +88,7 @@ end

# Evaluate CaptureArgs
(c::CaptureArgs)(vals...) = consume(c, vals)[1]
@inline function consume(c::CaptureArgs{F,Args}, vals) where {F,Args}
@inline function consume(c::CaptureArgs{F, Args}, vals) where {F, Args}
args, newvals = t_consume(c.args, vals)
return c.f(args...), newvals
end
Expand Down
4 changes: 2 additions & 2 deletions src/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ function Base.Array(a::StridedView)
return b
end

function (Base.Array{T})(a::StridedView{S,N}) where {T,S,N}
function (Base.Array{T})(a::StridedView{S, N}) where {T, S, N}
b = Array{T}(undef, size(a))
copy!(StridedView(b), a)
return b
end

function (Base.Array{T,N})(a::StridedView{S,N}) where {T,S,N}
function (Base.Array{T, N})(a::StridedView{S, N}) where {T, S, N}
b = Array{T}(undef, size(a))
copy!(StridedView(b), a)
return b
Expand Down
82 changes: 53 additions & 29 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,43 @@
LinearAlgebra.rmul!(dst::StridedView, α::Number) = mul!(dst, dst, α)
LinearAlgebra.lmul!(α::Number, dst::StridedView) = mul!(dst, α, dst)

function LinearAlgebra.mul!(dst::StridedView{<:Number,N}, α::Number,
src::StridedView{<:Number,N}) where {N}
function LinearAlgebra.mul!(
dst::StridedView{<:Number, N}, α::Number,
src::StridedView{<:Number, N}
) where {N}
if α == 1
copy!(dst, src)
else
dst .= α .* src
end
return dst
end
function LinearAlgebra.mul!(dst::StridedView{<:Number,N}, src::StridedView{<:Number,N},
α::Number) where {N}
function LinearAlgebra.mul!(
dst::StridedView{<:Number, N}, src::StridedView{<:Number, N},
α::Number
) where {N}
if α == 1
copy!(dst, src)
else
dst .= src .* α
end
return dst
end
function LinearAlgebra.axpy!(a::Number, X::StridedView{<:Number,N},
Y::StridedView{<:Number,N}) where {N}
function LinearAlgebra.axpy!(
a::Number, X::StridedView{<:Number, N},
Y::StridedView{<:Number, N}
) where {N}
if a == 1
Y .= X .+ Y
else
Y .= a .* X .+ Y
end
return Y
end
function LinearAlgebra.axpby!(a::Number, X::StridedView{<:Number,N},
b::Number, Y::StridedView{<:Number,N}) where {N}
function LinearAlgebra.axpby!(
a::Number, X::StridedView{<:Number, N},
b::Number, Y::StridedView{<:Number, N}
) where {N}
if b == 1
axpy!(a, X, Y)
elseif b == 0
Expand All @@ -41,9 +49,11 @@ function LinearAlgebra.axpby!(a::Number, X::StridedView{<:Number,N},
return Y
end

function LinearAlgebra.mul!(C::StridedView{T,2},
A::StridedView{<:Any,2}, B::StridedView{<:Any,2},
α::Number=true, β::Number=false) where {T}
function LinearAlgebra.mul!(
C::StridedView{T, 2},
A::StridedView{<:Any, 2}, B::StridedView{<:Any, 2},
α::Number = true, β::Number = false
) where {T}
if !(eltype(C) <: LinearAlgebra.BlasFloat && eltype(A) == eltype(B) == eltype(C))
return __mul!(C, A, B, α, β)
end
Expand All @@ -62,7 +72,7 @@ function LinearAlgebra.mul!(C::StridedView{T,2},
return C
end

function isblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat}
function isblasmatrix(A::StridedView{T, 2}) where {T <: LinearAlgebra.BlasFloat}
if A.op == identity
return stride(A, 1) == 1 || stride(A, 2) == 1
elseif A.op == conj
Expand All @@ -71,7 +81,7 @@ function isblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat}
return false
end
end
function getblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat}
function getblasmatrix(A::StridedView{T, 2}) where {T <: LinearAlgebra.BlasFloat}
if A.op == identity
if stride(A, 1) == 1
return A, 'N'
Expand All @@ -84,8 +94,10 @@ function getblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat}
end

# here we will have C.op == :identity && stride(C,1) < stride(C,2)
function _mul!(C::StridedView{T,2}, A::StridedView{T,2}, B::StridedView{T,2},
α::Number, β::Number) where {T<:LinearAlgebra.BlasFloat}
function _mul!(
C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2},
α::Number, β::Number
) where {T <: LinearAlgebra.BlasFloat}
if stride(C, 1) == 1 && isblasmatrix(A) && isblasmatrix(B)
nthreads = use_threaded_mul() ? get_num_threads() : 1
_threaded_blas_mul!(C, A, B, α, β, nthreads)
Expand All @@ -94,41 +106,53 @@ function _mul!(C::StridedView{T,2}, A::StridedView{T,2}, B::StridedView{T,2},
end
end

function _threaded_blas_mul!(C::StridedView{T,2}, A::StridedView{T,2}, B::StridedView{T,2},
α::Number, β::Number,
nthreads) where {T<:LinearAlgebra.BlasFloat}
function _threaded_blas_mul!(
C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2},
α::Number, β::Number,
nthreads
) where {T <: LinearAlgebra.BlasFloat}
m, n = size(C)
m == size(A, 1) && n == size(B, 2) || throw(DimensionMismatch())
if nthreads == 1 || m * n < 1024
return if nthreads == 1 || m * n < 1024
A2, CA = getblasmatrix(A)
B2, CB = getblasmatrix(B)
LinearAlgebra.BLAS.gemm!(CA, CB, convert(T, α), A2, B2, convert(T, β), C)
else
if m > n
m2 = round(Int, m / 16) * 8
nthreads2 = nthreads >> 1
t = Threads.@spawn _threaded_blas_mul!(C[1:($m2), :], A[1:($m2), :], B, α, β,
$nthreads2)
_threaded_blas_mul!(C[(m2 + 1):m, :], A[(m2 + 1):m, :], B, α, β,
nthreads - nthreads2)
t = Threads.@spawn _threaded_blas_mul!(
C[1:($m2), :], A[1:($m2), :], B, α, β,
$nthreads2
)
_threaded_blas_mul!(
C[(m2 + 1):m, :], A[(m2 + 1):m, :], B, α, β,
nthreads - nthreads2
)
wait(t)
return C
else
n2 = round(Int, n / 16) * 8
nthreads2 = nthreads >> 1
t = Threads.@spawn _threaded_blas_mul!(C[:, 1:($n2)], A, B[:, 1:($n2)], α, β,
$nthreads2)
_threaded_blas_mul!(C[:, (n2 + 1):n], A, B[:, (n2 + 1):n], α, β,
nthreads - nthreads2)
t = Threads.@spawn _threaded_blas_mul!(
C[:, 1:($n2)], A, B[:, 1:($n2)], α, β,
$nthreads2
)
_threaded_blas_mul!(
C[:, (n2 + 1):n], A, B[:, (n2 + 1):n], α, β,
nthreads - nthreads2
)
wait(t)
return C
end
end
end

# This implementation is faster than LinearAlgebra.generic_matmatmul
function __mul!(C::StridedView{<:Any,2}, A::StridedView{<:Any,2}, B::StridedView{<:Any,2},
α::Number, β::Number)
function __mul!(
C::StridedView{<:Any, 2}, A::StridedView{<:Any, 2}, B::StridedView{<:Any, 2},
α::Number, β::Number
)
(size(C, 1) == size(A, 1) && size(C, 2) == size(B, 2) && size(A, 2) == size(B, 1)) ||
throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))

Expand Down
Loading
Loading