From 0c166315b650e1567e3388c5a29add3d2ba4b55a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 13 Feb 2026 02:23:11 -0500 Subject: [PATCH 1/7] Small fixes to let BlockTensorMaps work with GPU arrays --- src/linalg/factorizations.jl | 7 +++---- src/tensors/blocktensor.jl | 12 ++++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/linalg/factorizations.jl b/src/linalg/factorizations.jl index 347754f..499664f 100644 --- a/src/linalg/factorizations.jl +++ b/src/linalg/factorizations.jl @@ -1,5 +1,5 @@ using MatrixAlgebraKit -using MatrixAlgebraKit: AbstractAlgorithm, YALAPACK.BlasMat, Algorithm +using MatrixAlgebraKit: AbstractAlgorithm, YALAPACK.BlasMat, Algorithm, diagview import MatrixAlgebraKit as MAK # Type piracy for defining the MAK rules on BlockArrays! @@ -9,9 +9,8 @@ const BlockBlasMat{T <: MAK.BlasFloat} = BlockMatrix{T} function MatrixAlgebraKit.one!(A::BlockBlasMat) _one, _zero = one(eltype(A)), zero(eltype(A)) - @inbounds for j in axes(A, 2), i in axes(A, 1) - A[i, j] = ifelse(i == j, _one, _zero) - end + A .= _zero + diagview(A) .= _one return A end diff --git a/src/tensors/blocktensor.jl b/src/tensors/blocktensor.jl index 8fb6258..962f6ce 100644 --- a/src/tensors/blocktensor.jl +++ b/src/tensors/blocktensor.jl @@ -123,10 +123,10 @@ end # ------------------------ for (fname, felt) in ((:zeros, :zero), (:ones, :one)) @eval begin - function Base.$fname(::Type{T}, V::TensorMapSumSpace) where {T} - TT = blocktensormaptype(spacetype(V), numout(V), numin(V), T) + function Base.$fname(::Type{TorA}, V::TensorMapSumSpace) where {TorA} + TT = blocktensormaptype(spacetype(V), numout(V), numin(V), TorA) t = TT(undef, V) - fill!(t, $felt(T)) + fill!(t, $felt(scalartype(t))) return t end end @@ -136,9 +136,9 @@ for randfun in (:rand, :randn, :randexp) randfun! = Symbol(randfun, :!) @eval begin function Random.$randfun( - rng::Random.AbstractRNG, ::Type{T}, V::TensorMapSumSpace - ) where {T} - TT = blocktensormaptype(spacetype(V), numout(V), numin(V), T) + rng::Random.AbstractRNG, ::Type{TorA}, V::TensorMapSumSpace + ) where {TorA} + TT = blocktensormaptype(spacetype(V), numout(V), numin(V), TorA) t = TT(undef, V) Random.$randfun!(rng, t) return t From 9ba6e72db7c5a8b25b2521784fab75f5cfc533b5 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 13 Feb 2026 14:30:54 +0100 Subject: [PATCH 2/7] Indexing horror --- src/linalg/factorizations.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/linalg/factorizations.jl b/src/linalg/factorizations.jl index 499664f..51c90e0 100644 --- a/src/linalg/factorizations.jl +++ b/src/linalg/factorizations.jl @@ -10,7 +10,11 @@ const BlockBlasMat{T <: MAK.BlasFloat} = BlockMatrix{T} function MatrixAlgebraKit.one!(A::BlockBlasMat) _one, _zero = one(eltype(A)), zero(eltype(A)) A .= _zero - diagview(A) .= _one + n_blocks = blocksize(A)[1] + # awful workaround to BlockArrays indexing interface + for bi in 1:n_blocks + A[Block(bi), Block(bi)] .= diagm(fill(_one, blocksizes(A)[bi, bi][1])) + end return A end From 3a8f8bc2a136b23669b1c6e11f9dec598837e5fd Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 13 Feb 2026 14:54:52 +0100 Subject: [PATCH 3/7] Update src/linalg/factorizations.jl Co-authored-by: Lukas Devos --- src/linalg/factorizations.jl | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/linalg/factorizations.jl b/src/linalg/factorizations.jl index 51c90e0..e32da53 100644 --- a/src/linalg/factorizations.jl +++ b/src/linalg/factorizations.jl @@ -1,5 +1,5 @@ using MatrixAlgebraKit -using MatrixAlgebraKit: AbstractAlgorithm, YALAPACK.BlasMat, Algorithm, diagview +using MatrixAlgebraKit: AbstractAlgorithm, YALAPACK.BlasMat, Algorithm, diagview, zero!, one! import MatrixAlgebraKit as MAK # Type piracy for defining the MAK rules on BlockArrays! @@ -7,13 +7,21 @@ import MatrixAlgebraKit as MAK const BlockBlasMat{T <: MAK.BlasFloat} = BlockMatrix{T} +function MatrixAlgebraKit.zero!(A::BlockBlasMat) + for bi in 1:blocksize(A, 1), bj in 1:blocksize(A, 2) + zero!(A[Block(bi), Block(bj)]) + end + return A +end + function MatrixAlgebraKit.one!(A::BlockBlasMat) - _one, _zero = one(eltype(A)), zero(eltype(A)) - A .= _zero - n_blocks = blocksize(A)[1] - # awful workaround to BlockArrays indexing interface - for bi in 1:n_blocks - A[Block(bi), Block(bi)] .= diagm(fill(_one, blocksizes(A)[bi, bi][1])) + A .= zero(eltype(A)) + for bi in 1:blocksize(A, 1), bj in 1:blocksize(A, 2) + if bi == bj + one!(A[Block(bi), Block(bj)]) + #else + # zero!(A[Block(bi), Block(bj)]) + end end return A end From ac30196de4eb5165019c4094b36a9f62e8fd646e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 13 Feb 2026 15:37:48 +0100 Subject: [PATCH 4/7] Use views --- src/linalg/factorizations.jl | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/linalg/factorizations.jl b/src/linalg/factorizations.jl index e32da53..003711a 100644 --- a/src/linalg/factorizations.jl +++ b/src/linalg/factorizations.jl @@ -1,5 +1,5 @@ using MatrixAlgebraKit -using MatrixAlgebraKit: AbstractAlgorithm, YALAPACK.BlasMat, Algorithm, diagview, zero!, one! +using MatrixAlgebraKit: AbstractAlgorithm, YALAPACK.BlasMat, Algorithm import MatrixAlgebraKit as MAK # Type piracy for defining the MAK rules on BlockArrays! @@ -8,20 +8,17 @@ import MatrixAlgebraKit as MAK const BlockBlasMat{T <: MAK.BlasFloat} = BlockMatrix{T} function MatrixAlgebraKit.zero!(A::BlockBlasMat) - for bi in 1:blocksize(A, 1), bj in 1:blocksize(A, 2) - zero!(A[Block(bi), Block(bj)]) + for bj in blockaxes(A, 2), bi in blockaxes(A, 1) + a = view(A, bi, bj) + MAK.zero!(a) end return A end function MatrixAlgebraKit.one!(A::BlockBlasMat) - A .= zero(eltype(A)) - for bi in 1:blocksize(A, 1), bj in 1:blocksize(A, 2) - if bi == bj - one!(A[Block(bi), Block(bj)]) - #else - # zero!(A[Block(bi), Block(bj)]) - end + for bj in blockaxes(A, 2), bi in blockaxes(A, 1) + a = view(A, bi, bj) + bi == bj ? MAK.one!(a) : MAK.zero!(a) end return A end From d88a5d8c51a10e2feb1a1c6c6e9ef133089581c0 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 13 Feb 2026 15:48:57 +0100 Subject: [PATCH 5/7] Add a check for blocksquarity --- src/linalg/factorizations.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/linalg/factorizations.jl b/src/linalg/factorizations.jl index 003711a..dfd2f31 100644 --- a/src/linalg/factorizations.jl +++ b/src/linalg/factorizations.jl @@ -16,6 +16,8 @@ function MatrixAlgebraKit.zero!(A::BlockBlasMat) end function MatrixAlgebraKit.one!(A::BlockBlasMat) + mb, nb = blocksize(A) + mb == nb || throw(DimensionMismatch("A is not block-square. Number of row-blocks ($mb) does not match number of column-blocks ($nb)")) for bj in blockaxes(A, 2), bi in blockaxes(A, 1) a = view(A, bi, bj) bi == bj ? MAK.one!(a) : MAK.zero!(a) From a20c3f42b73c15eef5251581cf9222c57ecd8f52 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 13 Feb 2026 15:55:14 +0100 Subject: [PATCH 6/7] Update src/linalg/factorizations.jl Co-authored-by: Lukas Devos --- src/linalg/factorizations.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/linalg/factorizations.jl b/src/linalg/factorizations.jl index dfd2f31..d589a35 100644 --- a/src/linalg/factorizations.jl +++ b/src/linalg/factorizations.jl @@ -16,8 +16,7 @@ function MatrixAlgebraKit.zero!(A::BlockBlasMat) end function MatrixAlgebraKit.one!(A::BlockBlasMat) - mb, nb = blocksize(A) - mb == nb || throw(DimensionMismatch("A is not block-square. Number of row-blocks ($mb) does not match number of column-blocks ($nb)")) + blockaxes(A, 1) == blockaxes(A, 2) || throw(DimensionMismatch("A is not block-square")) for bj in blockaxes(A, 2), bi in blockaxes(A, 1) a = view(A, bi, bj) bi == bj ? MAK.one!(a) : MAK.zero!(a) From 3bfe9c6e79309555d910e2f2abe92716e7920b50 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 13 Feb 2026 16:11:51 +0100 Subject: [PATCH 7/7] No checks full gas --- src/linalg/factorizations.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/linalg/factorizations.jl b/src/linalg/factorizations.jl index d589a35..003711a 100644 --- a/src/linalg/factorizations.jl +++ b/src/linalg/factorizations.jl @@ -16,7 +16,6 @@ function MatrixAlgebraKit.zero!(A::BlockBlasMat) end function MatrixAlgebraKit.one!(A::BlockBlasMat) - blockaxes(A, 1) == blockaxes(A, 2) || throw(DimensionMismatch("A is not block-square")) for bj in blockaxes(A, 2), bi in blockaxes(A, 1) a = view(A, bi, bj) bi == bj ? MAK.one!(a) : MAK.zero!(a)