Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,28 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
using AMDGPU
using LinearAlgebra
using LinearAlgebra: BlasFloat

include("yarocsolver.jl")

function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
return ROCSOLVER_HouseholderQR(; kwargs...)
end
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
qr_alg = ROCSOLVER_HouseholderQR(; kwargs...)
return LQViaTransposedQR(qr_alg)
end
MatrixAlgebraKit.default_householder_driver(::StridedROCMatrix{<:BlasFloat}) = ROCSOLVER()
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
return ROCSOLVER_QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
return ROCSOLVER_DivideAndConquer(; kwargs...)
end

_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A)
_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ)
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedROCMatrix, τ::StridedROCVector, C::StridedROCVecOrMat) =
YArocSOLVER.unmqr!(side, trans, A, τ, C)
for f in (:geqrf!, :ungqr!, :unmqr!)
@eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...)
end

_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) =
YArocSOLVER.gesvd!(A, S, U, Vᴴ)
# not yet supported
Expand Down
38 changes: 13 additions & 25 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank
using CUDA, CUDA.CUBLAS
using CUDA: i32
Expand All @@ -15,13 +15,7 @@ using LinearAlgebra: BlasFloat

include("yacusolver.jl")

function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
return CUSOLVER_HouseholderQR(; kwargs...)
end
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
return LQViaTransposedQR(qr_alg)
end
MatrixAlgebraKit.default_householder_driver(::StridedCuMatrix{<:BlasFloat}) = CUSOLVER()
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
return CUSOLVER_QRIteration(; kwargs...)
end
Expand All @@ -33,31 +27,25 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT
end

# include for block sector support
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_HouseholderQR(; kwargs...)
end
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
return LQViaTransposedQR(qr_alg)
end
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
const BlockView{T, A} = Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}}

MatrixAlgebraKit.default_householder_driver(::BlockView{T, A}) where {T <: BlasFloat, A <: CuVecOrMat{T}} = CUSOLVER()
function MatrixAlgebraKit.default_svd_algorithm(::Type{BlockView{T, A}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_Jacobi(; kwargs...)
end
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
function MatrixAlgebraKit.default_eig_algorithm(::Type{BlockView{T, A}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_Simple(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
function MatrixAlgebraKit.default_eigh_algorithm(::Type{BlockView{T, A}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_DivideAndConquer(; kwargs...)
end

for f in (:geqrf!, :ungqr!, :unmqr!)
@eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...)
end

_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
YACUSOLVER.Xgeev!(A, D, V)
_gpu_geqrf!(A::StridedCuMatrix) =
YACUSOLVER.geqrf!(A)
_gpu_ungqr!(A::StridedCuMatrix, τ::StridedCuVector) =
YACUSOLVER.ungqr!(A, τ)
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedCuMatrix, τ::StridedCuVector, C::StridedCuVecOrMat) =
YACUSOLVER.unmqr!(side, trans, A, τ, C)
_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) =
YACUSOLVER.gesvd!(A, S, U, Vᴴ)
_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
Expand Down
73 changes: 28 additions & 45 deletions ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module MatrixAlgebraKitGenericLinearAlgebraExt

using MatrixAlgebraKit
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, default_fixgauge
using MatrixAlgebraKit: left_orth_alg
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
using LinearAlgebra: I, Diagonal, lmul!

Expand Down Expand Up @@ -57,81 +56,65 @@ function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
return eigvals!(Hermitian(A); sortby = real)
end

function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
return GLA_HouseholderQR(; kwargs...)
end

function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
check_input(qr_full!, A, QR, alg)
Q, R = QR
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
end

function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
check_input(qr_compact!, A, QR, alg)
Q, R = QR
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
end

function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::GLA_HouseholderQR)
check_input(qr_null!, A, N, alg)
return _gla_householder_qr_null!(A, N; alg.kwargs...)
end

function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = true, blocksize = 1, pivoted = false)
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR."))
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR."))
function MatrixAlgebraKit.householder_qr!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
)
blocksize == 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))

m, n = size(A)
k = min(m, n)
minmn = min(m, n)
computeR = length(R) > 0

# compute QR
Q̃, R̃ = qr!(A)
lmul!(Q̃, MatrixAlgebraKit.one!(Q))

if positive
@inbounds for j in 1:k
@inbounds for j in 1:minmn
s = sign_safe(R̃[j, j])
@simd for i in 1:m
Q[i, j] *= s
end
end
end

computeR = length(R) > 0
if computeR
if positive
@inbounds for j in n:-1:1
@simd for i in 1:min(k, j)
@simd for i in 1:min(minmn, j)
R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
end
@simd for i in (min(k, j) + 1):size(R, 1)
@simd for i in (min(minmn, j) + 1):size(R, 1)
R[i, j] = zero(eltype(R))
end
end
else
R[1:k, :] .= R̃
MatrixAlgebraKit.zero!(@view(R[(k + 1):end, :]))
R[1:minmn, :] .= R̃
MatrixAlgebraKit.zero!(@view(R[(minmn + 1):end, :]))
end
end
return Q, R
end

function _gla_householder_qr_null!(
A::AbstractMatrix, N::AbstractMatrix;
positive = true, blocksize = 1, pivoted = false
function MatrixAlgebraKit.householder_qr_null!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
)
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR."))
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR."))
blocksize == 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))

