diff --git a/src/linalg/factorizations.jl b/src/linalg/factorizations.jl index 347754f..003711a 100644 --- a/src/linalg/factorizations.jl +++ b/src/linalg/factorizations.jl @@ -7,10 +7,18 @@ import MatrixAlgebraKit as MAK const BlockBlasMat{T <: MAK.BlasFloat} = BlockMatrix{T} +function MatrixAlgebraKit.zero!(A::BlockBlasMat) + for bj in blockaxes(A, 2), bi in blockaxes(A, 1) + a = view(A, bi, bj) + MAK.zero!(a) + end + return A +end + function MatrixAlgebraKit.one!(A::BlockBlasMat) - _one, _zero = one(eltype(A)), zero(eltype(A)) - @inbounds for j in axes(A, 2), i in axes(A, 1) - A[i, j] = ifelse(i == j, _one, _zero) + for bj in blockaxes(A, 2), bi in blockaxes(A, 1) + a = view(A, bi, bj) + bi == bj ? MAK.one!(a) : MAK.zero!(a) end return A end diff --git a/src/tensors/blocktensor.jl b/src/tensors/blocktensor.jl index 8fb6258..962f6ce 100644 --- a/src/tensors/blocktensor.jl +++ b/src/tensors/blocktensor.jl @@ -123,10 +123,10 @@ end # ------------------------ for (fname, felt) in ((:zeros, :zero), (:ones, :one)) @eval begin - function Base.$fname(::Type{T}, V::TensorMapSumSpace) where {T} - TT = blocktensormaptype(spacetype(V), numout(V), numin(V), T) + function Base.$fname(::Type{TorA}, V::TensorMapSumSpace) where {TorA} + TT = blocktensormaptype(spacetype(V), numout(V), numin(V), TorA) t = TT(undef, V) - fill!(t, $felt(T)) + fill!(t, $felt(scalartype(t))) return t end end @@ -136,9 +136,9 @@ for randfun in (:rand, :randn, :randexp) randfun! = Symbol(randfun, :!) @eval begin function Random.$randfun( - rng::Random.AbstractRNG, ::Type{T}, V::TensorMapSumSpace - ) where {T} - TT = blocktensormaptype(spacetype(V), numout(V), numin(V), T) + rng::Random.AbstractRNG, ::Type{TorA}, V::TensorMapSumSpace + ) where {TorA} + TT = blocktensormaptype(spacetype(V), numout(V), numin(V), TorA) t = TT(undef, V) Random.$randfun!(rng, t) return t