From ff0ac3feb47c32762c9eb72f65c4907816b7746c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 17 Mar 2026 11:01:10 -0400 Subject: [PATCH 1/4] broadcasted_linear, getindex on lazy arrays --- Project.toml | 2 +- src/lazyarrays.jl | 45 +++++++++++++++++++++++++++++++++++++++++++++ test/test_lazy.jl | 28 ++++++++++++++++++++++++++-- 3 files changed, 72 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 7d690926..18575fb3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.7.19" +version = "0.7.20" authors = ["ITensor developers and contributors"] [workspace] diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index 6b832613..e687b980 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -30,8 +30,13 @@ lazy_function(::typeof(*)) = *ₗ lazy_function(::typeof(/)) = /ₗ lazy_function(::typeof(\)) = \ₗ lazy_function(::typeof(conj)) = conjed +lazy_function(::typeof(identity)) = identity +lazy_function(f::Base.Fix1{typeof(*), <:Number}) = Base.Fix1(*ₗ, f.x) +lazy_function(f::Base.Fix2{typeof(*), <:Number}) = Base.Fix2(*ₗ, f.x) +lazy_function(f::Base.Fix2{typeof(/), <:Number}) = Base.Fix2(/ₗ, f.x) broadcast_is_linear(f, args...) = false +broadcast_is_linear(::typeof(identity), ::Base.AbstractArrayOrBroadcasted) = true broadcast_is_linear(::typeof(+), ::Base.AbstractArrayOrBroadcasted...) = true broadcast_is_linear(::typeof(-), ::Base.AbstractArrayOrBroadcasted) = true function broadcast_is_linear( @@ -50,6 +55,21 @@ function broadcast_is_linear( end broadcast_is_linear(::typeof(*), ::Number, ::Number) = true broadcast_is_linear(::typeof(conj), ::Base.AbstractArrayOrBroadcasted) = true +function broadcast_is_linear( + ::Base.Fix1{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted + ) + return true +end +function broadcast_is_linear( + ::Base.Fix2{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted + ) + return true +end +function broadcast_is_linear( + ::Base.Fix2{typeof(/), <:Number}, ::Base.AbstractArrayOrBroadcasted + ) + return true +end is_linear(x) = true function is_linear(bc::BC.Broadcasted) return broadcast_is_linear(bc.f, bc.args...) && all(is_linear, bc.args) @@ -57,6 +77,19 @@ end to_linear(x) = x to_linear(bc::BC.Broadcasted) = lazy_function(bc.f)(to_linear.(bc.args)...) +function broadcast_error(style, f) + return throw( + ArgumentError( + "Only linear broadcast operations are supported for `$style`, got `$f`." + ) + ) +end +function broadcasted_linear(style::BC.BroadcastStyle, f, args...) + bc = BC.Broadcasted(style, f, args) + is_linear(bc) || broadcast_error(style, f) + return to_linear(bc) +end +broadcasted_linear(f, args...) = broadcasted_linear(BC.combine_styles(args...), f, args...) # TODO: Use `Broadcast.broadcastable` interface for this? to_broadcasted(x) = x function to_broadcasted(a::AbstractArray) @@ -136,6 +169,7 @@ similar_scaled(a::AbstractArray) = similar(unscaled(a)) similar_scaled(a::AbstractArray, elt::Type) = similar(unscaled(a), elt) similar_scaled(a::AbstractArray, ax) = similar(unscaled(a), ax) similar_scaled(a::AbstractArray, elt::Type, ax) = similar(unscaled(a), elt, ax) +getindex_scaled(a::AbstractArray, I...) = coeff(a) * getindex(unscaled(a), I...) copyto!_scaled(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false) show_scaled(io::IO, a::AbstractArray) = show_lazy(io, a) show_scaled(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a) @@ -227,6 +261,9 @@ macro scaledarray_base(ScaledArray, AbstractArray = :AbstractArray) function Base.similar(a::$ScaledArray, elt::Type, ax::Dims) return $TensorAlgebra.similar_scaled(a, elt, ax) end + Base.@propagate_inbounds function Base.getindex(a::$ScaledArray, I...) + return $TensorAlgebra.getindex_scaled(a, I...) + end function Base.copyto!(dest::$AbstractArray, src::$ScaledArray) return $TensorAlgebra.copyto!_scaled(dest, src) end @@ -372,6 +409,7 @@ size_conj(a::AbstractArray) = size(conjed(a)) similar_conj(a::AbstractArray, elt::Type) = similar(conjed(a), elt) similar_conj(a::AbstractArray, elt::Type, ax) = similar(conjed(a), elt, ax) similar_conj(a::AbstractArray, ax) = similar(conjed(a), ax) +getindex_conj(a::AbstractArray, I...) = conj(getindex(conjed(a), I...)) copyto!_conj(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false) show_conj(io::IO, a::AbstractArray) = show_lazy(io, a) show_conj(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a) @@ -424,6 +462,9 @@ macro conjarray_base(ConjArray, AbstractArray = :AbstractArray) function Base.similar(a::$ConjArray, elt::Type, ax::Dims) return $TensorAlgebra.similar_conj(a, elt, ax) end + Base.@propagate_inbounds function Base.getindex(a::$ConjArray, I...) + return $TensorAlgebra.getindex_conj(a, I...) + end function Base.copyto!(dest::$AbstractArray, src::$ConjArray) return $TensorAlgebra.copyto!_conj(dest, src) end @@ -525,6 +566,7 @@ similar_add(a::AbstractArray, elt::Type) = similar(BC.Broadcasted(+, addends(a)) function similar_add(a::AbstractArray, elt::Type, ax) return similar(BC.Broadcasted(+, addends(a)), elt, ax) end +getindex_add(a::AbstractArray, I...) = sum(addend -> getindex(addend, I...), addends(a)) copyto!_add(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false) show_add(io::IO, a::AbstractArray) = show_lazy(io, a) show_add(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a) @@ -611,6 +653,9 @@ macro addarray_base(AddArray, AbstractArray = :AbstractArray) function Base.similar(a::$AddArray, elt::Type, ax) return $TensorAlgebra.similar_add(a, elt, ax) end + Base.@propagate_inbounds function Base.getindex(a::$AddArray, I...) + return $TensorAlgebra.getindex_add(a, I...) + end function Base.copyto!(dest::$AbstractArray, src::$AddArray) return $TensorAlgebra.copyto!_add(dest, src) end diff --git a/test/test_lazy.jl b/test/test_lazy.jl index 24517c9b..841b3125 100644 --- a/test/test_lazy.jl +++ b/test/test_lazy.jl @@ -1,6 +1,7 @@ import FunctionImplementations as FI -using TensorAlgebra: TensorAlgebra as TA, *ₗ, +ₗ, conjed -using Test: @test, @test_broken, @testset +using Base.Broadcast: Broadcast as BC +using TensorAlgebra: TensorAlgebra as TA, *ₗ, +ₗ, /ₗ, conjed +using Test: @test, @test_broken, @test_throws, @testset @testset "lazy arrays" begin @testset "lazy array operations" begin @@ -94,4 +95,27 @@ using Test: @test, @test_broken, @testset @test x ≡ PermutedDimsArray(a *ₗ b, perm) @test_broken copy(x) ≈ permutedims(a * b, perm) end + @testset "linear broadcast lowering" begin + a = randn(ComplexF64, 2, 2) + style = BC.DefaultArrayStyle{2}() + + @test TA.broadcasted_linear(identity, a) ≡ a + @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ 2 *ₗ a + @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ a *ₗ 2 + @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ a /ₗ 2 + @test TA.broadcasted_linear(style, identity, a) ≡ a + @test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ 2 *ₗ a + @test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ a *ₗ 2 + @test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ a /ₗ 2 + @test TA.broadcasted_linear(style, conj, a) ≡ conjed(a) + @test_throws ArgumentError TA.broadcasted_linear(style, exp, a) + end + @testset "scalar getindex" begin + a = randn(ComplexF64, 2, 2) + b = randn(ComplexF64, 2, 2) + + @test (2 *ₗ a)[1, 2] == 2 * a[1, 2] + @test conjed(a)[2, 1] == conj(a[2, 1]) + @test (a +ₗ b)[2, 2] == a[2, 2] + b[2, 2] + end end From add6826cf9526c53d642b3cf081c9e9d472eec9d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 17 Mar 2026 11:32:48 -0400 Subject: [PATCH 2/4] Scalar indexing of lazy mul --- src/lazyarrays.jl | 26 ++++++++++++++++++++++++++ test/test_lazy.jl | 4 +++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index e687b980..94ad1970 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -786,6 +786,20 @@ similar_mul(a::AbstractArray, elt::Type) = similar(a, elt, axes(a)) # TODO: Make use of both arguments to determine the output, maybe # using `LinearAlgebra.matprod_dest(factors(a)..., elt)`? similar_mul(a::AbstractArray, elt::Type, ax) = similar(last(factors(a)), elt, ax) +function mul_getindex(a1::AbstractMatrix, a2::AbstractMatrix, i::Int, j::Int) + return transpose(view(a1, i, :)) * view(a2, :, j) +end +function mul_getindex(a1::AbstractMatrix, a2::AbstractVector, i::Int) + return transpose(view(a1, i, :)) * a2 +end +function mul_getindex(a1::AbstractVector, a2::AbstractMatrix, j::Int) + return transpose(a1) * view(a2, :, j) +end +function getindex_mul(a::AbstractArray, i::Int) + I = Tuple(CartesianIndices(axes(a))[i]) + return getindex_mul(a, I...) +end +getindex_mul(a::AbstractArray, I::Vararg{Int}) = mul_getindex(factors(a)..., I...) copyto!_mul(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false) show_mul(io::IO, a::AbstractArray) = show_lazy(io, a) show_mul(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a) @@ -843,6 +857,12 @@ macro mularray_type(MulArray, AbstractArray = :AbstractArray) ) end +function copy_permuteddims_mul(a::PermutedDimsArray{<:Any, 2, perm}) where {perm} + perm == (1, 2) && return copy(parent(a)) + perm == (2, 1) && return copy(transpose(parent(a))) + throw(ArgumentError("Unsupported permutation $perm")) +end + macro mularray_base(MulArray, AbstractArray = :AbstractArray) return esc( quote @@ -864,6 +884,9 @@ macro mularray_base(MulArray, AbstractArray = :AbstractArray) function Base.similar(a::$MulArray, elt::Type, ax::Dims) return $TensorAlgebra.similar_mul(a, elt, ax) end + Base.@propagate_inbounds function Base.getindex(a::$MulArray, I...) + return $TensorAlgebra.getindex_mul(a, I...) + end function Base.copyto!(dest::$AbstractArray, src::$MulArray) return $TensorAlgebra.copyto!_mul(dest, src) end @@ -926,6 +949,9 @@ macro mularray_terminterface(MulArray, AbstractArray = :AbstractArray) $TensorAlgebra.iscall(a::$MulArray) = $TensorAlgebra.iscall_mul(a) $TensorAlgebra.operation(a::$MulArray) = $TensorAlgebra.operation_mul(a) $TensorAlgebra.arguments(a::$MulArray) = $TensorAlgebra.arguments_mul(a) + function Base.copy(a::PermutedDimsArray{<:Any, 2, <:Any, <:Any, $MulArray}) + return $TensorAlgebra.copy_permuteddims_mul(a) + end end ) end diff --git a/test/test_lazy.jl b/test/test_lazy.jl index 841b3125..4d3265fc 100644 --- a/test/test_lazy.jl +++ b/test/test_lazy.jl @@ -93,7 +93,7 @@ using Test: @test, @test_broken, @test_throws, @testset x = FI.permuteddims(a *ₗ b, perm) @test x ≡ PermutedDimsArray(a *ₗ b, perm) - @test_broken copy(x) ≈ permutedims(a * b, perm) + @test copy(x) ≈ permutedims(a * b, perm) end @testset "linear broadcast lowering" begin a = randn(ComplexF64, 2, 2) @@ -117,5 +117,7 @@ using Test: @test, @test_broken, @test_throws, @testset @test (2 *ₗ a)[1, 2] == 2 * a[1, 2] @test conjed(a)[2, 1] == conj(a[2, 1]) @test (a +ₗ b)[2, 2] == a[2, 2] + b[2, 2] + @test (a *ₗ b)[1, 2] ≈ (a * b)[1, 2] + @test (a *ₗ b)[3] ≈ (a * b)[3] end end From 3a9d7eb5bab78dc3afbecc736a598c58965a60bf Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 17 Mar 2026 11:37:53 -0400 Subject: [PATCH 3/4] Simplify --- src/lazyarrays.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index 94ad1970..5b250727 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -859,8 +859,7 @@ end function copy_permuteddims_mul(a::PermutedDimsArray{<:Any, 2, perm}) where {perm} perm == (1, 2) && return copy(parent(a)) - perm == (2, 1) && return copy(transpose(parent(a))) - throw(ArgumentError("Unsupported permutation $perm")) + return copy(transpose(parent(a))) end macro mularray_base(MulArray, AbstractArray = :AbstractArray) From 167e4e1fdbabd056ac4ad775db6cd0a3ac94cd20 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 17 Mar 2026 11:39:08 -0400 Subject: [PATCH 4/4] Simplify --- src/lazyarrays.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl index 5b250727..ba71edae 100644 --- a/src/lazyarrays.jl +++ b/src/lazyarrays.jl @@ -857,7 +857,7 @@ macro mularray_type(MulArray, AbstractArray = :AbstractArray) ) end -function copy_permuteddims_mul(a::PermutedDimsArray{<:Any, 2, perm}) where {perm} +function copy_permuteddims(a::PermutedDimsArray{<:Any, 2, perm}) where {perm} perm == (1, 2) && return copy(parent(a)) return copy(transpose(parent(a))) end @@ -949,7 +949,7 @@ macro mularray_terminterface(MulArray, AbstractArray = :AbstractArray) $TensorAlgebra.operation(a::$MulArray) = $TensorAlgebra.operation_mul(a) $TensorAlgebra.arguments(a::$MulArray) = $TensorAlgebra.arguments_mul(a) function Base.copy(a::PermutedDimsArray{<:Any, 2, <:Any, <:Any, $MulArray}) - return $TensorAlgebra.copy_permuteddims_mul(a) + return $TensorAlgebra.copy_permuteddims(a) end end )