m, n = size(A)
minmn = min(m, n)
fill!(N, zero(eltype(N)))
zero!(N)
one!(view(N, (minmn + 1):m, 1:(m - minmn)))
Q̃, = qr!(A)
lmul!(Q̃, N)
return N
end

function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...))
return lmul!(Q̃, N)
end

MatrixAlgebraKit.left_orth_alg(alg::GLA_HouseholderQR) = MatrixAlgebraKit.LeftOrthViaQR(alg)
Expand Down
55 changes: 54 additions & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Finally, the same behavior is obtained when the keyword arguments are
passed as the third positional argument in the form of a `NamedTuple`.
""" select_algorithm

function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg}
Base.@assume_effects :foldable function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg}
if isnothing(alg)
return default_algorithm(f, A; kwargs...)
elseif alg isa Symbol
Expand Down Expand Up @@ -143,6 +143,59 @@ If this is not possible, for example when the output size is not known a priori
this function may return `nothing`.
""" initialize_output


# Drivers
# -------
"""
abstract type Driver

Supertype used for customizing various implementations of the same algorithm.
"""
abstract type Driver end

"""
DefaultDriver <: Driver

Select a default driver at runtime, based on the input matrix.
"""
struct DefaultDriver <: Driver end

"""
LAPACK <: Driver

Driver to select LAPACK as the implementation strategy.
"""
struct LAPACK <: Driver end

"""
CUSOLVER <: Driver

Driver to select CUSOLVER as the implementation strategy.
"""
struct CUSOLVER <: Driver end

"""
ROCSOLVER <: Driver

Driver to select ROCSOLVER as the implementation strategy.
"""
struct ROCSOLVER <: Driver end

"""
GLA <: Driver

Driver to select GenericLinearAlgebra.jl as the implementation strategy.
"""
struct GLA <: Driver end

"""
Native <: Driver

Driver to select a native implementation in MatrixAlgebraKit as the implementation strategy.
"""
struct Native <: Driver end


# Truncation strategy
# -------------------
"""
Expand Down
16 changes: 8 additions & 8 deletions src/common/householder.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
const IndexRange{T <: Integer} = Base.AbstractRange{T}

# Elementary Householder reflection
struct Householder{T, V <: AbstractVector, R <: IndexRange}
struct HouseholderReflection{T, V <: AbstractVector, R <: IndexRange}
β::T
v::V
r::R
end
Base.adjoint(H::Householder) = Householder(conj(H.β), H.v, H.r)
Base.adjoint(H::HouseholderReflection) = HouseholderReflection(conj(H.β), H.v, H.r)

function householder(x::AbstractVector, r::IndexRange = axes(x, 1), k = first(r))
i = findfirst(==(k), r)
i == nothing && error("k = $k should be in the range r = $r")
β, v, ν = _householder!(x[r], i)
return Householder(β, v, r), ν
return HouseholderReflection(β, v, r), ν
end
# Householder reflector h that zeros the elements A[r,col] (except for A[k,col]) upon lmul!(h,A)
function householder(A::AbstractMatrix, r::IndexRange, col::Int, k = first(r))
i = findfirst(==(k), r)
i == nothing && error("k = $k should be in the range r = $r")
β, v, ν = _householder!(A[r, col], i)
return Householder(β, v, r), ν
return HouseholderReflection(β, v, r), ν
end
# Householder reflector that zeros the elements A[row,r] (except for A[row,k]) upon rmul!(A,h')
function householder(A::AbstractMatrix, row::Int, r::IndexRange, k = first(r))
i = findfirst(==(k), r)
i == nothing && error("k = $k should be in the range r = $r")
β, v, ν = _householder!(conj!(A[row, r]), i)
return Householder(β, v, r), ν
return HouseholderReflection(β, v, r), ν
end

# generate Householder vector based on vector v, such that applying the reflection
Expand Down Expand Up @@ -66,7 +66,7 @@ function _householder!(v::AbstractVector{T}, i::Int = 1) where {T}
return β, v, ν
end

function LinearAlgebra.lmul!(H::Householder, x::AbstractVector)
function LinearAlgebra.lmul!(H::HouseholderReflection, x::AbstractVector)
v = H.v
r = H.r
β = H.β
Expand All @@ -87,7 +87,7 @@ function LinearAlgebra.lmul!(H::Householder, x::AbstractVector)
end
return x
end
function LinearAlgebra.lmul!(H::Householder, A::AbstractMatrix; cols = axes(A, 2))
function LinearAlgebra.lmul!(H::HouseholderReflection, A::AbstractMatrix; cols = axes(A, 2))
v = H.v
r = H.r
β = H.β
Expand All @@ -110,7 +110,7 @@ function LinearAlgebra.lmul!(H::Householder, A::AbstractMatrix; cols = axes(A, 2
end
return A
end
function LinearAlgebra.rmul!(A::AbstractMatrix, H::Householder; rows = axes(A, 1))
function LinearAlgebra.rmul!(A::AbstractMatrix, H::HouseholderReflection; rows = axes(A, 1))
v = H.v
r = H.r
β = H.β
Expand Down
Loading