From d557bd7da988572c65a8b436778c3dff785fd85e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 24 Mar 2026 23:44:35 -0400 Subject: [PATCH 01/29] Replace AbstractArray lazy types with LinearBroadcasted and Mul MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Redesign the lazy type system to avoid AbstractArray subtyping, which caused method ambiguities with ArrayLayouts and BlockArrays. - Add LinearBroadcasted abstract type with ScaledBroadcasted, ConjBroadcasted, AddBroadcasted concrete subtypes - Add standalone Mul type for lazy matrix multiplication - Add LinearBroadcastFunction constructor API (replaces Unicode operators) - Rename LazyArrayStyle to LinearBroadcastedStyle - Materialization follows Broadcasted protocol: copy → copyto! → add!/mul! - Delete macro system and old AbstractArray lazy types - Rename lazyarrays.jl to linearbroadcasted.jl - Bump version to 0.8.0 (breaking) Co-Authored-By: Claude Opus 4.6 (1M context) --- Project.toml | 2 +- docs/Project.toml | 2 +- examples/Project.toml | 2 +- src/TensorAlgebra.jl | 2 +- src/lazyarrays.jl | 979 --------------------------------------- src/linearbroadcasted.jl | 565 ++++++++++++++++++++++ test/Project.toml | 2 +- test/test_lazy.jl | 161 ++++--- 8 files changed, 678 insertions(+), 1037 deletions(-) delete mode 100644 src/lazyarrays.jl create mode 100644 src/linearbroadcasted.jl diff --git a/Project.toml b/Project.toml index 18575fb3..3c65eb0d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.7.20" +version = "0.8.0" authors = ["ITensor developers and contributors"] [workspace] diff --git a/docs/Project.toml b/docs/Project.toml index 549440ac..20650d9d 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,4 +11,4 @@ path = ".." Documenter = "1.8.1" ITensorFormatter = "0.2.27" Literate = "2.20.1" -TensorAlgebra = "0.7" +TensorAlgebra = "0.8" diff --git a/examples/Project.toml b/examples/Project.toml index 9b0b1293..a8006256 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" path = ".." [compat] -TensorAlgebra = "0.7" +TensorAlgebra = "0.8" diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index e3e402db..a1c966a0 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -17,6 +17,6 @@ include("contract/allocate_output.jl") include("contract/contract_matricize.jl") include("factorizations.jl") include("matrixfunctions.jl") -include("lazyarrays.jl") +include("linearbroadcasted.jl") end diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl deleted file mode 100644 index ba71edae..00000000 --- a/src/lazyarrays.jl +++ /dev/null @@ -1,979 +0,0 @@ -import Base.Broadcast as BC -import FunctionImplementations as FI -import LinearAlgebra as LA -import StridedViews as SV - -# TermInterface-like interface. -iscall(x) = false -function operation end -function arguments end - -# Generic logic for lazy array linear algebra operations. -function +ₗ(a::AbstractArray, b::AbstractArray, c::AbstractArray, xs::AbstractArray...) - return Base.afoldl(+ₗ, +ₗ(+ₗ(a, b), c), xs...) -end --ₗ(a::AbstractArray, b::AbstractArray) = a +ₗ (-b) -function *ₗ(a::AbstractArray, b::AbstractArray, c::AbstractArray, xs::AbstractArray...) - return Base.afoldl(*ₗ, *ₗ(*ₗ(a, b), c), xs...) -end -*ₗ(a::AbstractArray, b::Number) = b *ₗ a -\ₗ(a::Number, b::AbstractArray) = inv(a) *ₗ b -/ₗ(a::AbstractArray, b::Number) = a *ₗ inv(b) -+ₗ(a::AbstractArray) = a --ₗ(a::AbstractArray) = -1 *ₗ a -conjed(a::AbstractArray{<:Real}) = a - -lazy_function(f) = error("No lazy function defined for `$f`.") -lazy_function(::typeof(+)) = +ₗ -lazy_function(::typeof(-)) = -ₗ -lazy_function(::typeof(*)) = *ₗ -lazy_function(::typeof(/)) = /ₗ -lazy_function(::typeof(\)) = \ₗ -lazy_function(::typeof(conj)) = conjed -lazy_function(::typeof(identity)) = identity -lazy_function(f::Base.Fix1{typeof(*), <:Number}) = Base.Fix1(*ₗ, f.x) -lazy_function(f::Base.Fix2{typeof(*), <:Number}) = Base.Fix2(*ₗ, f.x) -lazy_function(f::Base.Fix2{typeof(/), <:Number}) = Base.Fix2(/ₗ, f.x) - -broadcast_is_linear(f, args...) = false -broadcast_is_linear(::typeof(identity), ::Base.AbstractArrayOrBroadcasted) = true -broadcast_is_linear(::typeof(+), ::Base.AbstractArrayOrBroadcasted...) = true -broadcast_is_linear(::typeof(-), ::Base.AbstractArrayOrBroadcasted) = true -function broadcast_is_linear( - ::typeof(-), ::Base.AbstractArrayOrBroadcasted, ::Base.AbstractArrayOrBroadcasted - ) - return true -end -broadcast_is_linear(::typeof(*), ::Number, ::Base.AbstractArrayOrBroadcasted) = true -broadcast_is_linear(::typeof(\), ::Number, ::Base.AbstractArrayOrBroadcasted) = true -broadcast_is_linear(::typeof(*), ::Base.AbstractArrayOrBroadcasted, ::Number) = true -broadcast_is_linear(::typeof(/), ::Base.AbstractArrayOrBroadcasted, ::Number) = true -function broadcast_is_linear( - ::typeof(*), ::Base.AbstractArrayOrBroadcasted, ::Base.AbstractArrayOrBroadcasted - ) - return false -end -broadcast_is_linear(::typeof(*), ::Number, ::Number) = true -broadcast_is_linear(::typeof(conj), ::Base.AbstractArrayOrBroadcasted) = true -function broadcast_is_linear( - ::Base.Fix1{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted - ) - return true -end -function broadcast_is_linear( - ::Base.Fix2{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted - ) - return true -end -function broadcast_is_linear( - ::Base.Fix2{typeof(/), <:Number}, ::Base.AbstractArrayOrBroadcasted - ) - return true -end -is_linear(x) = true -function is_linear(bc::BC.Broadcasted) - return broadcast_is_linear(bc.f, bc.args...) && all(is_linear, bc.args) -end - -to_linear(x) = x -to_linear(bc::BC.Broadcasted) = lazy_function(bc.f)(to_linear.(bc.args)...) -function broadcast_error(style, f) - return throw( - ArgumentError( - "Only linear broadcast operations are supported for `$style`, got `$f`." - ) - ) -end -function broadcasted_linear(style::BC.BroadcastStyle, f, args...) - bc = BC.Broadcasted(style, f, args) - is_linear(bc) || broadcast_error(style, f) - return to_linear(bc) -end -broadcasted_linear(f, args...) = broadcasted_linear(BC.combine_styles(args...), f, args...) -# TODO: Use `Broadcast.broadcastable` interface for this? -to_broadcasted(x) = x -function to_broadcasted(a::AbstractArray) - (BC.BroadcastStyle(typeof(a)) isa LazyArrayStyle) || return a - return BC.broadcasted(operation(a), to_broadcasted.(arguments(a))...) -end -to_broadcasted(bc::BC.Broadcasted) = BC.Broadcasted(bc.f, to_broadcasted.(bc.args)) - -# For lazy arrays, define Broadcast methods in terms of lazy operations. -struct LazyArrayStyle{N, Style <: BC.AbstractArrayStyle{N}} <: BC.AbstractArrayStyle{N} - style::Style -end -# TODO: This empty constructor is required in some Julia versions below v1.12 (such as -# Julia v1.10), try deleting it once we drop support for those versions. -function LazyArrayStyle{N, Style}() where {N, Style <: BC.AbstractArrayStyle{N}} - return LazyArrayStyle{N, Style}(Style()) -end -function LazyArrayStyle{N, Style}(::Val{M}) where {M, N, Style <: BC.AbstractArrayStyle{N}} - return LazyArrayStyle(Style(Val(M))) -end -function BC.BroadcastStyle(style1::LazyArrayStyle, style2::LazyArrayStyle) - style = BC.BroadcastStyle(style1.style, style2.style) - style ≡ BC.Unknown() && return BC.Unknown() - return LazyArrayStyle(style) -end -function Base.similar(bc::BC.Broadcasted{<:LazyArrayStyle}, elt::Type, ax) - return similar(BC.Broadcasted(bc.style.style, bc.f, bc.args, bc.axes), elt, ax) -end -# Backup definition, for broadcast operations that don't preserve LazyArrays -# (such as nonlinear operations), convert back to Broadcasted expressions. -BC.broadcasted(::LazyArrayStyle, f, args...) = BC.Broadcasted(f, to_broadcasted.(args)) -BC.broadcasted(::LazyArrayStyle, ::typeof(+), a::AbstractArray, b::AbstractArray) = a +ₗ b -function BC.broadcasted(::LazyArrayStyle, ::typeof(+), a::AbstractArray, b::BC.Broadcasted) - is_linear(b) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return a +ₗ to_linear(b) -end -function BC.broadcasted(::LazyArrayStyle, ::typeof(+), a::BC.Broadcasted, b::AbstractArray) - is_linear(a) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return to_linear(a) +ₗ b -end -function BC.broadcasted(::LazyArrayStyle, ::typeof(+), a::BC.Broadcasted, b::BC.Broadcasted) - return error("Not implemented") -end -BC.broadcasted(::LazyArrayStyle, ::typeof(*), α::Number, a::AbstractArray) = α *ₗ a -BC.broadcasted(::LazyArrayStyle, ::typeof(*), a::AbstractArray, α::Number) = a *ₗ α -BC.broadcasted(::LazyArrayStyle, ::typeof(\), α::Number, a::AbstractArray) = α \ₗ a -BC.broadcasted(::LazyArrayStyle, ::typeof(/), a::AbstractArray, α::Number) = a /ₗ α -BC.broadcasted(::LazyArrayStyle, ::typeof(-), a::AbstractArray) = -ₗ(a) -BC.broadcasted(::LazyArrayStyle, ::typeof(conj), a::AbstractArray) = conjed(a) - -# Base overloads for lazy arrays. -function show_lazy(io::IO, a::AbstractArray) - print(io, operation(a), "(", join(arguments(a), ", "), ")") - return nothing -end -function show_lazy(io::IO, mime::MIME"text/plain", a::AbstractArray) - summary(io, a) - println(io, ":") - show(io, a) - return nothing -end - -# Generic constructors, accessors, and properties for ScaledArrays. -*ₗ(α::Number, a::AbstractArray) = ScaledArray(α, a) -unscaled(a::AbstractArray) = a -unscaled_type(arrayt::Type{<:AbstractArray}) = Base.promote_op(unscaled, arrayt) -coeff(a::AbstractArray) = true -coeff_type(arrayt::Type{<:AbstractArray}) = Base.promote_op(coeff, arrayt) -function scaled_eltype(coeff::Number, a::AbstractArray) - return Base.promote_op(*, typeof(coeff), eltype(a)) -end - -# Base overloads for ScaledArrays. -axes_scaled(a::AbstractArray) = axes(unscaled(a)) -size_scaled(a::AbstractArray) = size(unscaled(a)) -similar_scaled(a::AbstractArray) = similar(unscaled(a)) -similar_scaled(a::AbstractArray, elt::Type) = similar(unscaled(a), elt) -similar_scaled(a::AbstractArray, ax) = similar(unscaled(a), ax) -similar_scaled(a::AbstractArray, elt::Type, ax) = similar(unscaled(a), elt, ax) -getindex_scaled(a::AbstractArray, I...) = coeff(a) * getindex(unscaled(a), I...) -copyto!_scaled(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false) -show_scaled(io::IO, a::AbstractArray) = show_lazy(io, a) -show_scaled(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a) - -# Base overloads of adjoint and transpose for ScaledArrays. -adjoint_scaled(a::AbstractArray) = coeff(a) *ₗ adjoint(unscaled(a)) -transpose_scaled(a::AbstractArray) = coeff(a) *ₗ transpose(unscaled(a)) - -# Base.Broadcast overloads for ScaledArrays. -materialize_scaled(a::AbstractArray) = copy(a) -function BroadcastStyle_scaled(arrayt::Type{<:AbstractArray}) - return LazyArrayStyle(BC.BroadcastStyle(unscaled_type(arrayt))) -end - -# LinearAlgebra overloads for ScaledArrays. -function mul!_scaled( - dest::AbstractArray, - a::AbstractArray, - b::AbstractArray, - α::Number, - β::Number - ) - return LA.mul!(dest, unscaled(a), unscaled(b), coeff(a) * coeff(b) * α, β) -end - -# Lazy operations for ScaledArrays. -mulled_scaled(α::Number, a::AbstractArray) = (α * coeff(a)) *ₗ unscaled(a) -function mulled_scaled(a::AbstractArray, b::AbstractArray) - return (coeff(a) * coeff(b)) *ₗ (unscaled(a) *ₗ unscaled(b)) -end -conjed_scaled(a::AbstractArray) = conj(coeff(a)) *ₗ conjed(unscaled(a)) - -# TensorAlgebra overloads for ScaledArrays. -function add!_scaled(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) - return add!(dest, unscaled(src), coeff(src) * α, β) -end - -# TermInterface-like overloads for ScaledArrays. -iscall_scaled(::AbstractArray) = true -operation_scaled(::AbstractArray) = * -arguments_scaled(a::AbstractArray) = (coeff(a), unscaled(a)) - -# FunctionImplementations overloads for ScaledArrays. -permuteddims_scaled(a::AbstractArray, perm) = coeff(a) *ₗ FI.permuteddims(unscaled(a), perm) - -macro scaledarray_type(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - struct $ScaledArray{T, N, P <: AbstractArray{<:Any, N}, C <: Number} <: - $AbstractArray{T, N} - coeff::C - parent::P - function $ScaledArray(coeff::Number, a::AbstractArray) - T = $TensorAlgebra.scaled_eltype(coeff, a) - return new{T, ndims(a), typeof(a), typeof(coeff)}(coeff, a) - end - end - $TensorAlgebra.unscaled(a::$ScaledArray) = a.parent - function $TensorAlgebra.unscaled_type(arrayt::Type{<:$ScaledArray}) - return fieldtype(arrayt, :parent) - end - $TensorAlgebra.coeff(a::$ScaledArray) = a.coeff - function $TensorAlgebra.coeff_type(arrayt::Type{<:$ScaledArray}) - return fieldtype(arrayt, :coeff) - end - end - ) -end - -macro scaledarray_base(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.axes(a::$ScaledArray) = - $TensorAlgebra.axes_scaled(a) - Base.size(a::$ScaledArray) = - $TensorAlgebra.size_scaled(a) - Base.similar(a::$ScaledArray) = - $TensorAlgebra.similar_scaled(a) - function Base.similar(a::$ScaledArray, elt::Type) - return $TensorAlgebra.similar_scaled(a, elt) - end - Base.similar(a::$ScaledArray, ax) = - $TensorAlgebra.similar_scaled(a, ax) - Base.similar(a::$ScaledArray, ax::Tuple) = - $TensorAlgebra.similar_scaled(a, ax) - function Base.similar(a::$ScaledArray, elt::Type, ax) - return $TensorAlgebra.similar_scaled(a, elt, ax) - end - function Base.similar(a::$ScaledArray, elt::Type, ax::Dims) - return $TensorAlgebra.similar_scaled(a, elt, ax) - end - Base.@propagate_inbounds function Base.getindex(a::$ScaledArray, I...) - return $TensorAlgebra.getindex_scaled(a, I...) - end - function Base.copyto!(dest::$AbstractArray, src::$ScaledArray) - return $TensorAlgebra.copyto!_scaled(dest, src) - end - Base.show(io::IO, a::$ScaledArray) = - $TensorAlgebra.show_scaled(io, a) - function Base.show(io::IO, mime::MIME"text/plain", a::$ScaledArray) - return $TensorAlgebra.show_scaled(io, mime, a) - end - end - ) -end - -macro scaledarray_adjtrans(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.adjoint(a::$ScaledArray) = - $TensorAlgebra.adjoint_scaled(a) - Base.transpose(a::$ScaledArray) = - $TensorAlgebra.transpose_scaled(a) - end - ) -end - -macro scaledarray_broadcast(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - function Base.Broadcast.materialize(a::$ScaledArray) - return $TensorAlgebra.materialize_scaled(a) - end - function Base.Broadcast.BroadcastStyle(arrayt::Type{<:$ScaledArray}) - return $TensorAlgebra.BroadcastStyle_scaled(arrayt) - end - end - ) -end - -macro scaledarray_linearalgebra(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.LA.mul!( - dest::$AbstractArray{<:Any, 2}, - a::$ScaledArray{<:Any, 2}, - b::$ScaledArray{<:Any, 2}, - α::Number, β::Number - ) - return $TensorAlgebra.mul!_scaled(dest, a, b, α, β) - end - function $TensorAlgebra.LA.mul!( - dest::$AbstractArray{<:Any, 2}, - a::$AbstractArray{<:Any, 2}, - b::$ScaledArray{<:Any, 2}, - α::Number, β::Number - ) - return $TensorAlgebra.mul!_scaled(dest, a, b, α, β) - end - function $TensorAlgebra.LA.mul!( - dest::$AbstractArray{<:Any, 2}, - a::$ScaledArray{<:Any, 2}, - b::$AbstractArray{<:Any, 2}, - α::Number, β::Number - ) - return $TensorAlgebra.mul!_scaled(dest, a, b, α, β) - end - end - ) -end - -macro scaledarray_lazy(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.:*ₗ(α::Number, a::$ScaledArray) - return $TensorAlgebra.mulled_scaled(α, a) - end - function $TensorAlgebra.:*ₗ(a::$ScaledArray, b::$ScaledArray) - return $TensorAlgebra.mulled_scaled(a, b) - end - function $TensorAlgebra.:*ₗ(a::$AbstractArray, b::$ScaledArray) - return $TensorAlgebra.mulled_scaled(a, b) - end - function $TensorAlgebra.:*ₗ(a::$ScaledArray, b::$AbstractArray) - return $TensorAlgebra.mulled_scaled(a, b) - end - $TensorAlgebra.conjed(a::$ScaledArray) = - $TensorAlgebra.conjed_scaled(a) - end - ) -end - -macro scaledarray_tensoralgebra(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.add!( - dest::$AbstractArray, src::$ScaledArray, α::Number, β::Number - ) - return $TensorAlgebra.add!_scaled(dest, src, α, β) - end - end - ) -end - -macro scaledarray_terminterface(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.iscall(a::$ScaledArray) = $TensorAlgebra.iscall_scaled(a) - $TensorAlgebra.operation(a::$ScaledArray) = $TensorAlgebra.operation_scaled(a) - $TensorAlgebra.arguments(a::$ScaledArray) = $TensorAlgebra.arguments_scaled(a) - end - ) -end - -macro scaledarray_functionimplementations(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.FI.permuteddims(a::$ScaledArray, perm) - return $TensorAlgebra.permuteddims_scaled(a, perm) - end - end - ) -end - -macro scaledarray(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.@scaledarray_base $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_adjtrans $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_broadcast $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_lazy $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_linearalgebra $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_tensoralgebra $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_terminterface $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_functionimplementations $ScaledArray $AbstractArray - end - ) -end - -# Generic constructors for ConjArrays. -conjed(a::AbstractArray) = ConjArray(a) -conjed_type(arrayt::Type{<:AbstractArray}) = Base.promote_op(conjed, arrayt) - -# Base overloads for ConjArrays. -axes_conj(a::AbstractArray) = axes(conjed(a)) -size_conj(a::AbstractArray) = size(conjed(a)) -similar_conj(a::AbstractArray, elt::Type) = similar(conjed(a), elt) -similar_conj(a::AbstractArray, elt::Type, ax) = similar(conjed(a), elt, ax) -similar_conj(a::AbstractArray, ax) = similar(conjed(a), ax) -getindex_conj(a::AbstractArray, I...) = conj(getindex(conjed(a), I...)) -copyto!_conj(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false) -show_conj(io::IO, a::AbstractArray) = show_lazy(io, a) -show_conj(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a) - -# Base overloads of adjoint and transpose for ConjArrays. -adjoint_conj(a::AbstractArray) = transpose(conjed(a)) -transpose_conj(a::AbstractArray) = adjoint(conjed(a)) - -# Base.Broadcast overloads for ConjArrays. -materialize_conj(a::AbstractArray) = copy(a) -function BroadcastStyle_conj(arrayt::Type{<:AbstractArray}) - return LazyArrayStyle(BC.BroadcastStyle(conjed_type(arrayt))) -end - -# StridedViews overloads for ConjArrays. -isstrided_conj(a::AbstractArray) = SV.isstrided(conjed(a)) -StridedView_conj(a::AbstractArray) = conj(SV.StridedView(conjed(a))) - -# TermInterface-like overloads for ConjArrays. -iscall_conj(::AbstractArray) = true -operation_conj(::AbstractArray) = conj -arguments_conj(a::AbstractArray) = (conjed(a),) - -# FunctionImplementations overloads for ConjArrays. -permuteddims_conj(a::AbstractArray, perm) = conjed(FI.permuteddims(conjed(a), perm)) - -macro conjarray_type(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - struct $ConjArray{T, N, P <: AbstractArray{T, N}} <: $AbstractArray{T, N} - parent::P - end - $TensorAlgebra.conjed(a::$ConjArray) = a.parent - end - ) -end - -macro conjarray_base(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.axes(a::$ConjArray) = - $TensorAlgebra.axes_conj(a) - Base.size(a::$ConjArray) = - $TensorAlgebra.size_conj(a) - Base.similar(a::$ConjArray, elt::Type) = - $TensorAlgebra.similar_conj(a, elt) - function Base.similar(a::$ConjArray, elt::Type, ax) - return $TensorAlgebra.similar_conj(a, elt, ax) - end - function Base.similar(a::$ConjArray, elt::Type, ax::Dims) - return $TensorAlgebra.similar_conj(a, elt, ax) - end - Base.@propagate_inbounds function Base.getindex(a::$ConjArray, I...) - return $TensorAlgebra.getindex_conj(a, I...) - end - function Base.copyto!(dest::$AbstractArray, src::$ConjArray) - return $TensorAlgebra.copyto!_conj(dest, src) - end - Base.show(io::IO, a::$ConjArray) = $TensorAlgebra.show_conj(io, a) - function Base.show(io::IO, mime::MIME"text/plain", a::$ConjArray) - return $TensorAlgebra.show_conj(io, mime, a) - end - end - ) -end - -macro conjarray_adjtrans(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.adjoint(a::$ConjArray) = - $TensorAlgebra.adjoint_conj(a) - Base.transpose(a::$ConjArray) = - $TensorAlgebra.transpose_conj(a) - end - ) -end - -macro conjarray_broadcast(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.Broadcast.materialize(a::$ConjArray) = $TensorAlgebra.materialize_conj(a) - function Base.Broadcast.BroadcastStyle(arrayt::Type{<:$ConjArray}) - return $TensorAlgebra.BroadcastStyle_conj(arrayt) - end - end - ) -end - -macro conjarray_stridedviews(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.SV.isstrided(a::$ConjArray) = - $TensorAlgebra.isstrided_conj(a) - function $TensorAlgebra.SV.StridedView(a::$ConjArray) - return $TensorAlgebra.StridedView_conj(a) - end - end - ) -end - -macro conjarray_terminterface(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.iscall(a::$ConjArray) = $TensorAlgebra.iscall_conj(a) - $TensorAlgebra.operation(a::$ConjArray) = $TensorAlgebra.operation_conj(a) - $TensorAlgebra.arguments(a::$ConjArray) = $TensorAlgebra.arguments_conj(a) - end - ) -end - -macro conjarray_functionimplementations(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.FI.permuteddims(a::$ConjArray, perm) - return $TensorAlgebra.permuteddims_conj(a, perm) - end - end - ) -end - -macro conjarray(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.@conjarray_base $ConjArray $AbstractArray - $TensorAlgebra.@conjarray_adjtrans $ConjArray $AbstractArray - $TensorAlgebra.@conjarray_broadcast $ConjArray $AbstractArray - $TensorAlgebra.@conjarray_stridedviews $ConjArray $AbstractArray - $TensorAlgebra.@conjarray_terminterface $ConjArray $AbstractArray - $TensorAlgebra.@conjarray_functionimplementations $ConjArray $AbstractArray - end - ) -end - -# Generic constructors, accessors, and properties for AddArrays. -+ₗ(a::AbstractArray, b::AbstractArray) = AddArray(a, b) -addends(a::AbstractArray) = (a,) -addends_type(arrayt::Type{<:AbstractArray}) = Tuple{arrayt} -add_eltype(args::AbstractArray...) = Base.promote_op(+, eltype.(args)...) -function add_ndims(args::AbstractArray...) - return if allequal(ndims, args) - ndims(first(args)) - else - error("All addends must have the same number of dimensions.") - end -end - -# Base overloads for AddArrays. -add_axes(args::AbstractArray...) = BC.combine_axes(args...) -axes_add(a::AbstractArray) = add_axes(addends(a)...) -size_add(a::AbstractArray) = length.(axes_add(a)) -similar_add(a::AbstractArray) = similar(a, eltype(a)) -similar_add(a::AbstractArray, ax::Tuple) = similar(a, eltype(a), ax) -similar_add(a::AbstractArray, elt::Type) = similar(BC.Broadcasted(+, addends(a)), elt) -function similar_add(a::AbstractArray, elt::Type, ax) - return similar(BC.Broadcasted(+, addends(a)), elt, ax) -end -getindex_add(a::AbstractArray, I...) = sum(addend -> getindex(addend, I...), addends(a)) -copyto!_add(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false) -show_add(io::IO, a::AbstractArray) = show_lazy(io, a) -show_add(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a) - -# Base overloads of adjoint and transpose for AddArrays. -adjoint_add(a::AbstractArray) = +ₗ(adjoint.(addends(a))...) -transpose_add(a::AbstractArray) = +ₗ(transpose.(addends(a))...) - -# Base.Broadcast overloads for AddArrays. -materialize_add(a::AbstractArray) = copy(a) -function BroadcastStyle_add(arrayt::Type{<:AbstractArray}) - args_type = addends_type(arrayt) - style = Base.promote_op(BC.combine_styles, fieldtypes(args_type)...)() - return LazyArrayStyle(style) -end - -# TensorAlgebra overloads for AddArrays. -function add!_add(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) - args = addends(src) - add!(dest, first(args), α, β) - for a in Base.tail(args) - add!(dest, a, α, true) - end - return dest -end - -# Lazy operations for AddArrays. -added_add(a::AbstractArray, b::AbstractArray) = AddArray((addends(a)..., addends(b)...)...) -mulled_add(α::Number, a::AbstractArray) = +ₗ((α .*ₗ addends(a))...) -## TODO: Define multiplication of added arrays by expanding all combinations, treating -## both inputs as AddArrays. -## mulled_add(a::AbstractArray, b::AbstractArray) = +ₗ((Ref(a) .*ₗ addends(b))...) -## mulled_add(a::AddArray, b::AbstractArray) = +ₗ((addends(a) .*ₗ Ref(b))...) -## mulled_add(a::AddArray, b::AddArray) = +ₗ((Ref(a) .*ₗ addends(b))...) -conjed_add(a::AbstractArray) = +ₗ(conjed.(addends(a))...) - -# TermInterface-like overloads for AddArrays. -iscall_add(::AbstractArray) = true -operation_add(::AbstractArray) = + -arguments_add(a::AbstractArray) = addends(a) - -# FunctionImplementations overloads for AddArrays. -function permuteddims_add(a::AbstractArray, perm) - return +ₗ(Base.Fix2(FI.permuteddims, perm).(addends(a))...) -end - -macro addarray_type(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - struct $AddArray{T, N, Args <: Tuple{Vararg{AbstractArray{<:Any, N}}}} <: - $AbstractArray{T, N} - args::Args - function $AddArray(args::AbstractArray...) - T = $TensorAlgebra.add_eltype(args...) - N = $TensorAlgebra.add_ndims(args...) - return new{T, N, typeof(args)}(args) - end - end - $TensorAlgebra.addends(a::$AddArray) = a.args - function $TensorAlgebra.addends_type(arrayt::Type{<:$AddArray}) - return fieldtype(arrayt, :args) - end - end - ) -end - -macro addarray_base(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.axes(a::$AddArray) = $TensorAlgebra.axes_add(a) - Base.size(a::$AddArray) = $TensorAlgebra.size_add(a) - Base.similar(a::$AddArray) = $TensorAlgebra.similar_add(a) - Base.similar(a::$AddArray, ax::Tuple) = $TensorAlgebra.similar_add(a, ax) - Base.similar(a::$AddArray, elt::Type) = $TensorAlgebra.similar_add(a, elt) - function Base.similar( - a::$AddArray, elt::Type, - ax::Tuple{Union{Integer, Base.OneTo}, Vararg{Union{Integer, Base.OneTo}}} - ) - return $TensorAlgebra.similar_add(a, elt, ax) - end - function Base.similar(a::$AddArray, elt::Type, ax::Dims) - return $TensorAlgebra.similar_add(a, elt, ax) - end - function Base.similar(a::$AddArray, elt::Type, ax) - return $TensorAlgebra.similar_add(a, elt, ax) - end - Base.@propagate_inbounds function Base.getindex(a::$AddArray, I...) - return $TensorAlgebra.getindex_add(a, I...) - end - function Base.copyto!(dest::$AbstractArray, src::$AddArray) - return $TensorAlgebra.copyto!_add(dest, src) - end - Base.show(io::IO, a::$AddArray) = - $TensorAlgebra.show_add(io, a) - function Base.show(io::IO, mime::MIME"text/plain", a::$AddArray) - return $TensorAlgebra.show_add(io, mime, a) - end - end - ) -end - -macro addarray_adjtrans(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.adjoint(a::$AddArray) = - $TensorAlgebra.adjoint_add(a) - Base.transpose(a::$AddArray) = - $TensorAlgebra.transpose_add(a) - end - ) -end - -macro addarray_broadcast(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.Broadcast.materialize(a::$AddArray) = $TensorAlgebra.materialize_add(a) - function Base.Broadcast.BroadcastStyle(arrayt::Type{<:$AddArray}) - return $TensorAlgebra.BroadcastStyle_add(arrayt) - end - end - ) -end - -macro addarray_lazy(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.:+ₗ(a::$AbstractArray, b::$AddArray) - return $TensorAlgebra.added_add(a, b) - end - function $TensorAlgebra.:+ₗ(a::$AddArray, b::$AbstractArray) - return $TensorAlgebra.added_add(a, b) - end - $TensorAlgebra.:+ₗ(a::$AddArray, b::$AddArray) = - $TensorAlgebra.added_add(a, b) - $TensorAlgebra.:*ₗ(α::Number, a::$AddArray) = - $TensorAlgebra.mulled_add(α, a) - function $TensorAlgebra.:*ₗ(a::$AbstractArray, b::$AddArray) - return $TensorAlgebra.mulled_add(a, b) - end - function $TensorAlgebra.:*ₗ(a::$AddArray, b::$AbstractArray) - return $TensorAlgebra.mulled_add(a, b) - end - function $TensorAlgebra.:*ₗ(a::$AddArray, b::$AddArray) - return $TensorAlgebra.mulled_add(a, b) - end - $TensorAlgebra.conjed(a::$AddArray) = - $TensorAlgebra.conjed_add(a) - end - ) -end - -macro addarray_tensoralgebra(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.add!( - dest::$AbstractArray, src::$AddArray, α::Number, β::Number - ) - return $TensorAlgebra.add!_add(dest, src, α, β) - end - end - ) -end - -macro addarray_terminterface(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.iscall(a::$AddArray) = $TensorAlgebra.iscall_add(a) - $TensorAlgebra.operation(a::$AddArray) = $TensorAlgebra.operation_add(a) - $TensorAlgebra.arguments(a::$AddArray) = $TensorAlgebra.arguments_add(a) - end - ) -end - -macro addarray_functionimplementations(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.FI.permuteddims(a::$AddArray, perm) - return $TensorAlgebra.permuteddims_add(a, perm) - end - end - ) -end - -macro addarray(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.@addarray_base $AddArray $AbstractArray - $TensorAlgebra.@addarray_adjtrans $AddArray $AbstractArray - $TensorAlgebra.@addarray_broadcast $AddArray $AbstractArray - $TensorAlgebra.@addarray_lazy $AddArray $AbstractArray - $TensorAlgebra.@addarray_tensoralgebra $AddArray $AbstractArray - $TensorAlgebra.@addarray_terminterface $AddArray $AbstractArray - $TensorAlgebra.@addarray_functionimplementations $AddArray $AbstractArray - end - ) -end - -# Generic constructors, accessors, and properties for MulArrays. -*ₗ(a::AbstractArray, b::AbstractArray) = MulArray(a, b) -factors(a::AbstractArray) = (a,) -factor_types(arrayt::Type{<:AbstractArray}) = Base.promote_op(factors, arrayt) -# Same as `LinearAlgebra.matprod`, but duplicated here since it is private. -matprod(x, y) = x * y + x * y -function mul_eltype(a::AbstractArray, b::AbstractArray) - return Base.promote_op(matprod, eltype(a), eltype(b)) -end -mul_ndims(a::AbstractArray, b::AbstractArray) = ndims(b) -mul_axes(a::AbstractArray, b::AbstractArray) = (axes(a, 1), axes(b, ndims(b))) - -# Base overloads for MulArrays. -eltype_mul(a::AbstractArray{T}) where {T} = T -axes_mul(a::AbstractArray) = mul_axes(factors(a)...) -size_mul(a::AbstractArray) = length.(axes_mul(a)) -similar_mul(a::AbstractArray) = similar(a, eltype(a)) -similar_mul(a::AbstractArray, ax::Tuple) = similar(a, eltype(a), ax) -similar_mul(a::AbstractArray, elt::Type) = similar(a, elt, axes(a)) -# TODO: Make use of both arguments to determine the output, maybe -# using `LinearAlgebra.matprod_dest(factors(a)..., elt)`? -similar_mul(a::AbstractArray, elt::Type, ax) = similar(last(factors(a)), elt, ax) -function mul_getindex(a1::AbstractMatrix, a2::AbstractMatrix, i::Int, j::Int) - return transpose(view(a1, i, :)) * view(a2, :, j) -end -function mul_getindex(a1::AbstractMatrix, a2::AbstractVector, i::Int) - return transpose(view(a1, i, :)) * a2 -end -function mul_getindex(a1::AbstractVector, a2::AbstractMatrix, j::Int) - return transpose(a1) * view(a2, :, j) -end -function getindex_mul(a::AbstractArray, i::Int) - I = Tuple(CartesianIndices(axes(a))[i]) - return getindex_mul(a, I...) -end -getindex_mul(a::AbstractArray, I::Vararg{Int}) = mul_getindex(factors(a)..., I...) -copyto!_mul(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false) -show_mul(io::IO, a::AbstractArray) = show_lazy(io, a) -show_mul(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a) - -# Base overloads of adjoint and transpose for MulArrays. -adjoint_mul(a::AbstractArray) = *ₗ(reverse(adjoint.(factors(a)))...) -transpose_mul(a::AbstractArray) = *ₗ(reverse(transpose.(factors(a)))...) - -# Base.Broadcast overloads for MulArrays. -materialize_mul(a::AbstractArray) = copy(a) -function BroadcastStyle_mul(arrayt::Type{<:AbstractArray}) - style = Base.promote_op(BC.combine_styles, factor_types(arrayt)...)() - return LazyArrayStyle(style) -end - -# TensorAlgebra overloads for MulArrays. -# We materialize the arguments here to avoid nested lazy evaluation. -# Rewrite rules should make it so that `MulArray` is a "leaf` node of the -# expression tree. -function add!_mul(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) - return LA.mul!(dest, BC.materialize.(factors(src))..., α, β) -end - -# Lazy operations for MulArrays. -conjed_mul(a::AbstractArray) = *ₗ(conjed.(factors(a))...) -# Matmul isn't a broadcasting operation so we materialize (i.e. -# perform the matrix multiplication) when building a broadcast -# expression involving a `MulArray`. -# TODO: Use `Broadcast.broadcastable` interface for this? -to_broadcasted_mul(a::AbstractArray) = *(factors(a)...) - -# TermInterface-like overloads for MulArrays. -iscall_mul(::AbstractArray) = true -operation_mul(::AbstractArray) = * -arguments_mul(a::AbstractArray) = factors(a) - -macro mularray_type(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - struct $MulArray{T, N, A <: AbstractArray, B <: AbstractArray} <: - $AbstractArray{T, N} - a::A - b::B - function $MulArray(a::AbstractArray, b::AbstractArray) - T = $TensorAlgebra.mul_eltype(a, b) - N = $TensorAlgebra.mul_ndims(a, b) - return new{T, N, typeof(a), typeof(b)}(a, b) - end - end - $TensorAlgebra.factors(a::$MulArray) = (a.a, a.b) - function $TensorAlgebra.factor_types(arrayt::Type{<:$MulArray}) - return (fieldtype(arrayt, :a), fieldtype(arrayt, :b)) - end - end - ) -end - -function copy_permuteddims(a::PermutedDimsArray{<:Any, 2, perm}) where {perm} - perm == (1, 2) && return copy(parent(a)) - return copy(transpose(parent(a))) -end - -macro mularray_base(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.eltype(a::$MulArray) = $TensorAlgebra.eltype_mul(a) - Base.axes(a::$MulArray) = $TensorAlgebra.axes_mul(a) - Base.size(a::$MulArray) = $TensorAlgebra.size_mul(a) - Base.similar(a::$MulArray) = $TensorAlgebra.similar_mul(a) - Base.similar(a::$MulArray, ax::Tuple) = $TensorAlgebra.similar_mul(a, ax) - Base.similar(a::$MulArray, elt::Type) = $TensorAlgebra.similar_mul(a, elt) - function Base.similar( - a::$MulArray, elt::Type, - ax::Tuple{Union{Integer, Base.OneTo}, Vararg{Union{Integer, Base.OneTo}}} - ) - return $TensorAlgebra.similar_mul(a, elt, ax) - end - function Base.similar(a::$MulArray, elt::Type, ax) - return $TensorAlgebra.similar_mul(a, elt, ax) - end - function Base.similar(a::$MulArray, elt::Type, ax::Dims) - return $TensorAlgebra.similar_mul(a, elt, ax) - end - Base.@propagate_inbounds function Base.getindex(a::$MulArray, I...) - return $TensorAlgebra.getindex_mul(a, I...) - end - function Base.copyto!(dest::$AbstractArray, src::$MulArray) - return $TensorAlgebra.copyto!_mul(dest, src) - end - Base.show(io::IO, a::$MulArray) = $TensorAlgebra.show_mul(io, a) - function Base.show(io::IO, mime::MIME"text/plain", a::$MulArray) - return $TensorAlgebra.show_mul(io, mime, a) - end - end - ) -end - -macro mularray_adjtrans(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.adjoint(a::$MulArray) = - $TensorAlgebra.adjoint_mul(a) - Base.transpose(a::$MulArray) = - $TensorAlgebra.transpose_mul(a) - end - ) -end - -macro mularray_broadcast(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.Broadcast.materialize(a::$MulArray) = $TensorAlgebra.materialize_mul(a) - function Base.Broadcast.BroadcastStyle(arrayt::Type{<:$MulArray}) - return $TensorAlgebra.BroadcastStyle_mul(arrayt) - end - end - ) -end - -macro mularray_lazy(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.conjed(a::$MulArray) = $TensorAlgebra.conjed_mul(a) - function $TensorAlgebra.to_broadcasted(a::$MulArray) - return $TensorAlgebra.to_broadcasted_mul(a) - end - end - ) -end - -macro mularray_tensoralgebra(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.add!( - dest::$AbstractArray, src::$MulArray, α::Number, β::Number - ) - return $TensorAlgebra.add!_mul(dest, src, α, β) - end - end - ) -end - -macro mularray_terminterface(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.iscall(a::$MulArray) = $TensorAlgebra.iscall_mul(a) - $TensorAlgebra.operation(a::$MulArray) = $TensorAlgebra.operation_mul(a) - $TensorAlgebra.arguments(a::$MulArray) = $TensorAlgebra.arguments_mul(a) - function Base.copy(a::PermutedDimsArray{<:Any, 2, <:Any, <:Any, $MulArray}) - return $TensorAlgebra.copy_permuteddims(a) - end - end - ) -end - -macro mularray(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.@mularray_base $MulArray $AbstractArray - $TensorAlgebra.@mularray_adjtrans $MulArray $AbstractArray - $TensorAlgebra.@mularray_broadcast $MulArray $AbstractArray - $TensorAlgebra.@mularray_lazy $MulArray $AbstractArray - $TensorAlgebra.@mularray_tensoralgebra $MulArray $AbstractArray - $TensorAlgebra.@mularray_terminterface $MulArray $AbstractArray - end - ) -end - -# Define types. -@scaledarray_type ScaledArray -@scaledarray ScaledArray -@conjarray_type ConjArray -@conjarray ConjArray -@addarray_type AddArray -@addarray AddArray -@mularray_type MulArray -@mularray MulArray diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl new file mode 100644 index 00000000..8ee7c5ae --- /dev/null +++ b/src/linearbroadcasted.jl @@ -0,0 +1,565 @@ +import Base.Broadcast as BC +import FunctionImplementations as FI +import LinearAlgebra as LA +import StridedViews as SV + +# TermInterface-like interface. +iscall(x) = false +function operation end +function arguments end + +# ---------------------------------------------------------------------------- # +# LinearBroadcasted — lazy linear broadcast expressions (not <: AbstractArray) +# ---------------------------------------------------------------------------- # + +""" + LinearBroadcasted + +Abstract supertype for lazy linear broadcast expressions. Analogous to +`Base.Broadcast.Broadcasted` but restricted to linear operations. + +Materializes via the protocol: +copy(lb) = copyto!(similar(lb), lb) +copyto!(dest, lb) → add!(dest, lb, 1, 0) +""" +abstract type LinearBroadcasted end + +# Generic axes(a, d) for LinearBroadcasted subtypes. +Base.axes(a::LinearBroadcasted, d::Int) = axes(a)[d] + +# --- ScaledBroadcasted -------------------------------------------------------- + +struct ScaledBroadcasted{T, N, C <: Number, A} <: LinearBroadcasted + coeff::C + parent::A + function ScaledBroadcasted(coeff::Number, a) + T = Base.promote_op(*, typeof(coeff), eltype(a)) + return new{T, ndims(a), typeof(coeff), typeof(a)}(coeff, a) + end +end + +unscaled(a::ScaledBroadcasted) = a.parent +coeff(a::ScaledBroadcasted) = a.coeff + +Base.axes(a::ScaledBroadcasted) = axes(unscaled(a)) +Base.eltype(::Type{<:ScaledBroadcasted{T}}) where {T} = T +Base.eltype(::ScaledBroadcasted{T}) where {T} = T +Base.ndims(::Type{<:ScaledBroadcasted{<:Any, N}}) where {N} = N +Base.ndims(::ScaledBroadcasted{<:Any, N}) where {N} = N + +function Base.similar(a::ScaledBroadcasted) + return similar(unscaled(a), eltype(a), axes(a)) +end +function Base.similar(a::ScaledBroadcasted, elt::Type) + return similar(unscaled(a), elt, axes(a)) +end + +function Base.show(io::IO, a::ScaledBroadcasted) + print(io, "*(", coeff(a), ", ", unscaled(a), ")") + return nothing +end + +function Base.adjoint(a::ScaledBroadcasted) + return ScaledBroadcasted(coeff(a), adjoint(unscaled(a))) +end +function Base.transpose(a::ScaledBroadcasted) + return ScaledBroadcasted(coeff(a), transpose(unscaled(a))) +end + +function FI.permuteddims(a::ScaledBroadcasted, perm) + return ScaledBroadcasted(coeff(a), FI.permuteddims(unscaled(a), perm)) +end + +iscall(::ScaledBroadcasted) = true +operation(::ScaledBroadcasted) = * +arguments(a::ScaledBroadcasted) = (coeff(a), unscaled(a)) + +# --- ConjBroadcasted ---------------------------------------------------------- + +struct ConjBroadcasted{T, N, A} <: LinearBroadcasted + parent::A + function ConjBroadcasted(a) + return new{eltype(a), ndims(a), typeof(a)}(a) + end +end + +unconj(a::ConjBroadcasted) = a.parent + +Base.axes(a::ConjBroadcasted) = axes(unconj(a)) +Base.eltype(::Type{<:ConjBroadcasted{T}}) where {T} = T +Base.eltype(::ConjBroadcasted{T}) where {T} = T +Base.ndims(::Type{<:ConjBroadcasted{<:Any, N}}) where {N} = N +Base.ndims(::ConjBroadcasted{<:Any, N}) where {N} = N + +function Base.similar(a::ConjBroadcasted) + return similar(unconj(a), eltype(a), axes(a)) +end +function Base.similar(a::ConjBroadcasted, elt::Type) + return similar(unconj(a), elt, axes(a)) +end + +function Base.show(io::IO, a::ConjBroadcasted) + print(io, "conj(", unconj(a), ")") + return nothing +end + +Base.conj(a::ConjBroadcasted) = unconj(a) +Base.adjoint(a::ConjBroadcasted) = transpose(unconj(a)) +Base.transpose(a::ConjBroadcasted) = adjoint(unconj(a)) + +function FI.permuteddims(a::ConjBroadcasted, perm) + return ConjBroadcasted(FI.permuteddims(unconj(a), perm)) +end + +SV.isstrided(a::ConjBroadcasted) = SV.isstrided(unconj(a)) +SV.StridedView(a::ConjBroadcasted) = conj(SV.StridedView(unconj(a))) + +iscall(::ConjBroadcasted) = true +operation(::ConjBroadcasted) = conj +arguments(a::ConjBroadcasted) = (unconj(a),) + +# --- AddBroadcasted ----------------------------------------------------------- + +struct AddBroadcasted{T, N, Args <: Tuple} <: LinearBroadcasted + args::Args + function AddBroadcasted(args...) + T = Base.promote_op(+, eltype.(args)...) + N = if allequal(ndims, args) + ndims(first(args)) + else + error("All addends must have the same number of dimensions.") + end + return new{T, N, typeof(args)}(args) + end +end + +addends(a::AddBroadcasted) = a.args + +Base.axes(a::AddBroadcasted) = BC.combine_axes(addends(a)...) +Base.eltype(::Type{<:AddBroadcasted{T}}) where {T} = T +Base.eltype(::AddBroadcasted{T}) where {T} = T +Base.ndims(::Type{<:AddBroadcasted{<:Any, N}}) where {N} = N +Base.ndims(::AddBroadcasted{<:Any, N}) where {N} = N + +function Base.similar(a::AddBroadcasted) + return similar(BC.Broadcasted(+, addends(a)), eltype(a)) +end +function Base.similar(a::AddBroadcasted, elt::Type) + return similar(BC.Broadcasted(+, addends(a)), elt) +end + +function Base.show(io::IO, a::AddBroadcasted) + print(io, "+(", join(addends(a), ", "), ")") + return nothing +end + +function Base.adjoint(a::AddBroadcasted) + return AddBroadcasted(adjoint.(addends(a))...) +end +function Base.transpose(a::AddBroadcasted) + return AddBroadcasted(transpose.(addends(a))...) +end + +function FI.permuteddims(a::AddBroadcasted, perm) + return AddBroadcasted(Base.Fix2(FI.permuteddims, perm).(addends(a))...) +end + +iscall(::AddBroadcasted) = true +operation(::AddBroadcasted) = + +arguments(a::AddBroadcasted) = addends(a) + +# ---------------------------------------------------------------------------- # +# Mul — lazy matrix multiplication (standalone, not LinearBroadcasted) +# ---------------------------------------------------------------------------- # + +# Same as `LinearAlgebra.matprod`, but duplicated here since it is private. +matprod(x, y) = x * y + x * y + +struct Mul{T, N, A, B} + a::A + b::B + function Mul(a, b) + T = Base.promote_op(matprod, eltype(a), eltype(b)) + N = ndims(b) + return new{T, N, typeof(a), typeof(b)}(a, b) + end +end + +factors(a::Mul) = (a.a, a.b) + +Base.axes(a::Mul) = (axes(a.a, 1), axes(a.b, ndims(a.b))) +Base.axes(a::Mul, d::Int) = axes(a)[d] +Base.eltype(::Type{<:Mul{T}}) where {T} = T +Base.eltype(::Mul{T}) where {T} = T +Base.ndims(::Type{<:Mul{<:Any, N}}) where {N} = N +Base.ndims(::Mul{<:Any, N}) where {N} = N +Base.size(a::Mul) = length.(axes(a)) + +function Base.similar(a::Mul) + return similar(BC.materialize(last(factors(a))), eltype(a), axes(a)) +end +function Base.similar(a::Mul, elt::Type) + return similar(BC.materialize(last(factors(a))), elt, axes(a)) +end + +function Base.show(io::IO, a::Mul) + f = factors(a) + print(io, "*(", f[1], ", ", f[2], ")") + return nothing +end + +function Base.adjoint(a::Mul) + f = factors(a) + return Mul(adjoint(f[2]), adjoint(f[1])) +end +function Base.transpose(a::Mul) + f = factors(a) + return Mul(transpose(f[2]), transpose(f[1])) +end + +function FI.permuteddims(a::Mul{<:Any, 2}, perm) + perm == (1, 2) && return a + return transpose(a) +end + +iscall(::Mul) = true +operation(::Mul) = * +arguments(a::Mul) = factors(a) + +# ---------------------------------------------------------------------------- # +# Materialization protocol: copy, copyto!, add! +# ---------------------------------------------------------------------------- # + +function Base.copy(a::LinearBroadcasted) + return copyto!(similar(a), a) +end + +function Base.copy(a::Mul) + return copyto!(similar(a), a) +end + +# copyto! for LinearBroadcasted dispatches to add!. +function Base.copyto!(dest::AbstractArray, src::ScaledBroadcasted) + return add!(dest, src, true, false) +end +function Base.copyto!(dest::AbstractArray, src::ConjBroadcasted) + return add!(dest, src, true, false) +end +function Base.copyto!(dest::AbstractArray, src::AddBroadcasted) + return add!(dest, src, true, false) +end + +# copyto! for Mul dispatches to mul!. Materialize factors first since +# they may be LinearBroadcasted types. +function Base.copyto!(dest::AbstractArray, src::Mul) + return LA.mul!(dest, BC.materialize.(factors(src))..., true, false) +end + +# add! for LinearBroadcasted subtypes. +function add!(dest::AbstractArray, src::ScaledBroadcasted, α::Number, β::Number) + return add!(dest, unscaled(src), coeff(src) * α, β) +end + +function add!(dest::AbstractArray, src::ConjBroadcasted, α::Number, β::Number) + return add!(dest, unconj(src), α, β, Val(:conj)) +end + +# Default conj add! falls back to materializing conj. +function add!(dest::AbstractArray, src::AbstractArray, α::Number, β::Number, ::Val{:conj}) + return add!(dest, conj(src), α, β) +end + +function add!(dest::AbstractArray, src::AddBroadcasted, α::Number, β::Number) + args = addends(src) + add!(dest, first(args), α, β) + for a in Base.tail(args) + add!(dest, a, α, true) + end + return dest +end + +# add! for Mul materializes the factors and calls mul!. +function add!(dest::AbstractArray, src::Mul, α::Number, β::Number) + return LA.mul!(dest, BC.materialize.(factors(src))..., α, β) +end + +# ---------------------------------------------------------------------------- # +# LinearBroadcastFunction — constructor API +# ---------------------------------------------------------------------------- # + +""" + LinearBroadcastFunction(f) + +Wrap a function `f` so that calling it produces a `LinearBroadcasted` expression +instead of eagerly computing. Analogous to `Base.BroadcastFunction`. + +# Examples + +```julia +LinearBroadcastFunction(*)(2.0, a) # ScaledBroadcasted(2.0, a) +LinearBroadcastFunction(conj)(a) # ConjBroadcasted(a) +LinearBroadcastFunction(+)(a, b) # AddBroadcasted(a, b) +``` +""" +struct LinearBroadcastFunction{F} <: Function + f::F +end + +# Scaling: Number * AbstractArray +function (::LinearBroadcastFunction{typeof(*)})(α::Number, a::AbstractArray) + return ScaledBroadcasted(α, a) +end +function (::LinearBroadcastFunction{typeof(*)})(a::AbstractArray, α::Number) + return ScaledBroadcasted(α, a) +end +# Scaling of ScaledBroadcasted: absorb coefficient. +function (::LinearBroadcastFunction{typeof(*)})(α::Number, a::ScaledBroadcasted) + return ScaledBroadcasted(α * coeff(a), unscaled(a)) +end + +# Conjugation. +function (::LinearBroadcastFunction{typeof(conj)})(a::AbstractArray) + return ConjBroadcasted(a) +end +(::LinearBroadcastFunction{typeof(conj)})(a::AbstractArray{<:Real}) = a +(::LinearBroadcastFunction{typeof(conj)})(a::ConjBroadcasted) = unconj(a) +function (::LinearBroadcastFunction{typeof(conj)})(a::ScaledBroadcasted) + return ScaledBroadcasted( + conj(coeff(a)), LinearBroadcastFunction(conj)(unscaled(a)) + ) +end + +# Addition. +function (lf::LinearBroadcastFunction{typeof(+)})(a, b) + return AddBroadcasted(a, b) +end +function (lf::LinearBroadcastFunction{typeof(+)})(a, b, c, xs...) + return Base.afoldl(lf, lf(lf(a, b), c), xs...) +end +# Flatten AddBroadcasted + anything. +function (::LinearBroadcastFunction{typeof(+)})(a::AddBroadcasted, b) + return AddBroadcasted(addends(a)..., b) +end +function (::LinearBroadcastFunction{typeof(+)})(a, b::AddBroadcasted) + return AddBroadcasted(a, addends(b)...) +end +function (::LinearBroadcastFunction{typeof(+)})(a::AddBroadcasted, b::AddBroadcasted) + return AddBroadcasted(addends(a)..., addends(b)...) +end +(::LinearBroadcastFunction{typeof(+)})(a) = a + +# Subtraction. +function (::LinearBroadcastFunction{typeof(-)})(a, b) + return LinearBroadcastFunction(+)(a, LinearBroadcastFunction(*)(- 1, b)) +end +(::LinearBroadcastFunction{typeof(-)})(a) = LinearBroadcastFunction(*)(-1, a) + +# Division / left-division by scalars. +function (::LinearBroadcastFunction{typeof(/)})(a, b::Number) + return LinearBroadcastFunction(*)(inv(b), a) +end +function (::LinearBroadcastFunction{typeof(\)})(a::Number, b) + return LinearBroadcastFunction(*)(inv(a), b) +end + +# Identity. +(::LinearBroadcastFunction{typeof(identity)})(a) = a + +# Fix1/Fix2 wrappers for scalar multiplication/division. +function (lf::LinearBroadcastFunction{<:Base.Fix1{typeof(*)}})(a) + return LinearBroadcastFunction(*)(lf.f.x, a) +end +function (lf::LinearBroadcastFunction{<:Base.Fix2{typeof(*)}})(a) + return LinearBroadcastFunction(*)(a, lf.f.x) +end +function (lf::LinearBroadcastFunction{<:Base.Fix2{typeof(/)}})(a) + return LinearBroadcastFunction(/)(a, lf.f.x) +end + +# Scaling of AddBroadcasted distributes. +function (::LinearBroadcastFunction{typeof(*)})(α::Number, a::AddBroadcasted) + return LinearBroadcastFunction(+)( + map(x -> LinearBroadcastFunction(*)(α, x), addends(a))... + ) +end + +# Conjugation of AddBroadcasted distributes. +function (::LinearBroadcastFunction{typeof(conj)})(a::AddBroadcasted) + return LinearBroadcastFunction(+)( + map(x -> LinearBroadcastFunction(conj)(x), addends(a))... + ) +end + +# Conjugation of Mul distributes. +function (::LinearBroadcastFunction{typeof(conj)})(a::Mul) + f = factors(a) + return Mul(LinearBroadcastFunction(conj)(f[1]), LinearBroadcastFunction(conj)(f[2])) +end + +# Scaling of Mul: wrap in ScaledBroadcasted. +function (::LinearBroadcastFunction{typeof(*)})(α::Number, a::Mul) + return ScaledBroadcasted(α, a) +end + +# Number * Number passthrough (for broadcast lowering). +(::LinearBroadcastFunction{typeof(*)})(a::Number, b::Number) = a * b + +# ---------------------------------------------------------------------------- # +# Broadcast integration +# ---------------------------------------------------------------------------- # + +broadcast_is_linear(f, args...) = false +broadcast_is_linear(::typeof(identity), ::Base.AbstractArrayOrBroadcasted) = true +broadcast_is_linear(::typeof(+), ::Base.AbstractArrayOrBroadcasted...) = true +broadcast_is_linear(::typeof(-), ::Base.AbstractArrayOrBroadcasted) = true +function broadcast_is_linear( + ::typeof(-), ::Base.AbstractArrayOrBroadcasted, ::Base.AbstractArrayOrBroadcasted + ) + return true +end +broadcast_is_linear(::typeof(*), ::Number, ::Base.AbstractArrayOrBroadcasted) = true +broadcast_is_linear(::typeof(\), ::Number, ::Base.AbstractArrayOrBroadcasted) = true +broadcast_is_linear(::typeof(*), ::Base.AbstractArrayOrBroadcasted, ::Number) = true +broadcast_is_linear(::typeof(/), ::Base.AbstractArrayOrBroadcasted, ::Number) = true +function broadcast_is_linear( + ::typeof(*), ::Base.AbstractArrayOrBroadcasted, ::Base.AbstractArrayOrBroadcasted + ) + return false +end +broadcast_is_linear(::typeof(*), ::Number, ::Number) = true +broadcast_is_linear(::typeof(conj), ::Base.AbstractArrayOrBroadcasted) = true +function broadcast_is_linear( + ::Base.Fix1{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted + ) + return true +end +function broadcast_is_linear( + ::Base.Fix2{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted + ) + return true +end +function broadcast_is_linear( + ::Base.Fix2{typeof(/), <:Number}, ::Base.AbstractArrayOrBroadcasted + ) + return true +end +is_linear(x) = true +function is_linear(bc::BC.Broadcasted) + return broadcast_is_linear(bc.f, bc.args...) && all(is_linear, bc.args) +end + +to_linear(x) = x +function to_linear(bc::BC.Broadcasted) + return LinearBroadcastFunction(bc.f)(to_linear.(bc.args)...) +end + +function broadcast_error(style, f) + return throw( + ArgumentError( + "Only linear broadcast operations are supported for `$style`, got `$f`." + ) + ) +end +function broadcasted_linear(style::BC.BroadcastStyle, f, args...) + bc = BC.Broadcasted(style, f, args) + is_linear(bc) || broadcast_error(style, f) + return to_linear(bc) +end +function broadcasted_linear(f, args...) + return broadcasted_linear(BC.combine_styles(args...), f, args...) +end + +# Convert LinearBroadcasted / Mul back to Broadcasted for non-linear contexts. +to_broadcasted(x) = x +function to_broadcasted(a::AbstractArray) + (BC.BroadcastStyle(typeof(a)) isa LinearBroadcastedStyle) || return a + return BC.broadcasted(operation(a), to_broadcasted.(arguments(a))...) +end +function to_broadcasted(a::LinearBroadcasted) + return BC.broadcasted(operation(a), to_broadcasted.(arguments(a))...) +end +# Matmul isn't a broadcasting operation so we materialize when building a +# broadcast expression involving a Mul. +to_broadcasted(a::Mul) = *(factors(a)...) +to_broadcasted(bc::BC.Broadcasted) = BC.Broadcasted(bc.f, to_broadcasted.(bc.args)) + +# LinearBroadcastedStyle for broadcast interop. +struct LinearBroadcastedStyle{N, Style <: BC.AbstractArrayStyle{N}} <: BC.AbstractArrayStyle{N} + style::Style +end +# TODO: This empty constructor is required in some Julia versions below v1.12 (such as +# Julia v1.10), try deleting it once we drop support for those versions. +function LinearBroadcastedStyle{N, Style}() where {N, Style <: BC.AbstractArrayStyle{N}} + return LinearBroadcastedStyle{N, Style}(Style()) +end +function LinearBroadcastedStyle{N, Style}(::Val{M}) where {M, N, Style <: BC.AbstractArrayStyle{N}} + return LinearBroadcastedStyle(Style(Val(M))) +end +function BC.BroadcastStyle(style1::LinearBroadcastedStyle, style2::LinearBroadcastedStyle) + style = BC.BroadcastStyle(style1.style, style2.style) + style ≡ BC.Unknown() && return BC.Unknown() + return LinearBroadcastedStyle(style) +end +function Base.similar(bc::BC.Broadcasted{<:LinearBroadcastedStyle}, elt::Type, ax) + return similar(BC.Broadcasted(bc.style.style, bc.f, bc.args, bc.axes), elt, ax) +end + +# BroadcastStyle for LinearBroadcasted subtypes. +function BC.BroadcastStyle(::Type{<:ScaledBroadcasted{<:Any, <:Any, <:Any, A}}) where {A} + return LinearBroadcastedStyle(BC.BroadcastStyle(A)) +end +function BC.BroadcastStyle(::Type{<:ConjBroadcasted{<:Any, <:Any, A}}) where {A} + return LinearBroadcastedStyle(BC.BroadcastStyle(A)) +end +function BC.BroadcastStyle(::Type{<:AddBroadcasted{<:Any, <:Any, Args}}) where {Args} + style = Base.promote_op(BC.combine_styles, fieldtypes(Args)...)() + return LinearBroadcastedStyle(style) +end +function BC.BroadcastStyle(::Type{<:Mul{<:Any, <:Any, A, B}}) where {A, B} + style = BC.BroadcastStyle(BC.BroadcastStyle(A), BC.BroadcastStyle(B)) + return LinearBroadcastedStyle(style) +end + +# Broadcast.materialize for LinearBroadcasted and Mul. +BC.materialize(a::LinearBroadcasted) = copy(a) +BC.materialize(a::Mul) = copy(a) + +# Backup definition: for broadcast operations that don't preserve lazy types +# (such as nonlinear operations), convert back to Broadcasted expressions. +BC.broadcasted(::LinearBroadcastedStyle, f, args...) = BC.Broadcasted(f, to_broadcasted.(args)) + +# Linear broadcast operations produce LinearBroadcasted / Mul types. +function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(+), a::AbstractArray, b::AbstractArray) + return LinearBroadcastFunction(+)(a, b) +end +function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(+), a::AbstractArray, b::BC.Broadcasted) + is_linear(b) || return BC.Broadcasted(+, to_broadcasted.((a, b))) + return LinearBroadcastFunction(+)(a, to_linear(b)) +end +function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(+), a::BC.Broadcasted, b::AbstractArray) + is_linear(a) || return BC.Broadcasted(+, to_broadcasted.((a, b))) + return LinearBroadcastFunction(+)(to_linear(a), b) +end +function BC.broadcasted( + ::LinearBroadcastedStyle, ::typeof(+), a::BC.Broadcasted, b::BC.Broadcasted + ) + return error("Not implemented") +end +function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(*), α::Number, a::AbstractArray) + return LinearBroadcastFunction(*)(α, a) +end +function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(*), a::AbstractArray, α::Number) + return LinearBroadcastFunction(*)(a, α) +end +function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(\), α::Number, a::AbstractArray) + return LinearBroadcastFunction(\)(α, a) +end +function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(/), a::AbstractArray, α::Number) + return LinearBroadcastFunction(/)(a, α) +end +function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(-), a::AbstractArray) + return LinearBroadcastFunction(-)(a) +end +function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(conj), a::AbstractArray) + return LinearBroadcastFunction(conj)(a) +end diff --git a/test/Project.toml b/test/Project.toml index 5a4045d2..5c0c3426 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.7" +TensorAlgebra = "0.8" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" diff --git a/test/test_lazy.jl b/test/test_lazy.jl index 4d3265fc..4435dd3e 100644 --- a/test/test_lazy.jl +++ b/test/test_lazy.jl @@ -1,77 +1,80 @@ import FunctionImplementations as FI using Base.Broadcast: Broadcast as BC -using TensorAlgebra: TensorAlgebra as TA, *ₗ, +ₗ, /ₗ, conjed -using Test: @test, @test_broken, @test_throws, @testset +using TensorAlgebra: TensorAlgebra as TA +using Test: @test, @test_throws, @testset -@testset "lazy arrays" begin - @testset "lazy array operations" begin +const lbf = TA.LinearBroadcastFunction + +@testset "LinearBroadcasted and Mul" begin + @testset "construction and materialization" begin a = randn(ComplexF64, 3, 3) b = randn(ComplexF64, 3, 3) c = randn(ComplexF64, 3, 3) - x = 2 *ₗ a - @test x ≡ TA.ScaledArray(2, a) + x = lbf(*)(2, a) + @test x ≡ TA.ScaledBroadcasted(2, a) @test copy(x) ≈ 2a - x = conjed(a) - @test x ≡ TA.ConjArray(a) + x = lbf(conj)(a) + @test x ≡ TA.ConjBroadcasted(a) @test copy(x) ≈ conj(a) @test conj(x) ≈ a - x = a +ₗ b - @test x ≡ TA.AddArray(a, b) + x = lbf(+)(a, b) + @test x ≡ TA.AddBroadcasted(a, b) @test copy(x) ≈ a + b - x = a *ₗ b - @test x ≡ TA.MulArray(a, b) + x = TA.Mul(a, b) @test copy(x) ≈ a * b - x = a *ₗ b +ₗ c - @test x ≡ TA.AddArray(TA.MulArray(a, b), c) - @test copy(x) ≈ a *ₗ b .+ c ≈ a * b + c + x = lbf(+)(TA.Mul(a, b), c) + @test x ≡ TA.AddBroadcasted(TA.Mul(a, b), c) + @test copy(x) ≈ a * b + c - x = 2 *ₗ a *ₗ b +ₗ 3 *ₗ c - @test x ≡ TA.AddArray(TA.ScaledArray(2, TA.MulArray(a, b)), TA.ScaledArray(3, c)) - @test copy(x) ≈ 2 .* a *ₗ b .+ 3 .* c ≈ 2 * a * b + 3 * c + x = lbf(+)(lbf(*)(2, TA.Mul(a, b)), lbf(*)(3, c)) + @test x ≡ TA.AddBroadcasted( + TA.ScaledBroadcasted(2, TA.Mul(a, b)), TA.ScaledBroadcasted(3, c) + ) + @test copy(x) ≈ 2 * a * b + 3 * c end @testset "adjoint" begin a = randn(ComplexF64, 2, 2) b = randn(ComplexF64, 2, 2) - x = (2 *ₗ a)' - @test x ≡ 2 *ₗ a' + x = lbf(*)(2, a)' + @test x ≡ lbf(*)(2, a') @test copy(x) ≈ 2a' - x = conjed(a)' + x = lbf(conj)(a)' @test x ≡ transpose(a) @test copy(x) ≈ permutedims(a) - x = (a +ₗ b)' - @test x ≡ a' +ₗ b' + x = lbf(+)(a, b)' + @test x ≡ lbf(+)(a', b') @test copy(x) ≈ a' + b' - x = (a *ₗ b)' - @test x ≡ b' *ₗ a' + x = TA.Mul(a, b)' + @test x ≡ TA.Mul(b', a') @test copy(x) ≈ b' * a' end @testset "transpose" begin a = randn(ComplexF64, 2, 2) b = randn(ComplexF64, 2, 2) - x = transpose(2 *ₗ a) - @test x ≡ 2 *ₗ transpose(a) + x = transpose(lbf(*)(2, a)) + @test x ≡ lbf(*)(2, transpose(a)) @test copy(x) ≈ 2transpose(a) - x = transpose(conjed(a)) + x = transpose(lbf(conj)(a)) @test x ≡ adjoint(a) @test copy(x) ≈ permutedims(conj(a)) - x = transpose(a +ₗ b) - @test x ≡ transpose(a) +ₗ transpose(b) + x = transpose(lbf(+)(a, b)) + @test x ≡ lbf(+)(transpose(a), transpose(b)) @test copy(x) ≈ transpose(a) + transpose(b) - x = transpose(a *ₗ b) - @test x ≡ transpose(b) *ₗ transpose(a) + x = transpose(TA.Mul(a, b)) + @test x ≡ TA.Mul(transpose(b), transpose(a)) @test copy(x) ≈ transpose(b) * transpose(a) end @testset "permuteddims" begin @@ -79,20 +82,19 @@ using Test: @test, @test_broken, @test_throws, @testset b = randn(ComplexF64, 2, 2) perm = (2, 1) - x = FI.permuteddims(2 *ₗ a, perm) - @test x ≡ 2 *ₗ FI.permuteddims(a, perm) + x = FI.permuteddims(lbf(*)(2, a), perm) + @test x ≡ lbf(*)(2, FI.permuteddims(a, perm)) @test copy(x) ≈ 2permutedims(a, perm) - x = FI.permuteddims(conjed(a), perm) - @test x ≡ conjed(FI.permuteddims(a, perm)) + x = FI.permuteddims(lbf(conj)(a), perm) + @test x ≡ lbf(conj)(FI.permuteddims(a, perm)) @test copy(x) ≈ conj(permutedims(a, perm)) - x = FI.permuteddims(a +ₗ b, perm) - @test x ≡ FI.permuteddims(a, perm) +ₗ FI.permuteddims(b, perm) + x = FI.permuteddims(lbf(+)(a, b), perm) + @test x ≡ lbf(+)(FI.permuteddims(a, perm), FI.permuteddims(b, perm)) @test copy(x) ≈ permutedims(a, perm) + permutedims(b, perm) - x = FI.permuteddims(a *ₗ b, perm) - @test x ≡ PermutedDimsArray(a *ₗ b, perm) + x = FI.permuteddims(TA.Mul(a, b), perm) @test copy(x) ≈ permutedims(a * b, perm) end @testset "linear broadcast lowering" begin @@ -100,24 +102,77 @@ using Test: @test, @test_broken, @test_throws, @testset style = BC.DefaultArrayStyle{2}() @test TA.broadcasted_linear(identity, a) ≡ a - @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ 2 *ₗ a - @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ a *ₗ 2 - @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ a /ₗ 2 + @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ lbf(*)(2, a) + @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ lbf(*)(a, 2) + @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ lbf(/)(a, 2) @test TA.broadcasted_linear(style, identity, a) ≡ a - @test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ 2 *ₗ a - @test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ a *ₗ 2 - @test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ a /ₗ 2 - @test TA.broadcasted_linear(style, conj, a) ≡ conjed(a) + @test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ lbf(*)(2, a) + @test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ lbf(*)(a, 2) + @test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ lbf(/)(a, 2) + @test TA.broadcasted_linear(style, conj, a) ≡ lbf(conj)(a) @test_throws ArgumentError TA.broadcasted_linear(style, exp, a) end - @testset "scalar getindex" begin + @testset "LinearBroadcastFunction algebra" begin + a = randn(ComplexF64, 3, 3) + + # Scaling absorbs coefficients + @test lbf(*)(3, lbf(*)(2, a)) ≡ TA.ScaledBroadcasted(6, a) + + # Conjugation of scaled + x = lbf(conj)(lbf(*)(2im, a)) + @test x ≡ TA.ScaledBroadcasted(-2im, TA.ConjBroadcasted(a)) + + # Double conjugation cancels + @test lbf(conj)(lbf(conj)(a)) ≡ a + + # Subtraction + b = randn(ComplexF64, 3, 3) + x = lbf(-)(a, b) + @test copy(x) ≈ a - b + + # Unary minus + x = lbf(-)(a) + @test copy(x) ≈ -a + + # Division + x = lbf(/)(a, 2) + @test copy(x) ≈ a / 2 + + # Left division + x = lbf(\)(2, a) + @test copy(x) ≈ a / 2 + + # Scaling distributes over AddBroadcasted + ab = lbf(+)(a, b) + x = lbf(*)(3, ab) + @test copy(x) ≈ 3a + 3b + + # Conjugation distributes over AddBroadcasted + x = lbf(conj)(ab) + @test copy(x) ≈ conj(a) + conj(b) + + # Conjugation distributes over Mul + m = TA.Mul(a, b) + x = lbf(conj)(m) + @test copy(x) ≈ conj(a) * conj(b) + end + @testset "AddBroadcasted flattening" begin a = randn(ComplexF64, 2, 2) b = randn(ComplexF64, 2, 2) + c = randn(ComplexF64, 2, 2) + + # AddBroadcasted + array flattens + ab = lbf(+)(a, b) + x = lbf(+)(ab, c) + @test TA.addends(x) === (a, b, c) + + # array + AddBroadcasted flattens + x = lbf(+)(c, ab) + @test TA.addends(x) === (c, a, b) - @test (2 *ₗ a)[1, 2] == 2 * a[1, 2] - @test conjed(a)[2, 1] == conj(a[2, 1]) - @test (a +ₗ b)[2, 2] == a[2, 2] + b[2, 2] - @test (a *ₗ b)[1, 2] ≈ (a * b)[1, 2] - @test (a *ₗ b)[3] ≈ (a * b)[3] + # AddBroadcasted + AddBroadcasted flattens + cd = lbf(+)(c, a) + x = lbf(+)(ab, cd) + @test TA.addends(x) === (a, b, c, a) end end From bcc35d5048bc106709dd36085c8b4c5a4969d987 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 00:02:24 -0400 Subject: [PATCH 02/29] Remove T and N type parameters from LinearBroadcasted types and Mul These were needed for AbstractArray{T,N} subtyping and are harder to define for ITensorBase where N isn't well-defined. eltype/ndims are now computed from the wrapped data instead. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 95 ++++++++++++++++++++-------------------- 1 file changed, 47 insertions(+), 48 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 8ee7c5ae..eaf9b24a 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -29,23 +29,19 @@ Base.axes(a::LinearBroadcasted, d::Int) = axes(a)[d] # --- ScaledBroadcasted -------------------------------------------------------- -struct ScaledBroadcasted{T, N, C <: Number, A} <: LinearBroadcasted +struct ScaledBroadcasted{C <: Number, A} <: LinearBroadcasted coeff::C parent::A - function ScaledBroadcasted(coeff::Number, a) - T = Base.promote_op(*, typeof(coeff), eltype(a)) - return new{T, ndims(a), typeof(coeff), typeof(a)}(coeff, a) - end end unscaled(a::ScaledBroadcasted) = a.parent coeff(a::ScaledBroadcasted) = a.coeff Base.axes(a::ScaledBroadcasted) = axes(unscaled(a)) -Base.eltype(::Type{<:ScaledBroadcasted{T}}) where {T} = T -Base.eltype(::ScaledBroadcasted{T}) where {T} = T -Base.ndims(::Type{<:ScaledBroadcasted{<:Any, N}}) where {N} = N -Base.ndims(::ScaledBroadcasted{<:Any, N}) where {N} = N +function Base.eltype(a::ScaledBroadcasted) + return Base.promote_op(*, typeof(coeff(a)), eltype(unscaled(a))) +end +Base.ndims(a::ScaledBroadcasted) = ndims(unscaled(a)) function Base.similar(a::ScaledBroadcasted) return similar(unscaled(a), eltype(a), axes(a)) @@ -76,20 +72,15 @@ arguments(a::ScaledBroadcasted) = (coeff(a), unscaled(a)) # --- ConjBroadcasted ---------------------------------------------------------- -struct ConjBroadcasted{T, N, A} <: LinearBroadcasted +struct ConjBroadcasted{A} <: LinearBroadcasted parent::A - function ConjBroadcasted(a) - return new{eltype(a), ndims(a), typeof(a)}(a) - end end unconj(a::ConjBroadcasted) = a.parent Base.axes(a::ConjBroadcasted) = axes(unconj(a)) -Base.eltype(::Type{<:ConjBroadcasted{T}}) where {T} = T -Base.eltype(::ConjBroadcasted{T}) where {T} = T -Base.ndims(::Type{<:ConjBroadcasted{<:Any, N}}) where {N} = N -Base.ndims(::ConjBroadcasted{<:Any, N}) where {N} = N +Base.eltype(a::ConjBroadcasted) = eltype(unconj(a)) +Base.ndims(a::ConjBroadcasted) = ndims(unconj(a)) function Base.similar(a::ConjBroadcasted) return similar(unconj(a), eltype(a), axes(a)) @@ -120,26 +111,21 @@ arguments(a::ConjBroadcasted) = (unconj(a),) # --- AddBroadcasted ----------------------------------------------------------- -struct AddBroadcasted{T, N, Args <: Tuple} <: LinearBroadcasted +struct AddBroadcasted{Args <: Tuple} <: LinearBroadcasted args::Args function AddBroadcasted(args...) - T = Base.promote_op(+, eltype.(args)...) - N = if allequal(ndims, args) - ndims(first(args)) - else + if !allequal(ndims, args) error("All addends must have the same number of dimensions.") end - return new{T, N, typeof(args)}(args) + return new{typeof(args)}(args) end end addends(a::AddBroadcasted) = a.args Base.axes(a::AddBroadcasted) = BC.combine_axes(addends(a)...) -Base.eltype(::Type{<:AddBroadcasted{T}}) where {T} = T -Base.eltype(::AddBroadcasted{T}) where {T} = T -Base.ndims(::Type{<:AddBroadcasted{<:Any, N}}) where {N} = N -Base.ndims(::AddBroadcasted{<:Any, N}) where {N} = N +Base.eltype(a::AddBroadcasted) = Base.promote_op(+, eltype.(addends(a))...) +Base.ndims(a::AddBroadcasted) = ndims(first(addends(a))) function Base.similar(a::AddBroadcasted) return similar(BC.Broadcasted(+, addends(a)), eltype(a)) @@ -175,24 +161,17 @@ arguments(a::AddBroadcasted) = addends(a) # Same as `LinearAlgebra.matprod`, but duplicated here since it is private. matprod(x, y) = x * y + x * y -struct Mul{T, N, A, B} +struct Mul{A, B} a::A b::B - function Mul(a, b) - T = Base.promote_op(matprod, eltype(a), eltype(b)) - N = ndims(b) - return new{T, N, typeof(a), typeof(b)}(a, b) - end end factors(a::Mul) = (a.a, a.b) Base.axes(a::Mul) = (axes(a.a, 1), axes(a.b, ndims(a.b))) Base.axes(a::Mul, d::Int) = axes(a)[d] -Base.eltype(::Type{<:Mul{T}}) where {T} = T -Base.eltype(::Mul{T}) where {T} = T -Base.ndims(::Type{<:Mul{<:Any, N}}) where {N} = N -Base.ndims(::Mul{<:Any, N}) where {N} = N +Base.eltype(a::Mul) = Base.promote_op(matprod, eltype(a.a), eltype(a.b)) +Base.ndims(a::Mul) = ndims(a.b) Base.size(a::Mul) = length.(axes(a)) function Base.similar(a::Mul) @@ -217,7 +196,7 @@ function Base.transpose(a::Mul) return Mul(transpose(f[2]), transpose(f[1])) end -function FI.permuteddims(a::Mul{<:Any, 2}, perm) +function FI.permuteddims(a::Mul, perm) perm == (1, 2) && return a return transpose(a) end @@ -484,7 +463,8 @@ to_broadcasted(a::Mul) = *(factors(a)...) to_broadcasted(bc::BC.Broadcasted) = BC.Broadcasted(bc.f, to_broadcasted.(bc.args)) # LinearBroadcastedStyle for broadcast interop. -struct LinearBroadcastedStyle{N, Style <: BC.AbstractArrayStyle{N}} <: BC.AbstractArrayStyle{N} +struct LinearBroadcastedStyle{N, Style <: BC.AbstractArrayStyle{N}} <: + BC.AbstractArrayStyle{N} style::Style end # TODO: This empty constructor is required in some Julia versions below v1.12 (such as @@ -492,7 +472,9 @@ end function LinearBroadcastedStyle{N, Style}() where {N, Style <: BC.AbstractArrayStyle{N}} return LinearBroadcastedStyle{N, Style}(Style()) end -function LinearBroadcastedStyle{N, Style}(::Val{M}) where {M, N, Style <: BC.AbstractArrayStyle{N}} +function LinearBroadcastedStyle{N, Style}( + ::Val{M} + ) where {M, N, Style <: BC.AbstractArrayStyle{N}} return LinearBroadcastedStyle(Style(Val(M))) end function BC.BroadcastStyle(style1::LinearBroadcastedStyle, style2::LinearBroadcastedStyle) @@ -505,17 +487,17 @@ function Base.similar(bc::BC.Broadcasted{<:LinearBroadcastedStyle}, elt::Type, a end # BroadcastStyle for LinearBroadcasted subtypes. -function BC.BroadcastStyle(::Type{<:ScaledBroadcasted{<:Any, <:Any, <:Any, A}}) where {A} +function BC.BroadcastStyle(::Type{<:ScaledBroadcasted{<:Any, A}}) where {A} return LinearBroadcastedStyle(BC.BroadcastStyle(A)) end -function BC.BroadcastStyle(::Type{<:ConjBroadcasted{<:Any, <:Any, A}}) where {A} +function BC.BroadcastStyle(::Type{<:ConjBroadcasted{A}}) where {A} return LinearBroadcastedStyle(BC.BroadcastStyle(A)) end -function BC.BroadcastStyle(::Type{<:AddBroadcasted{<:Any, <:Any, Args}}) where {Args} +function BC.BroadcastStyle(::Type{<:AddBroadcasted{Args}}) where {Args} style = Base.promote_op(BC.combine_styles, fieldtypes(Args)...)() return LinearBroadcastedStyle(style) end -function BC.BroadcastStyle(::Type{<:Mul{<:Any, <:Any, A, B}}) where {A, B} +function BC.BroadcastStyle(::Type{<:Mul{A, B}}) where {A, B} style = BC.BroadcastStyle(BC.BroadcastStyle(A), BC.BroadcastStyle(B)) return LinearBroadcastedStyle(style) end @@ -526,17 +508,34 @@ BC.materialize(a::Mul) = copy(a) # Backup definition: for broadcast operations that don't preserve lazy types # (such as nonlinear operations), convert back to Broadcasted expressions. -BC.broadcasted(::LinearBroadcastedStyle, f, args...) = BC.Broadcasted(f, to_broadcasted.(args)) +function BC.broadcasted(::LinearBroadcastedStyle, f, args...) + return BC.Broadcasted(f, to_broadcasted.(args)) +end # Linear broadcast operations produce LinearBroadcasted / Mul types. -function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(+), a::AbstractArray, b::AbstractArray) +function BC.broadcasted( + ::LinearBroadcastedStyle, + ::typeof(+), + a::AbstractArray, + b::AbstractArray + ) return LinearBroadcastFunction(+)(a, b) end -function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(+), a::AbstractArray, b::BC.Broadcasted) +function BC.broadcasted( + ::LinearBroadcastedStyle, + ::typeof(+), + a::AbstractArray, + b::BC.Broadcasted + ) is_linear(b) || return BC.Broadcasted(+, to_broadcasted.((a, b))) return LinearBroadcastFunction(+)(a, to_linear(b)) end -function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(+), a::BC.Broadcasted, b::AbstractArray) +function BC.broadcasted( + ::LinearBroadcastedStyle, + ::typeof(+), + a::BC.Broadcasted, + b::AbstractArray + ) is_linear(a) || return BC.Broadcasted(+, to_broadcasted.((a, b))) return LinearBroadcastFunction(+)(to_linear(a), b) end From e5c8ef2493d40804c0e9bd6779a92f799ef78144 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 00:11:59 -0400 Subject: [PATCH 03/29] Simplify similar hierarchy and copyto! for Mul MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add generic similar(::LinearBroadcasted) → similar(a, eltype(a)) and similar(::LinearBroadcasted, elt) → similar(a, elt, axes(a)) fallbacks - Each subtype now only defines the 3-arg similar(a, elt, ax) - Remove redundant true/false defaults from copyto!(dest, ::Mul) Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index eaf9b24a..3b58f446 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -24,8 +24,10 @@ copyto!(dest, lb) → add!(dest, lb, 1, 0) """ abstract type LinearBroadcasted end -# Generic axes(a, d) for LinearBroadcasted subtypes. +# Generic interface for LinearBroadcasted subtypes. Base.axes(a::LinearBroadcasted, d::Int) = axes(a)[d] +Base.similar(a::LinearBroadcasted) = similar(a, eltype(a)) +Base.similar(a::LinearBroadcasted, elt::Type) = similar(a, elt, axes(a)) # --- ScaledBroadcasted -------------------------------------------------------- @@ -43,11 +45,8 @@ function Base.eltype(a::ScaledBroadcasted) end Base.ndims(a::ScaledBroadcasted) = ndims(unscaled(a)) -function Base.similar(a::ScaledBroadcasted) - return similar(unscaled(a), eltype(a), axes(a)) -end -function Base.similar(a::ScaledBroadcasted, elt::Type) - return similar(unscaled(a), elt, axes(a)) +function Base.similar(a::ScaledBroadcasted, elt::Type, ax) + return similar(unscaled(a), elt, ax) end function Base.show(io::IO, a::ScaledBroadcasted) @@ -82,11 +81,8 @@ Base.axes(a::ConjBroadcasted) = axes(unconj(a)) Base.eltype(a::ConjBroadcasted) = eltype(unconj(a)) Base.ndims(a::ConjBroadcasted) = ndims(unconj(a)) -function Base.similar(a::ConjBroadcasted) - return similar(unconj(a), eltype(a), axes(a)) -end -function Base.similar(a::ConjBroadcasted, elt::Type) - return similar(unconj(a), elt, axes(a)) +function Base.similar(a::ConjBroadcasted, elt::Type, ax) + return similar(unconj(a), elt, ax) end function Base.show(io::IO, a::ConjBroadcasted) @@ -127,11 +123,8 @@ Base.axes(a::AddBroadcasted) = BC.combine_axes(addends(a)...) Base.eltype(a::AddBroadcasted) = Base.promote_op(+, eltype.(addends(a))...) Base.ndims(a::AddBroadcasted) = ndims(first(addends(a))) -function Base.similar(a::AddBroadcasted) - return similar(BC.Broadcasted(+, addends(a)), eltype(a)) -end -function Base.similar(a::AddBroadcasted, elt::Type) - return similar(BC.Broadcasted(+, addends(a)), elt) +function Base.similar(a::AddBroadcasted, elt::Type, ax) + return similar(BC.Broadcasted(+, addends(a)), elt, ax) end function Base.show(io::IO, a::AddBroadcasted) @@ -174,11 +167,10 @@ Base.eltype(a::Mul) = Base.promote_op(matprod, eltype(a.a), eltype(a.b)) Base.ndims(a::Mul) = ndims(a.b) Base.size(a::Mul) = length.(axes(a)) -function Base.similar(a::Mul) - return similar(BC.materialize(last(factors(a))), eltype(a), axes(a)) -end -function Base.similar(a::Mul, elt::Type) - return similar(BC.materialize(last(factors(a))), elt, axes(a)) +Base.similar(a::Mul) = similar(a, eltype(a)) +Base.similar(a::Mul, elt::Type) = similar(a, elt, axes(a)) +function Base.similar(a::Mul, elt::Type, ax) + return similar(BC.materialize(last(factors(a))), elt, ax) end function Base.show(io::IO, a::Mul) @@ -231,7 +223,7 @@ end # copyto! for Mul dispatches to mul!. Materialize factors first since # they may be LinearBroadcasted types. function Base.copyto!(dest::AbstractArray, src::Mul) - return LA.mul!(dest, BC.materialize.(factors(src))..., true, false) + return LA.mul!(dest, BC.materialize.(factors(src))...) end # add! for LinearBroadcasted subtypes. From 0c15f1779093765f73562550481d29e04afeb43b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 00:14:09 -0400 Subject: [PATCH 04/29] Simplify add! for ConjBroadcasted Remove Val(:conj) hack. Just pass conj(unconj(src)) to add!, which uses Base's lazy conj wrapper on the underlying array. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 3b58f446..139435a9 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -232,12 +232,7 @@ function add!(dest::AbstractArray, src::ScaledBroadcasted, α::Number, β::Numbe end function add!(dest::AbstractArray, src::ConjBroadcasted, α::Number, β::Number) - return add!(dest, unconj(src), α, β, Val(:conj)) -end - -# Default conj add! falls back to materializing conj. -function add!(dest::AbstractArray, src::AbstractArray, α::Number, β::Number, ::Val{:conj}) - return add!(dest, conj(src), α, β) + return add!(dest, conj(unconj(src)), α, β) end function add!(dest::AbstractArray, src::AddBroadcasted, α::Number, β::Number) From 1fba31cdd04eb45b5adcdf414c27601ed2549c47 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 00:26:25 -0400 Subject: [PATCH 05/29] Lift copyto!, show, and iscall to generic LinearBroadcasted methods - copyto!(dest, src::LinearBroadcasted) generic instead of per-subtype - show(io, a::LinearBroadcasted) generic via operation/arguments - iscall(::LinearBroadcasted) generic instead of per-subtype Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 31 ++++++------------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 139435a9..e186a8ef 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -28,6 +28,11 @@ abstract type LinearBroadcasted end Base.axes(a::LinearBroadcasted, d::Int) = axes(a)[d] Base.similar(a::LinearBroadcasted) = similar(a, eltype(a)) Base.similar(a::LinearBroadcasted, elt::Type) = similar(a, elt, axes(a)) +function Base.show(io::IO, a::LinearBroadcasted) + print(io, operation(a), "(", join(arguments(a), ", "), ")") + return nothing +end +iscall(::LinearBroadcasted) = true # --- ScaledBroadcasted -------------------------------------------------------- @@ -49,11 +54,6 @@ function Base.similar(a::ScaledBroadcasted, elt::Type, ax) return similar(unscaled(a), elt, ax) end -function Base.show(io::IO, a::ScaledBroadcasted) - print(io, "*(", coeff(a), ", ", unscaled(a), ")") - return nothing -end - function Base.adjoint(a::ScaledBroadcasted) return ScaledBroadcasted(coeff(a), adjoint(unscaled(a))) end @@ -65,7 +65,6 @@ function FI.permuteddims(a::ScaledBroadcasted, perm) return ScaledBroadcasted(coeff(a), FI.permuteddims(unscaled(a), perm)) end -iscall(::ScaledBroadcasted) = true operation(::ScaledBroadcasted) = * arguments(a::ScaledBroadcasted) = (coeff(a), unscaled(a)) @@ -85,11 +84,6 @@ function Base.similar(a::ConjBroadcasted, elt::Type, ax) return similar(unconj(a), elt, ax) end -function Base.show(io::IO, a::ConjBroadcasted) - print(io, "conj(", unconj(a), ")") - return nothing -end - Base.conj(a::ConjBroadcasted) = unconj(a) Base.adjoint(a::ConjBroadcasted) = transpose(unconj(a)) Base.transpose(a::ConjBroadcasted) = adjoint(unconj(a)) @@ -101,7 +95,6 @@ end SV.isstrided(a::ConjBroadcasted) = SV.isstrided(unconj(a)) SV.StridedView(a::ConjBroadcasted) = conj(SV.StridedView(unconj(a))) -iscall(::ConjBroadcasted) = true operation(::ConjBroadcasted) = conj arguments(a::ConjBroadcasted) = (unconj(a),) @@ -127,11 +120,6 @@ function Base.similar(a::AddBroadcasted, elt::Type, ax) return similar(BC.Broadcasted(+, addends(a)), elt, ax) end -function Base.show(io::IO, a::AddBroadcasted) - print(io, "+(", join(addends(a), ", "), ")") - return nothing -end - function Base.adjoint(a::AddBroadcasted) return AddBroadcasted(adjoint.(addends(a))...) end @@ -143,7 +131,6 @@ function FI.permuteddims(a::AddBroadcasted, perm) return AddBroadcasted(Base.Fix2(FI.permuteddims, perm).(addends(a))...) end -iscall(::AddBroadcasted) = true operation(::AddBroadcasted) = + arguments(a::AddBroadcasted) = addends(a) @@ -210,13 +197,7 @@ function Base.copy(a::Mul) end # copyto! for LinearBroadcasted dispatches to add!. -function Base.copyto!(dest::AbstractArray, src::ScaledBroadcasted) - return add!(dest, src, true, false) -end -function Base.copyto!(dest::AbstractArray, src::ConjBroadcasted) - return add!(dest, src, true, false) -end -function Base.copyto!(dest::AbstractArray, src::AddBroadcasted) +function Base.copyto!(dest::AbstractArray, src::LinearBroadcasted) return add!(dest, src, true, false) end From 84f1627d84453c32fde26871ca96dd2453be3875 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 00:31:05 -0400 Subject: [PATCH 06/29] Add LBF as shorthand for LinearBroadcastFunction Use LBF (uppercase since it aliases a type) throughout source and tests to reduce verbosity of LinearBroadcastFunction call sites. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 61 +++++++++++++++----------- test/test_lazy.jl | 92 ++++++++++++++++++++-------------------- 2 files changed, 81 insertions(+), 72 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index e186a8ef..2638da6f 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -252,6 +252,21 @@ struct LinearBroadcastFunction{F} <: Function f::F end +""" + LBF + +Shorthand for `LinearBroadcastFunction`. + +# Examples + +```julia +LBF(*)(2.0, a) # ScaledBroadcasted(2.0, a) +LBF(conj)(a) # ConjBroadcasted(a) +LBF(+)(a, b) # AddBroadcasted(a, b) +``` +""" +const LBF = LinearBroadcastFunction + # Scaling: Number * AbstractArray function (::LinearBroadcastFunction{typeof(*)})(α::Number, a::AbstractArray) return ScaledBroadcasted(α, a) @@ -272,7 +287,7 @@ end (::LinearBroadcastFunction{typeof(conj)})(a::ConjBroadcasted) = unconj(a) function (::LinearBroadcastFunction{typeof(conj)})(a::ScaledBroadcasted) return ScaledBroadcasted( - conj(coeff(a)), LinearBroadcastFunction(conj)(unscaled(a)) + conj(coeff(a)), LBF(conj)(unscaled(a)) ) end @@ -297,16 +312,16 @@ end # Subtraction. function (::LinearBroadcastFunction{typeof(-)})(a, b) - return LinearBroadcastFunction(+)(a, LinearBroadcastFunction(*)(- 1, b)) + return LBF(+)(a, LBF(*)(- 1, b)) end -(::LinearBroadcastFunction{typeof(-)})(a) = LinearBroadcastFunction(*)(-1, a) +(::LinearBroadcastFunction{typeof(-)})(a) = LBF(*)(-1, a) # Division / left-division by scalars. function (::LinearBroadcastFunction{typeof(/)})(a, b::Number) - return LinearBroadcastFunction(*)(inv(b), a) + return LBF(*)(inv(b), a) end function (::LinearBroadcastFunction{typeof(\)})(a::Number, b) - return LinearBroadcastFunction(*)(inv(a), b) + return LBF(*)(inv(a), b) end # Identity. @@ -314,33 +329,29 @@ end # Fix1/Fix2 wrappers for scalar multiplication/division. function (lf::LinearBroadcastFunction{<:Base.Fix1{typeof(*)}})(a) - return LinearBroadcastFunction(*)(lf.f.x, a) + return LBF(*)(lf.f.x, a) end function (lf::LinearBroadcastFunction{<:Base.Fix2{typeof(*)}})(a) - return LinearBroadcastFunction(*)(a, lf.f.x) + return LBF(*)(a, lf.f.x) end function (lf::LinearBroadcastFunction{<:Base.Fix2{typeof(/)}})(a) - return LinearBroadcastFunction(/)(a, lf.f.x) + return LBF(/)(a, lf.f.x) end # Scaling of AddBroadcasted distributes. function (::LinearBroadcastFunction{typeof(*)})(α::Number, a::AddBroadcasted) - return LinearBroadcastFunction(+)( - map(x -> LinearBroadcastFunction(*)(α, x), addends(a))... - ) + return LBF(+)(map(x -> LBF(*)(α, x), addends(a))...) end # Conjugation of AddBroadcasted distributes. function (::LinearBroadcastFunction{typeof(conj)})(a::AddBroadcasted) - return LinearBroadcastFunction(+)( - map(x -> LinearBroadcastFunction(conj)(x), addends(a))... - ) + return LBF(+)(map(x -> LBF(conj)(x), addends(a))...) end # Conjugation of Mul distributes. function (::LinearBroadcastFunction{typeof(conj)})(a::Mul) f = factors(a) - return Mul(LinearBroadcastFunction(conj)(f[1]), LinearBroadcastFunction(conj)(f[2])) + return Mul(LBF(conj)(f[1]), LBF(conj)(f[2])) end # Scaling of Mul: wrap in ScaledBroadcasted. @@ -397,7 +408,7 @@ end to_linear(x) = x function to_linear(bc::BC.Broadcasted) - return LinearBroadcastFunction(bc.f)(to_linear.(bc.args)...) + return LBF(bc.f)(to_linear.(bc.args)...) end function broadcast_error(style, f) @@ -487,7 +498,7 @@ function BC.broadcasted( a::AbstractArray, b::AbstractArray ) - return LinearBroadcastFunction(+)(a, b) + return LBF(+)(a, b) end function BC.broadcasted( ::LinearBroadcastedStyle, @@ -496,7 +507,7 @@ function BC.broadcasted( b::BC.Broadcasted ) is_linear(b) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return LinearBroadcastFunction(+)(a, to_linear(b)) + return LBF(+)(a, to_linear(b)) end function BC.broadcasted( ::LinearBroadcastedStyle, @@ -505,7 +516,7 @@ function BC.broadcasted( b::AbstractArray ) is_linear(a) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return LinearBroadcastFunction(+)(to_linear(a), b) + return LBF(+)(to_linear(a), b) end function BC.broadcasted( ::LinearBroadcastedStyle, ::typeof(+), a::BC.Broadcasted, b::BC.Broadcasted @@ -513,20 +524,20 @@ function BC.broadcasted( return error("Not implemented") end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(*), α::Number, a::AbstractArray) - return LinearBroadcastFunction(*)(α, a) + return LBF(*)(α, a) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(*), a::AbstractArray, α::Number) - return LinearBroadcastFunction(*)(a, α) + return LBF(*)(a, α) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(\), α::Number, a::AbstractArray) - return LinearBroadcastFunction(\)(α, a) + return LBF(\)(α, a) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(/), a::AbstractArray, α::Number) - return LinearBroadcastFunction(/)(a, α) + return LBF(/)(a, α) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(-), a::AbstractArray) - return LinearBroadcastFunction(-)(a) + return LBF(-)(a) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(conj), a::AbstractArray) - return LinearBroadcastFunction(conj)(a) + return LBF(conj)(a) end diff --git a/test/test_lazy.jl b/test/test_lazy.jl index 4435dd3e..83424e5b 100644 --- a/test/test_lazy.jl +++ b/test/test_lazy.jl @@ -1,37 +1,35 @@ import FunctionImplementations as FI using Base.Broadcast: Broadcast as BC -using TensorAlgebra: TensorAlgebra as TA +using TensorAlgebra: TensorAlgebra as TA, LBF using Test: @test, @test_throws, @testset -const lbf = TA.LinearBroadcastFunction - @testset "LinearBroadcasted and Mul" begin @testset "construction and materialization" begin a = randn(ComplexF64, 3, 3) b = randn(ComplexF64, 3, 3) c = randn(ComplexF64, 3, 3) - x = lbf(*)(2, a) + x = LBF(*)(2, a) @test x ≡ TA.ScaledBroadcasted(2, a) @test copy(x) ≈ 2a - x = lbf(conj)(a) + x = LBF(conj)(a) @test x ≡ TA.ConjBroadcasted(a) @test copy(x) ≈ conj(a) @test conj(x) ≈ a - x = lbf(+)(a, b) + x = LBF(+)(a, b) @test x ≡ TA.AddBroadcasted(a, b) @test copy(x) ≈ a + b x = TA.Mul(a, b) @test copy(x) ≈ a * b - x = lbf(+)(TA.Mul(a, b), c) + x = LBF(+)(TA.Mul(a, b), c) @test x ≡ TA.AddBroadcasted(TA.Mul(a, b), c) @test copy(x) ≈ a * b + c - x = lbf(+)(lbf(*)(2, TA.Mul(a, b)), lbf(*)(3, c)) + x = LBF(+)(LBF(*)(2, TA.Mul(a, b)), LBF(*)(3, c)) @test x ≡ TA.AddBroadcasted( TA.ScaledBroadcasted(2, TA.Mul(a, b)), TA.ScaledBroadcasted(3, c) ) @@ -41,16 +39,16 @@ const lbf = TA.LinearBroadcastFunction a = randn(ComplexF64, 2, 2) b = randn(ComplexF64, 2, 2) - x = lbf(*)(2, a)' - @test x ≡ lbf(*)(2, a') + x = LBF(*)(2, a)' + @test x ≡ LBF(*)(2, a') @test copy(x) ≈ 2a' - x = lbf(conj)(a)' + x = LBF(conj)(a)' @test x ≡ transpose(a) @test copy(x) ≈ permutedims(a) - x = lbf(+)(a, b)' - @test x ≡ lbf(+)(a', b') + x = LBF(+)(a, b)' + @test x ≡ LBF(+)(a', b') @test copy(x) ≈ a' + b' x = TA.Mul(a, b)' @@ -61,16 +59,16 @@ const lbf = TA.LinearBroadcastFunction a = randn(ComplexF64, 2, 2) b = randn(ComplexF64, 2, 2) - x = transpose(lbf(*)(2, a)) - @test x ≡ lbf(*)(2, transpose(a)) + x = transpose(LBF(*)(2, a)) + @test x ≡ LBF(*)(2, transpose(a)) @test copy(x) ≈ 2transpose(a) - x = transpose(lbf(conj)(a)) + x = transpose(LBF(conj)(a)) @test x ≡ adjoint(a) @test copy(x) ≈ permutedims(conj(a)) - x = transpose(lbf(+)(a, b)) - @test x ≡ lbf(+)(transpose(a), transpose(b)) + x = transpose(LBF(+)(a, b)) + @test x ≡ LBF(+)(transpose(a), transpose(b)) @test copy(x) ≈ transpose(a) + transpose(b) x = transpose(TA.Mul(a, b)) @@ -82,16 +80,16 @@ const lbf = TA.LinearBroadcastFunction b = randn(ComplexF64, 2, 2) perm = (2, 1) - x = FI.permuteddims(lbf(*)(2, a), perm) - @test x ≡ lbf(*)(2, FI.permuteddims(a, perm)) + x = FI.permuteddims(LBF(*)(2, a), perm) + @test x ≡ LBF(*)(2, FI.permuteddims(a, perm)) @test copy(x) ≈ 2permutedims(a, perm) - x = FI.permuteddims(lbf(conj)(a), perm) - @test x ≡ lbf(conj)(FI.permuteddims(a, perm)) + x = FI.permuteddims(LBF(conj)(a), perm) + @test x ≡ LBF(conj)(FI.permuteddims(a, perm)) @test copy(x) ≈ conj(permutedims(a, perm)) - x = FI.permuteddims(lbf(+)(a, b), perm) - @test x ≡ lbf(+)(FI.permuteddims(a, perm), FI.permuteddims(b, perm)) + x = FI.permuteddims(LBF(+)(a, b), perm) + @test x ≡ LBF(+)(FI.permuteddims(a, perm), FI.permuteddims(b, perm)) @test copy(x) ≈ permutedims(a, perm) + permutedims(b, perm) x = FI.permuteddims(TA.Mul(a, b), perm) @@ -102,58 +100,58 @@ const lbf = TA.LinearBroadcastFunction style = BC.DefaultArrayStyle{2}() @test TA.broadcasted_linear(identity, a) ≡ a - @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ lbf(*)(2, a) - @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ lbf(*)(a, 2) - @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ lbf(/)(a, 2) + @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ LBF(*)(2, a) + @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ LBF(*)(a, 2) + @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ LBF(/)(a, 2) @test TA.broadcasted_linear(style, identity, a) ≡ a - @test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ lbf(*)(2, a) - @test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ lbf(*)(a, 2) - @test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ lbf(/)(a, 2) - @test TA.broadcasted_linear(style, conj, a) ≡ lbf(conj)(a) + @test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ LBF(*)(2, a) + @test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ LBF(*)(a, 2) + @test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ LBF(/)(a, 2) + @test TA.broadcasted_linear(style, conj, a) ≡ LBF(conj)(a) @test_throws ArgumentError TA.broadcasted_linear(style, exp, a) end @testset "LinearBroadcastFunction algebra" begin a = randn(ComplexF64, 3, 3) # Scaling absorbs coefficients - @test lbf(*)(3, lbf(*)(2, a)) ≡ TA.ScaledBroadcasted(6, a) + @test LBF(*)(3, LBF(*)(2, a)) ≡ TA.ScaledBroadcasted(6, a) # Conjugation of scaled - x = lbf(conj)(lbf(*)(2im, a)) + x = LBF(conj)(LBF(*)(2im, a)) @test x ≡ TA.ScaledBroadcasted(-2im, TA.ConjBroadcasted(a)) # Double conjugation cancels - @test lbf(conj)(lbf(conj)(a)) ≡ a + @test LBF(conj)(LBF(conj)(a)) ≡ a # Subtraction b = randn(ComplexF64, 3, 3) - x = lbf(-)(a, b) + x = LBF(-)(a, b) @test copy(x) ≈ a - b # Unary minus - x = lbf(-)(a) + x = LBF(-)(a) @test copy(x) ≈ -a # Division - x = lbf(/)(a, 2) + x = LBF(/)(a, 2) @test copy(x) ≈ a / 2 # Left division - x = lbf(\)(2, a) + x = LBF(\)(2, a) @test copy(x) ≈ a / 2 # Scaling distributes over AddBroadcasted - ab = lbf(+)(a, b) - x = lbf(*)(3, ab) + ab = LBF(+)(a, b) + x = LBF(*)(3, ab) @test copy(x) ≈ 3a + 3b # Conjugation distributes over AddBroadcasted - x = lbf(conj)(ab) + x = LBF(conj)(ab) @test copy(x) ≈ conj(a) + conj(b) # Conjugation distributes over Mul m = TA.Mul(a, b) - x = lbf(conj)(m) + x = LBF(conj)(m) @test copy(x) ≈ conj(a) * conj(b) end @testset "AddBroadcasted flattening" begin @@ -162,17 +160,17 @@ const lbf = TA.LinearBroadcastFunction c = randn(ComplexF64, 2, 2) # AddBroadcasted + array flattens - ab = lbf(+)(a, b) - x = lbf(+)(ab, c) + ab = LBF(+)(a, b) + x = LBF(+)(ab, c) @test TA.addends(x) === (a, b, c) # array + AddBroadcasted flattens - x = lbf(+)(c, ab) + x = LBF(+)(c, ab) @test TA.addends(x) === (c, a, b) # AddBroadcasted + AddBroadcasted flattens - cd = lbf(+)(c, a) - x = lbf(+)(ab, cd) + cd = LBF(+)(c, a) + x = LBF(+)(ab, cd) @test TA.addends(x) === (a, b, c, a) end end From b66f9ea19a08c7a505b818c25cb88a6349ed7cdf Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 00:33:05 -0400 Subject: [PATCH 07/29] Use typeof(LBF(f)) in dispatch signatures Replace `::LinearBroadcastFunction{typeof(f)}` with `::typeof(LBF(f))` in method signatures for consistency with the LBF shorthand. Fix1/Fix2 patterns remain as LinearBroadcastFunction{<:...} since those can't be expressed via typeof(LBF(...)). Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 46 ++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 2638da6f..569e2422 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -268,64 +268,64 @@ LBF(+)(a, b) # AddBroadcasted(a, b) const LBF = LinearBroadcastFunction # Scaling: Number * AbstractArray -function (::LinearBroadcastFunction{typeof(*)})(α::Number, a::AbstractArray) +function (::typeof(LBF(*)))(α::Number, a::AbstractArray) return ScaledBroadcasted(α, a) end -function (::LinearBroadcastFunction{typeof(*)})(a::AbstractArray, α::Number) +function (::typeof(LBF(*)))(a::AbstractArray, α::Number) return ScaledBroadcasted(α, a) end # Scaling of ScaledBroadcasted: absorb coefficient. -function (::LinearBroadcastFunction{typeof(*)})(α::Number, a::ScaledBroadcasted) +function (::typeof(LBF(*)))(α::Number, a::ScaledBroadcasted) return ScaledBroadcasted(α * coeff(a), unscaled(a)) end # Conjugation. -function (::LinearBroadcastFunction{typeof(conj)})(a::AbstractArray) +function (::typeof(LBF(conj)))(a::AbstractArray) return ConjBroadcasted(a) end -(::LinearBroadcastFunction{typeof(conj)})(a::AbstractArray{<:Real}) = a -(::LinearBroadcastFunction{typeof(conj)})(a::ConjBroadcasted) = unconj(a) -function (::LinearBroadcastFunction{typeof(conj)})(a::ScaledBroadcasted) +(::typeof(LBF(conj)))(a::AbstractArray{<:Real}) = a +(::typeof(LBF(conj)))(a::ConjBroadcasted) = unconj(a) +function (::typeof(LBF(conj)))(a::ScaledBroadcasted) return ScaledBroadcasted( conj(coeff(a)), LBF(conj)(unscaled(a)) ) end # Addition. -function (lf::LinearBroadcastFunction{typeof(+)})(a, b) +function (lf::typeof(LBF(+)))(a, b) return AddBroadcasted(a, b) end -function (lf::LinearBroadcastFunction{typeof(+)})(a, b, c, xs...) +function (lf::typeof(LBF(+)))(a, b, c, xs...) return Base.afoldl(lf, lf(lf(a, b), c), xs...) end # Flatten AddBroadcasted + anything. -function (::LinearBroadcastFunction{typeof(+)})(a::AddBroadcasted, b) +function (::typeof(LBF(+)))(a::AddBroadcasted, b) return AddBroadcasted(addends(a)..., b) end -function (::LinearBroadcastFunction{typeof(+)})(a, b::AddBroadcasted) +function (::typeof(LBF(+)))(a, b::AddBroadcasted) return AddBroadcasted(a, addends(b)...) end -function (::LinearBroadcastFunction{typeof(+)})(a::AddBroadcasted, b::AddBroadcasted) +function (::typeof(LBF(+)))(a::AddBroadcasted, b::AddBroadcasted) return AddBroadcasted(addends(a)..., addends(b)...) end -(::LinearBroadcastFunction{typeof(+)})(a) = a +(::typeof(LBF(+)))(a) = a # Subtraction. -function (::LinearBroadcastFunction{typeof(-)})(a, b) +function (::typeof(LBF(-)))(a, b) return LBF(+)(a, LBF(*)(- 1, b)) end -(::LinearBroadcastFunction{typeof(-)})(a) = LBF(*)(-1, a) +(::typeof(LBF(-)))(a) = LBF(*)(-1, a) # Division / left-division by scalars. -function (::LinearBroadcastFunction{typeof(/)})(a, b::Number) +function (::typeof(LBF(/)))(a, b::Number) return LBF(*)(inv(b), a) end -function (::LinearBroadcastFunction{typeof(\)})(a::Number, b) +function (::typeof(LBF(\)))(a::Number, b) return LBF(*)(inv(a), b) end # Identity. -(::LinearBroadcastFunction{typeof(identity)})(a) = a +(::typeof(LBF(identity)))(a) = a # Fix1/Fix2 wrappers for scalar multiplication/division. function (lf::LinearBroadcastFunction{<:Base.Fix1{typeof(*)}})(a) @@ -339,28 +339,28 @@ function (lf::LinearBroadcastFunction{<:Base.Fix2{typeof(/)}})(a) end # Scaling of AddBroadcasted distributes. -function (::LinearBroadcastFunction{typeof(*)})(α::Number, a::AddBroadcasted) +function (::typeof(LBF(*)))(α::Number, a::AddBroadcasted) return LBF(+)(map(x -> LBF(*)(α, x), addends(a))...) end # Conjugation of AddBroadcasted distributes. -function (::LinearBroadcastFunction{typeof(conj)})(a::AddBroadcasted) +function (::typeof(LBF(conj)))(a::AddBroadcasted) return LBF(+)(map(x -> LBF(conj)(x), addends(a))...) end # Conjugation of Mul distributes. -function (::LinearBroadcastFunction{typeof(conj)})(a::Mul) +function (::typeof(LBF(conj)))(a::Mul) f = factors(a) return Mul(LBF(conj)(f[1]), LBF(conj)(f[2])) end # Scaling of Mul: wrap in ScaledBroadcasted. -function (::LinearBroadcastFunction{typeof(*)})(α::Number, a::Mul) +function (::typeof(LBF(*)))(α::Number, a::Mul) return ScaledBroadcasted(α, a) end # Number * Number passthrough (for broadcast lowering). -(::LinearBroadcastFunction{typeof(*)})(a::Number, b::Number) = a * b +(::typeof(LBF(*)))(a::Number, b::Number) = a * b # ---------------------------------------------------------------------------- # # Broadcast integration From ef6bf2cd0ee0720d0964377beb2dec0df0f83b26 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 00:35:24 -0400 Subject: [PATCH 08/29] Rename LBF to LinearFunc and use in Fix dispatch signatures LinearFunc is less terse than LBF while still shorter than LinearBroadcastFunction. Also use LinearFunc{<:Base.Fix1{typeof(*)}} etc. in Fix1/Fix2 dispatch signatures for consistency. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 102 +++++++++++++++++++-------------------- test/test_lazy.jl | 90 +++++++++++++++++----------------- 2 files changed, 96 insertions(+), 96 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 569e2422..aec22713 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -253,114 +253,114 @@ struct LinearBroadcastFunction{F} <: Function end """ - LBF + LinearFunc Shorthand for `LinearBroadcastFunction`. # Examples ```julia -LBF(*)(2.0, a) # ScaledBroadcasted(2.0, a) -LBF(conj)(a) # ConjBroadcasted(a) -LBF(+)(a, b) # AddBroadcasted(a, b) +LinearFunc(*)(2.0, a) # ScaledBroadcasted(2.0, a) +LinearFunc(conj)(a) # ConjBroadcasted(a) +LinearFunc(+)(a, b) # AddBroadcasted(a, b) ``` """ -const LBF = LinearBroadcastFunction +const LinearFunc = LinearBroadcastFunction # Scaling: Number * AbstractArray -function (::typeof(LBF(*)))(α::Number, a::AbstractArray) +function (::typeof(LinearFunc(*)))(α::Number, a::AbstractArray) return ScaledBroadcasted(α, a) end -function (::typeof(LBF(*)))(a::AbstractArray, α::Number) +function (::typeof(LinearFunc(*)))(a::AbstractArray, α::Number) return ScaledBroadcasted(α, a) end # Scaling of ScaledBroadcasted: absorb coefficient. -function (::typeof(LBF(*)))(α::Number, a::ScaledBroadcasted) +function (::typeof(LinearFunc(*)))(α::Number, a::ScaledBroadcasted) return ScaledBroadcasted(α * coeff(a), unscaled(a)) end # Conjugation. -function (::typeof(LBF(conj)))(a::AbstractArray) +function (::typeof(LinearFunc(conj)))(a::AbstractArray) return ConjBroadcasted(a) end -(::typeof(LBF(conj)))(a::AbstractArray{<:Real}) = a -(::typeof(LBF(conj)))(a::ConjBroadcasted) = unconj(a) -function (::typeof(LBF(conj)))(a::ScaledBroadcasted) +(::typeof(LinearFunc(conj)))(a::AbstractArray{<:Real}) = a +(::typeof(LinearFunc(conj)))(a::ConjBroadcasted) = unconj(a) +function (::typeof(LinearFunc(conj)))(a::ScaledBroadcasted) return ScaledBroadcasted( - conj(coeff(a)), LBF(conj)(unscaled(a)) + conj(coeff(a)), LinearFunc(conj)(unscaled(a)) ) end # Addition. -function (lf::typeof(LBF(+)))(a, b) +function (lf::typeof(LinearFunc(+)))(a, b) return AddBroadcasted(a, b) end -function (lf::typeof(LBF(+)))(a, b, c, xs...) +function (lf::typeof(LinearFunc(+)))(a, b, c, xs...) return Base.afoldl(lf, lf(lf(a, b), c), xs...) end # Flatten AddBroadcasted + anything. -function (::typeof(LBF(+)))(a::AddBroadcasted, b) +function (::typeof(LinearFunc(+)))(a::AddBroadcasted, b) return AddBroadcasted(addends(a)..., b) end -function (::typeof(LBF(+)))(a, b::AddBroadcasted) +function (::typeof(LinearFunc(+)))(a, b::AddBroadcasted) return AddBroadcasted(a, addends(b)...) end -function (::typeof(LBF(+)))(a::AddBroadcasted, b::AddBroadcasted) +function (::typeof(LinearFunc(+)))(a::AddBroadcasted, b::AddBroadcasted) return AddBroadcasted(addends(a)..., addends(b)...) end -(::typeof(LBF(+)))(a) = a +(::typeof(LinearFunc(+)))(a) = a # Subtraction. -function (::typeof(LBF(-)))(a, b) - return LBF(+)(a, LBF(*)(- 1, b)) +function (::typeof(LinearFunc(-)))(a, b) + return LinearFunc(+)(a, LinearFunc(*)(- 1, b)) end -(::typeof(LBF(-)))(a) = LBF(*)(-1, a) +(::typeof(LinearFunc(-)))(a) = LinearFunc(*)(-1, a) # Division / left-division by scalars. -function (::typeof(LBF(/)))(a, b::Number) - return LBF(*)(inv(b), a) +function (::typeof(LinearFunc(/)))(a, b::Number) + return LinearFunc(*)(inv(b), a) end -function (::typeof(LBF(\)))(a::Number, b) - return LBF(*)(inv(a), b) +function (::typeof(LinearFunc(\)))(a::Number, b) + return LinearFunc(*)(inv(a), b) end # Identity. -(::typeof(LBF(identity)))(a) = a +(::typeof(LinearFunc(identity)))(a) = a # Fix1/Fix2 wrappers for scalar multiplication/division. -function (lf::LinearBroadcastFunction{<:Base.Fix1{typeof(*)}})(a) - return LBF(*)(lf.f.x, a) +function (lf::LinearFunc{<:Base.Fix1{typeof(*)}})(a) + return LinearFunc(*)(lf.f.x, a) end -function (lf::LinearBroadcastFunction{<:Base.Fix2{typeof(*)}})(a) - return LBF(*)(a, lf.f.x) +function (lf::LinearFunc{<:Base.Fix2{typeof(*)}})(a) + return LinearFunc(*)(a, lf.f.x) end -function (lf::LinearBroadcastFunction{<:Base.Fix2{typeof(/)}})(a) - return LBF(/)(a, lf.f.x) +function (lf::LinearFunc{<:Base.Fix2{typeof(/)}})(a) + return LinearFunc(/)(a, lf.f.x) end # Scaling of AddBroadcasted distributes. -function (::typeof(LBF(*)))(α::Number, a::AddBroadcasted) - return LBF(+)(map(x -> LBF(*)(α, x), addends(a))...) +function (::typeof(LinearFunc(*)))(α::Number, a::AddBroadcasted) + return LinearFunc(+)(map(x -> LinearFunc(*)(α, x), addends(a))...) end # Conjugation of AddBroadcasted distributes. -function (::typeof(LBF(conj)))(a::AddBroadcasted) - return LBF(+)(map(x -> LBF(conj)(x), addends(a))...) +function (::typeof(LinearFunc(conj)))(a::AddBroadcasted) + return LinearFunc(+)(map(x -> LinearFunc(conj)(x), addends(a))...) end # Conjugation of Mul distributes. -function (::typeof(LBF(conj)))(a::Mul) +function (::typeof(LinearFunc(conj)))(a::Mul) f = factors(a) - return Mul(LBF(conj)(f[1]), LBF(conj)(f[2])) + return Mul(LinearFunc(conj)(f[1]), LinearFunc(conj)(f[2])) end # Scaling of Mul: wrap in ScaledBroadcasted. -function (::typeof(LBF(*)))(α::Number, a::Mul) +function (::typeof(LinearFunc(*)))(α::Number, a::Mul) return ScaledBroadcasted(α, a) end # Number * Number passthrough (for broadcast lowering). -(::typeof(LBF(*)))(a::Number, b::Number) = a * b +(::typeof(LinearFunc(*)))(a::Number, b::Number) = a * b # ---------------------------------------------------------------------------- # # Broadcast integration @@ -408,7 +408,7 @@ end to_linear(x) = x function to_linear(bc::BC.Broadcasted) - return LBF(bc.f)(to_linear.(bc.args)...) + return LinearFunc(bc.f)(to_linear.(bc.args)...) end function broadcast_error(style, f) @@ -498,7 +498,7 @@ function BC.broadcasted( a::AbstractArray, b::AbstractArray ) - return LBF(+)(a, b) + return LinearFunc(+)(a, b) end function BC.broadcasted( ::LinearBroadcastedStyle, @@ -507,7 +507,7 @@ function BC.broadcasted( b::BC.Broadcasted ) is_linear(b) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return LBF(+)(a, to_linear(b)) + return LinearFunc(+)(a, to_linear(b)) end function BC.broadcasted( ::LinearBroadcastedStyle, @@ -516,7 +516,7 @@ function BC.broadcasted( b::AbstractArray ) is_linear(a) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return LBF(+)(to_linear(a), b) + return LinearFunc(+)(to_linear(a), b) end function BC.broadcasted( ::LinearBroadcastedStyle, ::typeof(+), a::BC.Broadcasted, b::BC.Broadcasted @@ -524,20 +524,20 @@ function BC.broadcasted( return error("Not implemented") end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(*), α::Number, a::AbstractArray) - return LBF(*)(α, a) + return LinearFunc(*)(α, a) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(*), a::AbstractArray, α::Number) - return LBF(*)(a, α) + return LinearFunc(*)(a, α) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(\), α::Number, a::AbstractArray) - return LBF(\)(α, a) + return LinearFunc(\)(α, a) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(/), a::AbstractArray, α::Number) - return LBF(/)(a, α) + return LinearFunc(/)(a, α) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(-), a::AbstractArray) - return LBF(-)(a) + return LinearFunc(-)(a) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(conj), a::AbstractArray) - return LBF(conj)(a) + return LinearFunc(conj)(a) end diff --git a/test/test_lazy.jl b/test/test_lazy.jl index 83424e5b..e158b6e4 100644 --- a/test/test_lazy.jl +++ b/test/test_lazy.jl @@ -1,6 +1,6 @@ import FunctionImplementations as FI using Base.Broadcast: Broadcast as BC -using TensorAlgebra: TensorAlgebra as TA, LBF +using TensorAlgebra: TensorAlgebra as TA, LinearFunc using Test: @test, @test_throws, @testset @testset "LinearBroadcasted and Mul" begin @@ -9,27 +9,27 @@ using Test: @test, @test_throws, @testset b = randn(ComplexF64, 3, 3) c = randn(ComplexF64, 3, 3) - x = LBF(*)(2, a) + x = LinearFunc(*)(2, a) @test x ≡ TA.ScaledBroadcasted(2, a) @test copy(x) ≈ 2a - x = LBF(conj)(a) + x = LinearFunc(conj)(a) @test x ≡ TA.ConjBroadcasted(a) @test copy(x) ≈ conj(a) @test conj(x) ≈ a - x = LBF(+)(a, b) + x = LinearFunc(+)(a, b) @test x ≡ TA.AddBroadcasted(a, b) @test copy(x) ≈ a + b x = TA.Mul(a, b) @test copy(x) ≈ a * b - x = LBF(+)(TA.Mul(a, b), c) + x = LinearFunc(+)(TA.Mul(a, b), c) @test x ≡ TA.AddBroadcasted(TA.Mul(a, b), c) @test copy(x) ≈ a * b + c - x = LBF(+)(LBF(*)(2, TA.Mul(a, b)), LBF(*)(3, c)) + x = LinearFunc(+)(LinearFunc(*)(2, TA.Mul(a, b)), LinearFunc(*)(3, c)) @test x ≡ TA.AddBroadcasted( TA.ScaledBroadcasted(2, TA.Mul(a, b)), TA.ScaledBroadcasted(3, c) ) @@ -39,16 +39,16 @@ using Test: @test, @test_throws, @testset a = randn(ComplexF64, 2, 2) b = randn(ComplexF64, 2, 2) - x = LBF(*)(2, a)' - @test x ≡ LBF(*)(2, a') + x = LinearFunc(*)(2, a)' + @test x ≡ LinearFunc(*)(2, a') @test copy(x) ≈ 2a' - x = LBF(conj)(a)' + x = LinearFunc(conj)(a)' @test x ≡ transpose(a) @test copy(x) ≈ permutedims(a) - x = LBF(+)(a, b)' - @test x ≡ LBF(+)(a', b') + x = LinearFunc(+)(a, b)' + @test x ≡ LinearFunc(+)(a', b') @test copy(x) ≈ a' + b' x = TA.Mul(a, b)' @@ -59,16 +59,16 @@ using Test: @test, @test_throws, @testset a = randn(ComplexF64, 2, 2) b = randn(ComplexF64, 2, 2) - x = transpose(LBF(*)(2, a)) - @test x ≡ LBF(*)(2, transpose(a)) + x = transpose(LinearFunc(*)(2, a)) + @test x ≡ LinearFunc(*)(2, transpose(a)) @test copy(x) ≈ 2transpose(a) - x = transpose(LBF(conj)(a)) + x = transpose(LinearFunc(conj)(a)) @test x ≡ adjoint(a) @test copy(x) ≈ permutedims(conj(a)) - x = transpose(LBF(+)(a, b)) - @test x ≡ LBF(+)(transpose(a), transpose(b)) + x = transpose(LinearFunc(+)(a, b)) + @test x ≡ LinearFunc(+)(transpose(a), transpose(b)) @test copy(x) ≈ transpose(a) + transpose(b) x = transpose(TA.Mul(a, b)) @@ -80,16 +80,16 @@ using Test: @test, @test_throws, @testset b = randn(ComplexF64, 2, 2) perm = (2, 1) - x = FI.permuteddims(LBF(*)(2, a), perm) - @test x ≡ LBF(*)(2, FI.permuteddims(a, perm)) + x = FI.permuteddims(LinearFunc(*)(2, a), perm) + @test x ≡ LinearFunc(*)(2, FI.permuteddims(a, perm)) @test copy(x) ≈ 2permutedims(a, perm) - x = FI.permuteddims(LBF(conj)(a), perm) - @test x ≡ LBF(conj)(FI.permuteddims(a, perm)) + x = FI.permuteddims(LinearFunc(conj)(a), perm) + @test x ≡ LinearFunc(conj)(FI.permuteddims(a, perm)) @test copy(x) ≈ conj(permutedims(a, perm)) - x = FI.permuteddims(LBF(+)(a, b), perm) - @test x ≡ LBF(+)(FI.permuteddims(a, perm), FI.permuteddims(b, perm)) + x = FI.permuteddims(LinearFunc(+)(a, b), perm) + @test x ≡ LinearFunc(+)(FI.permuteddims(a, perm), FI.permuteddims(b, perm)) @test copy(x) ≈ permutedims(a, perm) + permutedims(b, perm) x = FI.permuteddims(TA.Mul(a, b), perm) @@ -100,58 +100,58 @@ using Test: @test, @test_throws, @testset style = BC.DefaultArrayStyle{2}() @test TA.broadcasted_linear(identity, a) ≡ a - @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ LBF(*)(2, a) - @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ LBF(*)(a, 2) - @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ LBF(/)(a, 2) + @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ LinearFunc(*)(2, a) + @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ LinearFunc(*)(a, 2) + @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ LinearFunc(/)(a, 2) @test TA.broadcasted_linear(style, identity, a) ≡ a - @test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ LBF(*)(2, a) - @test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ LBF(*)(a, 2) - @test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ LBF(/)(a, 2) - @test TA.broadcasted_linear(style, conj, a) ≡ LBF(conj)(a) + @test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ LinearFunc(*)(2, a) + @test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ LinearFunc(*)(a, 2) + @test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ LinearFunc(/)(a, 2) + @test TA.broadcasted_linear(style, conj, a) ≡ LinearFunc(conj)(a) @test_throws ArgumentError TA.broadcasted_linear(style, exp, a) end @testset "LinearBroadcastFunction algebra" begin a = randn(ComplexF64, 3, 3) # Scaling absorbs coefficients - @test LBF(*)(3, LBF(*)(2, a)) ≡ TA.ScaledBroadcasted(6, a) + @test LinearFunc(*)(3, LinearFunc(*)(2, a)) ≡ TA.ScaledBroadcasted(6, a) # Conjugation of scaled - x = LBF(conj)(LBF(*)(2im, a)) + x = LinearFunc(conj)(LinearFunc(*)(2im, a)) @test x ≡ TA.ScaledBroadcasted(-2im, TA.ConjBroadcasted(a)) # Double conjugation cancels - @test LBF(conj)(LBF(conj)(a)) ≡ a + @test LinearFunc(conj)(LinearFunc(conj)(a)) ≡ a # Subtraction b = randn(ComplexF64, 3, 3) - x = LBF(-)(a, b) + x = LinearFunc(-)(a, b) @test copy(x) ≈ a - b # Unary minus - x = LBF(-)(a) + x = LinearFunc(-)(a) @test copy(x) ≈ -a # Division - x = LBF(/)(a, 2) + x = LinearFunc(/)(a, 2) @test copy(x) ≈ a / 2 # Left division - x = LBF(\)(2, a) + x = LinearFunc(\)(2, a) @test copy(x) ≈ a / 2 # Scaling distributes over AddBroadcasted - ab = LBF(+)(a, b) - x = LBF(*)(3, ab) + ab = LinearFunc(+)(a, b) + x = LinearFunc(*)(3, ab) @test copy(x) ≈ 3a + 3b # Conjugation distributes over AddBroadcasted - x = LBF(conj)(ab) + x = LinearFunc(conj)(ab) @test copy(x) ≈ conj(a) + conj(b) # Conjugation distributes over Mul m = TA.Mul(a, b) - x = LBF(conj)(m) + x = LinearFunc(conj)(m) @test copy(x) ≈ conj(a) * conj(b) end @testset "AddBroadcasted flattening" begin @@ -160,17 +160,17 @@ using Test: @test, @test_throws, @testset c = randn(ComplexF64, 2, 2) # AddBroadcasted + array flattens - ab = LBF(+)(a, b) - x = LBF(+)(ab, c) + ab = LinearFunc(+)(a, b) + x = LinearFunc(+)(ab, c) @test TA.addends(x) === (a, b, c) # array + AddBroadcasted flattens - x = LBF(+)(c, ab) + x = LinearFunc(+)(c, ab) @test TA.addends(x) === (c, a, b) # AddBroadcasted + AddBroadcasted flattens - cd = LBF(+)(c, a) - x = LBF(+)(ab, cd) + cd = LinearFunc(+)(c, a) + x = LinearFunc(+)(ab, cd) @test TA.addends(x) === (a, b, c, a) end end From ff0d095d34cf71237394dd46ea324d7fa4d9dd28 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 00:36:38 -0400 Subject: [PATCH 09/29] Drop LinearFunc alias, use LinearBroadcastFunction everywhere MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit No shorthand — just use the full name consistently, matching the Base.BroadcastFunction convention. Keep typeof(LinearBroadcastFunction(f)) in dispatch signatures. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 111 ++++++++++++++++++--------------------- test/test_lazy.jl | 98 ++++++++++++++++++---------------- 2 files changed, 103 insertions(+), 106 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index aec22713..9a363ec5 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -252,115 +252,104 @@ struct LinearBroadcastFunction{F} <: Function f::F end -""" - LinearFunc - -Shorthand for `LinearBroadcastFunction`. - -# Examples - -```julia -LinearFunc(*)(2.0, a) # ScaledBroadcasted(2.0, a) -LinearFunc(conj)(a) # ConjBroadcasted(a) -LinearFunc(+)(a, b) # AddBroadcasted(a, b) -``` -""" -const LinearFunc = LinearBroadcastFunction - # Scaling: Number * AbstractArray -function (::typeof(LinearFunc(*)))(α::Number, a::AbstractArray) +function (::typeof(LinearBroadcastFunction(*)))(α::Number, a::AbstractArray) return ScaledBroadcasted(α, a) end -function (::typeof(LinearFunc(*)))(a::AbstractArray, α::Number) +function (::typeof(LinearBroadcastFunction(*)))(a::AbstractArray, α::Number) return ScaledBroadcasted(α, a) end # Scaling of ScaledBroadcasted: absorb coefficient. -function (::typeof(LinearFunc(*)))(α::Number, a::ScaledBroadcasted) +function (::typeof(LinearBroadcastFunction(*)))(α::Number, a::ScaledBroadcasted) return ScaledBroadcasted(α * coeff(a), unscaled(a)) end # Conjugation. -function (::typeof(LinearFunc(conj)))(a::AbstractArray) +function (::typeof(LinearBroadcastFunction(conj)))(a::AbstractArray) return ConjBroadcasted(a) end -(::typeof(LinearFunc(conj)))(a::AbstractArray{<:Real}) = a -(::typeof(LinearFunc(conj)))(a::ConjBroadcasted) = unconj(a) -function (::typeof(LinearFunc(conj)))(a::ScaledBroadcasted) +(::typeof(LinearBroadcastFunction(conj)))(a::AbstractArray{<:Real}) = a +(::typeof(LinearBroadcastFunction(conj)))(a::ConjBroadcasted) = unconj(a) +function (::typeof(LinearBroadcastFunction(conj)))(a::ScaledBroadcasted) return ScaledBroadcasted( - conj(coeff(a)), LinearFunc(conj)(unscaled(a)) + conj(coeff(a)), LinearBroadcastFunction(conj)(unscaled(a)) ) end # Addition. -function (lf::typeof(LinearFunc(+)))(a, b) +function (lf::typeof(LinearBroadcastFunction(+)))(a, b) return AddBroadcasted(a, b) end -function (lf::typeof(LinearFunc(+)))(a, b, c, xs...) +function (lf::typeof(LinearBroadcastFunction(+)))(a, b, c, xs...) return Base.afoldl(lf, lf(lf(a, b), c), xs...) end # Flatten AddBroadcasted + anything. -function (::typeof(LinearFunc(+)))(a::AddBroadcasted, b) +function (::typeof(LinearBroadcastFunction(+)))(a::AddBroadcasted, b) return AddBroadcasted(addends(a)..., b) end -function (::typeof(LinearFunc(+)))(a, b::AddBroadcasted) +function (::typeof(LinearBroadcastFunction(+)))(a, b::AddBroadcasted) return AddBroadcasted(a, addends(b)...) end -function (::typeof(LinearFunc(+)))(a::AddBroadcasted, b::AddBroadcasted) +function (::typeof(LinearBroadcastFunction(+)))(a::AddBroadcasted, b::AddBroadcasted) return AddBroadcasted(addends(a)..., addends(b)...) end -(::typeof(LinearFunc(+)))(a) = a +(::typeof(LinearBroadcastFunction(+)))(a) = a # Subtraction. -function (::typeof(LinearFunc(-)))(a, b) - return LinearFunc(+)(a, LinearFunc(*)(- 1, b)) +function (::typeof(LinearBroadcastFunction(-)))(a, b) + return LinearBroadcastFunction(+)(a, LinearBroadcastFunction(*)(- 1, b)) end -(::typeof(LinearFunc(-)))(a) = LinearFunc(*)(-1, a) +(::typeof(LinearBroadcastFunction(-)))(a) = LinearBroadcastFunction(*)(-1, a) # Division / left-division by scalars. -function (::typeof(LinearFunc(/)))(a, b::Number) - return LinearFunc(*)(inv(b), a) +function (::typeof(LinearBroadcastFunction(/)))(a, b::Number) + return LinearBroadcastFunction(*)(inv(b), a) end -function (::typeof(LinearFunc(\)))(a::Number, b) - return LinearFunc(*)(inv(a), b) +function (::typeof(LinearBroadcastFunction(\)))(a::Number, b) + return LinearBroadcastFunction(*)(inv(a), b) end # Identity. -(::typeof(LinearFunc(identity)))(a) = a +(::typeof(LinearBroadcastFunction(identity)))(a) = a # Fix1/Fix2 wrappers for scalar multiplication/division. -function (lf::LinearFunc{<:Base.Fix1{typeof(*)}})(a) - return LinearFunc(*)(lf.f.x, a) +function (lf::LinearBroadcastFunction{<:Base.Fix1{typeof(*)}})(a) + return LinearBroadcastFunction(*)(lf.f.x, a) end -function (lf::LinearFunc{<:Base.Fix2{typeof(*)}})(a) - return LinearFunc(*)(a, lf.f.x) +function (lf::LinearBroadcastFunction{<:Base.Fix2{typeof(*)}})(a) + return LinearBroadcastFunction(*)(a, lf.f.x) end -function (lf::LinearFunc{<:Base.Fix2{typeof(/)}})(a) - return LinearFunc(/)(a, lf.f.x) +function (lf::LinearBroadcastFunction{<:Base.Fix2{typeof(/)}})(a) + return LinearBroadcastFunction(/)(a, lf.f.x) end # Scaling of AddBroadcasted distributes. -function (::typeof(LinearFunc(*)))(α::Number, a::AddBroadcasted) - return LinearFunc(+)(map(x -> LinearFunc(*)(α, x), addends(a))...) +function (::typeof(LinearBroadcastFunction(*)))(α::Number, a::AddBroadcasted) + return LinearBroadcastFunction(+)( + map(x -> LinearBroadcastFunction(*)(α, x), addends(a))... + ) end # Conjugation of AddBroadcasted distributes. -function (::typeof(LinearFunc(conj)))(a::AddBroadcasted) - return LinearFunc(+)(map(x -> LinearFunc(conj)(x), addends(a))...) +function (::typeof(LinearBroadcastFunction(conj)))(a::AddBroadcasted) + return LinearBroadcastFunction(+)( + map(x -> LinearBroadcastFunction(conj)(x), addends(a))... + ) end # Conjugation of Mul distributes. -function (::typeof(LinearFunc(conj)))(a::Mul) +function (::typeof(LinearBroadcastFunction(conj)))(a::Mul) f = factors(a) - return Mul(LinearFunc(conj)(f[1]), LinearFunc(conj)(f[2])) + return Mul(LinearBroadcastFunction(conj)(f[1]), LinearBroadcastFunction(conj)(f[2])) end # Scaling of Mul: wrap in ScaledBroadcasted. -function (::typeof(LinearFunc(*)))(α::Number, a::Mul) +function (::typeof(LinearBroadcastFunction(*)))(α::Number, a::Mul) return ScaledBroadcasted(α, a) end # Number * Number passthrough (for broadcast lowering). -(::typeof(LinearFunc(*)))(a::Number, b::Number) = a * b +(::typeof(LinearBroadcastFunction(*)))(a::Number, b::Number) = a * b # ---------------------------------------------------------------------------- # # Broadcast integration @@ -408,7 +397,7 @@ end to_linear(x) = x function to_linear(bc::BC.Broadcasted) - return LinearFunc(bc.f)(to_linear.(bc.args)...) + return LinearBroadcastFunction(bc.f)(to_linear.(bc.args)...) end function broadcast_error(style, f) @@ -498,7 +487,7 @@ function BC.broadcasted( a::AbstractArray, b::AbstractArray ) - return LinearFunc(+)(a, b) + return LinearBroadcastFunction(+)(a, b) end function BC.broadcasted( ::LinearBroadcastedStyle, @@ -507,7 +496,7 @@ function BC.broadcasted( b::BC.Broadcasted ) is_linear(b) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return LinearFunc(+)(a, to_linear(b)) + return LinearBroadcastFunction(+)(a, to_linear(b)) end function BC.broadcasted( ::LinearBroadcastedStyle, @@ -516,7 +505,7 @@ function BC.broadcasted( b::AbstractArray ) is_linear(a) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return LinearFunc(+)(to_linear(a), b) + return LinearBroadcastFunction(+)(to_linear(a), b) end function BC.broadcasted( ::LinearBroadcastedStyle, ::typeof(+), a::BC.Broadcasted, b::BC.Broadcasted @@ -524,20 +513,20 @@ function BC.broadcasted( return error("Not implemented") end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(*), α::Number, a::AbstractArray) - return LinearFunc(*)(α, a) + return LinearBroadcastFunction(*)(α, a) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(*), a::AbstractArray, α::Number) - return LinearFunc(*)(a, α) + return LinearBroadcastFunction(*)(a, α) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(\), α::Number, a::AbstractArray) - return LinearFunc(\)(α, a) + return LinearBroadcastFunction(\)(α, a) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(/), a::AbstractArray, α::Number) - return LinearFunc(/)(a, α) + return LinearBroadcastFunction(/)(a, α) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(-), a::AbstractArray) - return LinearFunc(-)(a) + return LinearBroadcastFunction(-)(a) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(conj), a::AbstractArray) - return LinearFunc(conj)(a) + return LinearBroadcastFunction(conj)(a) end diff --git a/test/test_lazy.jl b/test/test_lazy.jl index e158b6e4..16983bf1 100644 --- a/test/test_lazy.jl +++ b/test/test_lazy.jl @@ -1,6 +1,6 @@ import FunctionImplementations as FI using Base.Broadcast: Broadcast as BC -using TensorAlgebra: TensorAlgebra as TA, LinearFunc +using TensorAlgebra: TensorAlgebra as TA, LinearBroadcastFunction using Test: @test, @test_throws, @testset @testset "LinearBroadcasted and Mul" begin @@ -9,27 +9,30 @@ using Test: @test, @test_throws, @testset b = randn(ComplexF64, 3, 3) c = randn(ComplexF64, 3, 3) - x = LinearFunc(*)(2, a) + x = LinearBroadcastFunction(*)(2, a) @test x ≡ TA.ScaledBroadcasted(2, a) @test copy(x) ≈ 2a - x = LinearFunc(conj)(a) + x = LinearBroadcastFunction(conj)(a) @test x ≡ TA.ConjBroadcasted(a) @test copy(x) ≈ conj(a) @test conj(x) ≈ a - x = LinearFunc(+)(a, b) + x = LinearBroadcastFunction(+)(a, b) @test x ≡ TA.AddBroadcasted(a, b) @test copy(x) ≈ a + b x = TA.Mul(a, b) @test copy(x) ≈ a * b - x = LinearFunc(+)(TA.Mul(a, b), c) + x = LinearBroadcastFunction(+)(TA.Mul(a, b), c) @test x ≡ TA.AddBroadcasted(TA.Mul(a, b), c) @test copy(x) ≈ a * b + c - x = LinearFunc(+)(LinearFunc(*)(2, TA.Mul(a, b)), LinearFunc(*)(3, c)) + x = LinearBroadcastFunction(+)( + LinearBroadcastFunction(*)(2, TA.Mul(a, b)), + LinearBroadcastFunction(*)(3, c) + ) @test x ≡ TA.AddBroadcasted( TA.ScaledBroadcasted(2, TA.Mul(a, b)), TA.ScaledBroadcasted(3, c) ) @@ -39,16 +42,16 @@ using Test: @test, @test_throws, @testset a = randn(ComplexF64, 2, 2) b = randn(ComplexF64, 2, 2) - x = LinearFunc(*)(2, a)' - @test x ≡ LinearFunc(*)(2, a') + x = LinearBroadcastFunction(*)(2, a)' + @test x ≡ LinearBroadcastFunction(*)(2, a') @test copy(x) ≈ 2a' - x = LinearFunc(conj)(a)' + x = LinearBroadcastFunction(conj)(a)' @test x ≡ transpose(a) @test copy(x) ≈ permutedims(a) - x = LinearFunc(+)(a, b)' - @test x ≡ LinearFunc(+)(a', b') + x = LinearBroadcastFunction(+)(a, b)' + @test x ≡ LinearBroadcastFunction(+)(a', b') @test copy(x) ≈ a' + b' x = TA.Mul(a, b)' @@ -59,16 +62,16 @@ using Test: @test, @test_throws, @testset a = randn(ComplexF64, 2, 2) b = randn(ComplexF64, 2, 2) - x = transpose(LinearFunc(*)(2, a)) - @test x ≡ LinearFunc(*)(2, transpose(a)) + x = transpose(LinearBroadcastFunction(*)(2, a)) + @test x ≡ LinearBroadcastFunction(*)(2, transpose(a)) @test copy(x) ≈ 2transpose(a) - x = transpose(LinearFunc(conj)(a)) + x = transpose(LinearBroadcastFunction(conj)(a)) @test x ≡ adjoint(a) @test copy(x) ≈ permutedims(conj(a)) - x = transpose(LinearFunc(+)(a, b)) - @test x ≡ LinearFunc(+)(transpose(a), transpose(b)) + x = transpose(LinearBroadcastFunction(+)(a, b)) + @test x ≡ LinearBroadcastFunction(+)(transpose(a), transpose(b)) @test copy(x) ≈ transpose(a) + transpose(b) x = transpose(TA.Mul(a, b)) @@ -80,16 +83,17 @@ using Test: @test, @test_throws, @testset b = randn(ComplexF64, 2, 2) perm = (2, 1) - x = FI.permuteddims(LinearFunc(*)(2, a), perm) - @test x ≡ LinearFunc(*)(2, FI.permuteddims(a, perm)) + x = FI.permuteddims(LinearBroadcastFunction(*)(2, a), perm) + @test x ≡ LinearBroadcastFunction(*)(2, FI.permuteddims(a, perm)) @test copy(x) ≈ 2permutedims(a, perm) - x = FI.permuteddims(LinearFunc(conj)(a), perm) - @test x ≡ LinearFunc(conj)(FI.permuteddims(a, perm)) + x = FI.permuteddims(LinearBroadcastFunction(conj)(a), perm) + @test x ≡ LinearBroadcastFunction(conj)(FI.permuteddims(a, perm)) @test copy(x) ≈ conj(permutedims(a, perm)) - x = FI.permuteddims(LinearFunc(+)(a, b), perm) - @test x ≡ LinearFunc(+)(FI.permuteddims(a, perm), FI.permuteddims(b, perm)) + x = FI.permuteddims(LinearBroadcastFunction(+)(a, b), perm) + @test x ≡ + LinearBroadcastFunction(+)(FI.permuteddims(a, perm), FI.permuteddims(b, perm)) @test copy(x) ≈ permutedims(a, perm) + permutedims(b, perm) x = FI.permuteddims(TA.Mul(a, b), perm) @@ -100,58 +104,62 @@ using Test: @test, @test_throws, @testset style = BC.DefaultArrayStyle{2}() @test TA.broadcasted_linear(identity, a) ≡ a - @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ LinearFunc(*)(2, a) - @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ LinearFunc(*)(a, 2) - @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ LinearFunc(/)(a, 2) + @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ LinearBroadcastFunction(*)(2, a) + @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ LinearBroadcastFunction(*)(a, 2) + @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ LinearBroadcastFunction(/)(a, 2) @test TA.broadcasted_linear(style, identity, a) ≡ a - @test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ LinearFunc(*)(2, a) - @test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ LinearFunc(*)(a, 2) - @test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ LinearFunc(/)(a, 2) - @test TA.broadcasted_linear(style, conj, a) ≡ LinearFunc(conj)(a) + @test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ + LinearBroadcastFunction(*)(2, a) + @test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ + LinearBroadcastFunction(*)(a, 2) + @test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ + LinearBroadcastFunction(/)(a, 2) + @test TA.broadcasted_linear(style, conj, a) ≡ LinearBroadcastFunction(conj)(a) @test_throws ArgumentError TA.broadcasted_linear(style, exp, a) end @testset "LinearBroadcastFunction algebra" begin a = randn(ComplexF64, 3, 3) # Scaling absorbs coefficients - @test LinearFunc(*)(3, LinearFunc(*)(2, a)) ≡ TA.ScaledBroadcasted(6, a) + @test LinearBroadcastFunction(*)(3, LinearBroadcastFunction(*)(2, a)) ≡ + TA.ScaledBroadcasted(6, a) # Conjugation of scaled - x = LinearFunc(conj)(LinearFunc(*)(2im, a)) + x = LinearBroadcastFunction(conj)(LinearBroadcastFunction(*)(2im, a)) @test x ≡ TA.ScaledBroadcasted(-2im, TA.ConjBroadcasted(a)) # Double conjugation cancels - @test LinearFunc(conj)(LinearFunc(conj)(a)) ≡ a + @test LinearBroadcastFunction(conj)(LinearBroadcastFunction(conj)(a)) ≡ a # Subtraction b = randn(ComplexF64, 3, 3) - x = LinearFunc(-)(a, b) + x = LinearBroadcastFunction(-)(a, b) @test copy(x) ≈ a - b # Unary minus - x = LinearFunc(-)(a) + x = LinearBroadcastFunction(-)(a) @test copy(x) ≈ -a # Division - x = LinearFunc(/)(a, 2) + x = LinearBroadcastFunction(/)(a, 2) @test copy(x) ≈ a / 2 # Left division - x = LinearFunc(\)(2, a) + x = LinearBroadcastFunction(\)(2, a) @test copy(x) ≈ a / 2 # Scaling distributes over AddBroadcasted - ab = LinearFunc(+)(a, b) - x = LinearFunc(*)(3, ab) + ab = LinearBroadcastFunction(+)(a, b) + x = LinearBroadcastFunction(*)(3, ab) @test copy(x) ≈ 3a + 3b # Conjugation distributes over AddBroadcasted - x = LinearFunc(conj)(ab) + x = LinearBroadcastFunction(conj)(ab) @test copy(x) ≈ conj(a) + conj(b) # Conjugation distributes over Mul m = TA.Mul(a, b) - x = LinearFunc(conj)(m) + x = LinearBroadcastFunction(conj)(m) @test copy(x) ≈ conj(a) * conj(b) end @testset "AddBroadcasted flattening" begin @@ -160,17 +168,17 @@ using Test: @test, @test_throws, @testset c = randn(ComplexF64, 2, 2) # AddBroadcasted + array flattens - ab = LinearFunc(+)(a, b) - x = LinearFunc(+)(ab, c) + ab = LinearBroadcastFunction(+)(a, b) + x = LinearBroadcastFunction(+)(ab, c) @test TA.addends(x) === (a, b, c) # array + AddBroadcasted flattens - x = LinearFunc(+)(c, ab) + x = LinearBroadcastFunction(+)(c, ab) @test TA.addends(x) === (c, a, b) # AddBroadcasted + AddBroadcasted flattens - cd = LinearFunc(+)(c, a) - x = LinearFunc(+)(ab, cd) + cd = LinearBroadcastFunction(+)(c, a) + x = LinearBroadcastFunction(+)(ab, cd) @test TA.addends(x) === (a, b, c, a) end end From 9154717bb89d7914346d45cbbff3e7da1e12e51a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 00:36:46 -0400 Subject: [PATCH 10/29] Rename test_lazy.jl to test_linearbroadcasted.jl Co-Authored-By: Claude Opus 4.6 (1M context) --- test/{test_lazy.jl => test_linearbroadcasted.jl} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/{test_lazy.jl => test_linearbroadcasted.jl} (100%) diff --git a/test/test_lazy.jl b/test/test_linearbroadcasted.jl similarity index 100% rename from test/test_lazy.jl rename to test/test_linearbroadcasted.jl From 7f3b02633aa104a2367e90f3bc45e4a0fb908d20 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 00:46:40 -0400 Subject: [PATCH 11/29] Replace LinearBroadcastFunction with linearbroadcasted function Remove the LinearBroadcastFunction callable struct. Instead, use linearbroadcasted(f, args...) as a generic constructor that dispatches on f to create the appropriate LinearBroadcasted subtype. This mirrors the Base.Broadcast.broadcasted(f, args...) pattern. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 136 +++++++++++++-------------------- test/test_linearbroadcasted.jl | 95 +++++++++++------------ 2 files changed, 101 insertions(+), 130 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 9a363ec5..57214339 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -231,125 +231,95 @@ function add!(dest::AbstractArray, src::Mul, α::Number, β::Number) end # ---------------------------------------------------------------------------- # -# LinearBroadcastFunction — constructor API +# linearbroadcasted — construct LinearBroadcasted subtypes by dispatching on f # ---------------------------------------------------------------------------- # """ - LinearBroadcastFunction(f) + linearbroadcasted(f, args...) -Wrap a function `f` so that calling it produces a `LinearBroadcasted` expression -instead of eagerly computing. Analogous to `Base.BroadcastFunction`. +Construct a `LinearBroadcasted` subtype from function `f` and arguments. +Analogous to `Base.Broadcast.broadcasted(f, args...)`. # Examples ```julia -LinearBroadcastFunction(*)(2.0, a) # ScaledBroadcasted(2.0, a) -LinearBroadcastFunction(conj)(a) # ConjBroadcasted(a) -LinearBroadcastFunction(+)(a, b) # AddBroadcasted(a, b) +linearbroadcasted(*, 2.0, a) # ScaledBroadcasted(2.0, a) +linearbroadcasted(conj, a) # ConjBroadcasted(a) +linearbroadcasted(+, a, b) # AddBroadcasted(a, b) ``` """ -struct LinearBroadcastFunction{F} <: Function - f::F -end +function linearbroadcasted end # Scaling: Number * AbstractArray -function (::typeof(LinearBroadcastFunction(*)))(α::Number, a::AbstractArray) - return ScaledBroadcasted(α, a) -end -function (::typeof(LinearBroadcastFunction(*)))(a::AbstractArray, α::Number) - return ScaledBroadcasted(α, a) -end +linearbroadcasted(::typeof(*), α::Number, a::AbstractArray) = ScaledBroadcasted(α, a) +linearbroadcasted(::typeof(*), a::AbstractArray, α::Number) = ScaledBroadcasted(α, a) # Scaling of ScaledBroadcasted: absorb coefficient. -function (::typeof(LinearBroadcastFunction(*)))(α::Number, a::ScaledBroadcasted) +function linearbroadcasted(::typeof(*), α::Number, a::ScaledBroadcasted) return ScaledBroadcasted(α * coeff(a), unscaled(a)) end # Conjugation. -function (::typeof(LinearBroadcastFunction(conj)))(a::AbstractArray) - return ConjBroadcasted(a) -end -(::typeof(LinearBroadcastFunction(conj)))(a::AbstractArray{<:Real}) = a -(::typeof(LinearBroadcastFunction(conj)))(a::ConjBroadcasted) = unconj(a) -function (::typeof(LinearBroadcastFunction(conj)))(a::ScaledBroadcasted) - return ScaledBroadcasted( - conj(coeff(a)), LinearBroadcastFunction(conj)(unscaled(a)) - ) +linearbroadcasted(::typeof(conj), a::AbstractArray) = ConjBroadcasted(a) +linearbroadcasted(::typeof(conj), a::AbstractArray{<:Real}) = a +linearbroadcasted(::typeof(conj), a::ConjBroadcasted) = unconj(a) +function linearbroadcasted(::typeof(conj), a::ScaledBroadcasted) + return ScaledBroadcasted(conj(coeff(a)), linearbroadcasted(conj, unscaled(a))) end # Addition. -function (lf::typeof(LinearBroadcastFunction(+)))(a, b) - return AddBroadcasted(a, b) -end -function (lf::typeof(LinearBroadcastFunction(+)))(a, b, c, xs...) - return Base.afoldl(lf, lf(lf(a, b), c), xs...) +linearbroadcasted(::typeof(+), a, b) = AddBroadcasted(a, b) +function linearbroadcasted(f::typeof(+), a, b, c, xs...) + return Base.afoldl( + (x, y) -> linearbroadcasted(f, x, y), + linearbroadcasted(f, linearbroadcasted(f, a, b), c), + xs... + ) end # Flatten AddBroadcasted + anything. -function (::typeof(LinearBroadcastFunction(+)))(a::AddBroadcasted, b) - return AddBroadcasted(addends(a)..., b) -end -function (::typeof(LinearBroadcastFunction(+)))(a, b::AddBroadcasted) - return AddBroadcasted(a, addends(b)...) -end -function (::typeof(LinearBroadcastFunction(+)))(a::AddBroadcasted, b::AddBroadcasted) +linearbroadcasted(::typeof(+), a::AddBroadcasted, b) = AddBroadcasted(addends(a)..., b) +linearbroadcasted(::typeof(+), a, b::AddBroadcasted) = AddBroadcasted(a, addends(b)...) +function linearbroadcasted(::typeof(+), a::AddBroadcasted, b::AddBroadcasted) return AddBroadcasted(addends(a)..., addends(b)...) end -(::typeof(LinearBroadcastFunction(+)))(a) = a +linearbroadcasted(::typeof(+), a) = a # Subtraction. -function (::typeof(LinearBroadcastFunction(-)))(a, b) - return LinearBroadcastFunction(+)(a, LinearBroadcastFunction(*)(- 1, b)) -end -(::typeof(LinearBroadcastFunction(-)))(a) = LinearBroadcastFunction(*)(-1, a) +linearbroadcasted(::typeof(-), a, b) = linearbroadcasted(+, a, linearbroadcasted(*, -1, b)) +linearbroadcasted(::typeof(-), a) = linearbroadcasted(*, -1, a) # Division / left-division by scalars. -function (::typeof(LinearBroadcastFunction(/)))(a, b::Number) - return LinearBroadcastFunction(*)(inv(b), a) -end -function (::typeof(LinearBroadcastFunction(\)))(a::Number, b) - return LinearBroadcastFunction(*)(inv(a), b) -end +linearbroadcasted(::typeof(/), a, b::Number) = linearbroadcasted(*, inv(b), a) +linearbroadcasted(::typeof(\), a::Number, b) = linearbroadcasted(*, inv(a), b) # Identity. -(::typeof(LinearBroadcastFunction(identity)))(a) = a +linearbroadcasted(::typeof(identity), a) = a # Fix1/Fix2 wrappers for scalar multiplication/division. -function (lf::LinearBroadcastFunction{<:Base.Fix1{typeof(*)}})(a) - return LinearBroadcastFunction(*)(lf.f.x, a) -end -function (lf::LinearBroadcastFunction{<:Base.Fix2{typeof(*)}})(a) - return LinearBroadcastFunction(*)(a, lf.f.x) -end -function (lf::LinearBroadcastFunction{<:Base.Fix2{typeof(/)}})(a) - return LinearBroadcastFunction(/)(a, lf.f.x) -end +linearbroadcasted(f::Base.Fix1{typeof(*)}, a) = linearbroadcasted(*, f.x, a) +linearbroadcasted(f::Base.Fix2{typeof(*)}, a) = linearbroadcasted(*, a, f.x) +linearbroadcasted(f::Base.Fix2{typeof(/)}, a) = linearbroadcasted(/, a, f.x) # Scaling of AddBroadcasted distributes. -function (::typeof(LinearBroadcastFunction(*)))(α::Number, a::AddBroadcasted) - return LinearBroadcastFunction(+)( - map(x -> LinearBroadcastFunction(*)(α, x), addends(a))... - ) +function linearbroadcasted(::typeof(*), α::Number, a::AddBroadcasted) + return linearbroadcasted(+, map(x -> linearbroadcasted(*, α, x), addends(a))...) end # Conjugation of AddBroadcasted distributes. -function (::typeof(LinearBroadcastFunction(conj)))(a::AddBroadcasted) - return LinearBroadcastFunction(+)( - map(x -> LinearBroadcastFunction(conj)(x), addends(a))... - ) +function linearbroadcasted(::typeof(conj), a::AddBroadcasted) + return linearbroadcasted(+, map(x -> linearbroadcasted(conj, x), addends(a))...) end # Conjugation of Mul distributes. -function (::typeof(LinearBroadcastFunction(conj)))(a::Mul) +function linearbroadcasted(::typeof(conj), a::Mul) f = factors(a) - return Mul(LinearBroadcastFunction(conj)(f[1]), LinearBroadcastFunction(conj)(f[2])) + return Mul(linearbroadcasted(conj, f[1]), linearbroadcasted(conj, f[2])) end # Scaling of Mul: wrap in ScaledBroadcasted. -function (::typeof(LinearBroadcastFunction(*)))(α::Number, a::Mul) - return ScaledBroadcasted(α, a) -end +linearbroadcasted(::typeof(*), α::Number, a::Mul) = ScaledBroadcasted(α, a) # Number * Number passthrough (for broadcast lowering). -(::typeof(LinearBroadcastFunction(*)))(a::Number, b::Number) = a * b +linearbroadcasted(::typeof(*), a::Number, b::Number) = a * b # ---------------------------------------------------------------------------- # # Broadcast integration @@ -397,7 +367,7 @@ end to_linear(x) = x function to_linear(bc::BC.Broadcasted) - return LinearBroadcastFunction(bc.f)(to_linear.(bc.args)...) + return linearbroadcasted(bc.f, to_linear.(bc.args)...) end function broadcast_error(style, f) @@ -487,7 +457,7 @@ function BC.broadcasted( a::AbstractArray, b::AbstractArray ) - return LinearBroadcastFunction(+)(a, b) + return linearbroadcasted(+, a, b) end function BC.broadcasted( ::LinearBroadcastedStyle, @@ -496,7 +466,7 @@ function BC.broadcasted( b::BC.Broadcasted ) is_linear(b) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return LinearBroadcastFunction(+)(a, to_linear(b)) + return linearbroadcasted(+, a, to_linear(b)) end function BC.broadcasted( ::LinearBroadcastedStyle, @@ -505,7 +475,7 @@ function BC.broadcasted( b::AbstractArray ) is_linear(a) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return LinearBroadcastFunction(+)(to_linear(a), b) + return linearbroadcasted(+, to_linear(a), b) end function BC.broadcasted( ::LinearBroadcastedStyle, ::typeof(+), a::BC.Broadcasted, b::BC.Broadcasted @@ -513,20 +483,20 @@ function BC.broadcasted( return error("Not implemented") end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(*), α::Number, a::AbstractArray) - return LinearBroadcastFunction(*)(α, a) + return linearbroadcasted(*, α, a) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(*), a::AbstractArray, α::Number) - return LinearBroadcastFunction(*)(a, α) + return linearbroadcasted(*, a, α) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(\), α::Number, a::AbstractArray) - return LinearBroadcastFunction(\)(α, a) + return linearbroadcasted(\, α, a) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(/), a::AbstractArray, α::Number) - return LinearBroadcastFunction(/)(a, α) + return linearbroadcasted(/, a, α) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(-), a::AbstractArray) - return LinearBroadcastFunction(-)(a) + return linearbroadcasted(-, a) end function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(conj), a::AbstractArray) - return LinearBroadcastFunction(conj)(a) + return linearbroadcasted(conj, a) end diff --git a/test/test_linearbroadcasted.jl b/test/test_linearbroadcasted.jl index 16983bf1..7fc017fe 100644 --- a/test/test_linearbroadcasted.jl +++ b/test/test_linearbroadcasted.jl @@ -1,6 +1,6 @@ import FunctionImplementations as FI using Base.Broadcast: Broadcast as BC -using TensorAlgebra: TensorAlgebra as TA, LinearBroadcastFunction +using TensorAlgebra: TensorAlgebra as TA, linearbroadcasted using Test: @test, @test_throws, @testset @testset "LinearBroadcasted and Mul" begin @@ -9,29 +9,30 @@ using Test: @test, @test_throws, @testset b = randn(ComplexF64, 3, 3) c = randn(ComplexF64, 3, 3) - x = LinearBroadcastFunction(*)(2, a) + x = linearbroadcasted(*, 2, a) @test x ≡ TA.ScaledBroadcasted(2, a) @test copy(x) ≈ 2a - x = LinearBroadcastFunction(conj)(a) + x = linearbroadcasted(conj, a) @test x ≡ TA.ConjBroadcasted(a) @test copy(x) ≈ conj(a) @test conj(x) ≈ a - x = LinearBroadcastFunction(+)(a, b) + x = linearbroadcasted(+, a, b) @test x ≡ TA.AddBroadcasted(a, b) @test copy(x) ≈ a + b x = TA.Mul(a, b) @test copy(x) ≈ a * b - x = LinearBroadcastFunction(+)(TA.Mul(a, b), c) + x = linearbroadcasted(+, TA.Mul(a, b), c) @test x ≡ TA.AddBroadcasted(TA.Mul(a, b), c) @test copy(x) ≈ a * b + c - x = LinearBroadcastFunction(+)( - LinearBroadcastFunction(*)(2, TA.Mul(a, b)), - LinearBroadcastFunction(*)(3, c) + x = linearbroadcasted( + +, + linearbroadcasted(*, 2, TA.Mul(a, b)), + linearbroadcasted(*, 3, c) ) @test x ≡ TA.AddBroadcasted( TA.ScaledBroadcasted(2, TA.Mul(a, b)), TA.ScaledBroadcasted(3, c) @@ -42,16 +43,16 @@ using Test: @test, @test_throws, @testset a = randn(ComplexF64, 2, 2) b = randn(ComplexF64, 2, 2) - x = LinearBroadcastFunction(*)(2, a)' - @test x ≡ LinearBroadcastFunction(*)(2, a') + x = linearbroadcasted(*, 2, a)' + @test x ≡ linearbroadcasted(*, 2, a') @test copy(x) ≈ 2a' - x = LinearBroadcastFunction(conj)(a)' + x = linearbroadcasted(conj, a)' @test x ≡ transpose(a) @test copy(x) ≈ permutedims(a) - x = LinearBroadcastFunction(+)(a, b)' - @test x ≡ LinearBroadcastFunction(+)(a', b') + x = linearbroadcasted(+, a, b)' + @test x ≡ linearbroadcasted(+, a', b') @test copy(x) ≈ a' + b' x = TA.Mul(a, b)' @@ -62,16 +63,16 @@ using Test: @test, @test_throws, @testset a = randn(ComplexF64, 2, 2) b = randn(ComplexF64, 2, 2) - x = transpose(LinearBroadcastFunction(*)(2, a)) - @test x ≡ LinearBroadcastFunction(*)(2, transpose(a)) + x = transpose(linearbroadcasted(*, 2, a)) + @test x ≡ linearbroadcasted(*, 2, transpose(a)) @test copy(x) ≈ 2transpose(a) - x = transpose(LinearBroadcastFunction(conj)(a)) + x = transpose(linearbroadcasted(conj, a)) @test x ≡ adjoint(a) @test copy(x) ≈ permutedims(conj(a)) - x = transpose(LinearBroadcastFunction(+)(a, b)) - @test x ≡ LinearBroadcastFunction(+)(transpose(a), transpose(b)) + x = transpose(linearbroadcasted(+, a, b)) + @test x ≡ linearbroadcasted(+, transpose(a), transpose(b)) @test copy(x) ≈ transpose(a) + transpose(b) x = transpose(TA.Mul(a, b)) @@ -83,17 +84,17 @@ using Test: @test, @test_throws, @testset b = randn(ComplexF64, 2, 2) perm = (2, 1) - x = FI.permuteddims(LinearBroadcastFunction(*)(2, a), perm) - @test x ≡ LinearBroadcastFunction(*)(2, FI.permuteddims(a, perm)) + x = FI.permuteddims(linearbroadcasted(*, 2, a), perm) + @test x ≡ linearbroadcasted(*, 2, FI.permuteddims(a, perm)) @test copy(x) ≈ 2permutedims(a, perm) - x = FI.permuteddims(LinearBroadcastFunction(conj)(a), perm) - @test x ≡ LinearBroadcastFunction(conj)(FI.permuteddims(a, perm)) + x = FI.permuteddims(linearbroadcasted(conj, a), perm) + @test x ≡ linearbroadcasted(conj, FI.permuteddims(a, perm)) @test copy(x) ≈ conj(permutedims(a, perm)) - x = FI.permuteddims(LinearBroadcastFunction(+)(a, b), perm) + x = FI.permuteddims(linearbroadcasted(+, a, b), perm) @test x ≡ - LinearBroadcastFunction(+)(FI.permuteddims(a, perm), FI.permuteddims(b, perm)) + linearbroadcasted(+, FI.permuteddims(a, perm), FI.permuteddims(b, perm)) @test copy(x) ≈ permutedims(a, perm) + permutedims(b, perm) x = FI.permuteddims(TA.Mul(a, b), perm) @@ -104,62 +105,62 @@ using Test: @test, @test_throws, @testset style = BC.DefaultArrayStyle{2}() @test TA.broadcasted_linear(identity, a) ≡ a - @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ LinearBroadcastFunction(*)(2, a) - @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ LinearBroadcastFunction(*)(a, 2) - @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ LinearBroadcastFunction(/)(a, 2) + @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ linearbroadcasted(*, 2, a) + @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ linearbroadcasted(*, a, 2) + @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ linearbroadcasted(/, a, 2) @test TA.broadcasted_linear(style, identity, a) ≡ a @test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ - LinearBroadcastFunction(*)(2, a) + linearbroadcasted(*, 2, a) @test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ - LinearBroadcastFunction(*)(a, 2) + linearbroadcasted(*, a, 2) @test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ - LinearBroadcastFunction(/)(a, 2) - @test TA.broadcasted_linear(style, conj, a) ≡ LinearBroadcastFunction(conj)(a) + linearbroadcasted(/, a, 2) + @test TA.broadcasted_linear(style, conj, a) ≡ linearbroadcasted(conj, a) @test_throws ArgumentError TA.broadcasted_linear(style, exp, a) end @testset "LinearBroadcastFunction algebra" begin a = randn(ComplexF64, 3, 3) # Scaling absorbs coefficients - @test LinearBroadcastFunction(*)(3, LinearBroadcastFunction(*)(2, a)) ≡ + @test linearbroadcasted(*, 3, linearbroadcasted(*, 2, a)) ≡ TA.ScaledBroadcasted(6, a) # Conjugation of scaled - x = LinearBroadcastFunction(conj)(LinearBroadcastFunction(*)(2im, a)) + x = linearbroadcasted(conj, linearbroadcasted(*, 2im, a)) @test x ≡ TA.ScaledBroadcasted(-2im, TA.ConjBroadcasted(a)) # Double conjugation cancels - @test LinearBroadcastFunction(conj)(LinearBroadcastFunction(conj)(a)) ≡ a + @test linearbroadcasted(conj, linearbroadcasted(conj, a)) ≡ a # Subtraction b = randn(ComplexF64, 3, 3) - x = LinearBroadcastFunction(-)(a, b) + x = linearbroadcasted(-, a, b) @test copy(x) ≈ a - b # Unary minus - x = LinearBroadcastFunction(-)(a) + x = linearbroadcasted(-, a) @test copy(x) ≈ -a # Division - x = LinearBroadcastFunction(/)(a, 2) + x = linearbroadcasted(/, a, 2) @test copy(x) ≈ a / 2 # Left division - x = LinearBroadcastFunction(\)(2, a) + x = linearbroadcasted(\, 2, a) @test copy(x) ≈ a / 2 # Scaling distributes over AddBroadcasted - ab = LinearBroadcastFunction(+)(a, b) - x = LinearBroadcastFunction(*)(3, ab) + ab = linearbroadcasted(+, a, b) + x = linearbroadcasted(*, 3, ab) @test copy(x) ≈ 3a + 3b # Conjugation distributes over AddBroadcasted - x = LinearBroadcastFunction(conj)(ab) + x = linearbroadcasted(conj, ab) @test copy(x) ≈ conj(a) + conj(b) # Conjugation distributes over Mul m = TA.Mul(a, b) - x = LinearBroadcastFunction(conj)(m) + x = linearbroadcasted(conj, m) @test copy(x) ≈ conj(a) * conj(b) end @testset "AddBroadcasted flattening" begin @@ -168,17 +169,17 @@ using Test: @test, @test_throws, @testset c = randn(ComplexF64, 2, 2) # AddBroadcasted + array flattens - ab = LinearBroadcastFunction(+)(a, b) - x = LinearBroadcastFunction(+)(ab, c) + ab = linearbroadcasted(+, a, b) + x = linearbroadcasted(+, ab, c) @test TA.addends(x) === (a, b, c) # array + AddBroadcasted flattens - x = LinearBroadcastFunction(+)(c, ab) + x = linearbroadcasted(+, c, ab) @test TA.addends(x) === (c, a, b) # AddBroadcasted + AddBroadcasted flattens - cd = LinearBroadcastFunction(+)(c, a) - x = LinearBroadcastFunction(+)(ab, cd) + cd = linearbroadcasted(+, c, a) + x = linearbroadcasted(+, ab, cd) @test TA.addends(x) === (a, b, c, a) end end From ea8a252612e89d9dda3c27a15ee465969e6a690b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 09:06:48 -0400 Subject: [PATCH 12/29] Temporarily revert to v0.7.21 for cross-package development Mark as non-breaking (0.7.x) so downstream packages (NamedDimsArrays, GradedArrays) can Pkg.develop this branch without compat clashes. Will bump to 0.8.0 before merging. Co-Authored-By: Claude Opus 4.6 (1M context) --- Project.toml | 2 +- docs/Project.toml | 2 +- examples/Project.toml | 2 +- test/Project.toml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 3c65eb0d..40eac37e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.8.0" +version = "0.7.21" authors = ["ITensor developers and contributors"] [workspace] diff --git a/docs/Project.toml b/docs/Project.toml index 20650d9d..549440ac 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,4 +11,4 @@ path = ".." Documenter = "1.8.1" ITensorFormatter = "0.2.27" Literate = "2.20.1" -TensorAlgebra = "0.8" +TensorAlgebra = "0.7" diff --git a/examples/Project.toml b/examples/Project.toml index a8006256..9b0b1293 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" path = ".." [compat] -TensorAlgebra = "0.8" +TensorAlgebra = "0.7" diff --git a/test/Project.toml b/test/Project.toml index 5c0c3426..5a4045d2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.8" +TensorAlgebra = "0.7" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" From c2fdb4040684d3d8e1ec11dd26ac12b07471b3d2 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 09:41:47 -0400 Subject: [PATCH 13/29] Strip LinearBroadcasted types to minimal interface Remove adjoint, transpose, permuteddims, conj, StridedViews methods from LinearBroadcasted types and Mul. These were algebraic rewrite rules carried over from the AbstractArray era. Will bring back as needed when validating against NamedDimsArrays and GradedArrays PRs. Each type now only defines: axes, eltype, ndims, similar, operation, arguments. Plus the materialization chain (copy, copyto!, add!). Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 49 ------------------------- test/test_linearbroadcasted.jl | 65 +--------------------------------- 2 files changed, 1 insertion(+), 113 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 57214339..c68ddf77 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -1,7 +1,5 @@ import Base.Broadcast as BC -import FunctionImplementations as FI import LinearAlgebra as LA -import StridedViews as SV # TermInterface-like interface. iscall(x) = false @@ -54,17 +52,6 @@ function Base.similar(a::ScaledBroadcasted, elt::Type, ax) return similar(unscaled(a), elt, ax) end -function Base.adjoint(a::ScaledBroadcasted) - return ScaledBroadcasted(coeff(a), adjoint(unscaled(a))) -end -function Base.transpose(a::ScaledBroadcasted) - return ScaledBroadcasted(coeff(a), transpose(unscaled(a))) -end - -function FI.permuteddims(a::ScaledBroadcasted, perm) - return ScaledBroadcasted(coeff(a), FI.permuteddims(unscaled(a), perm)) -end - operation(::ScaledBroadcasted) = * arguments(a::ScaledBroadcasted) = (coeff(a), unscaled(a)) @@ -84,17 +71,6 @@ function Base.similar(a::ConjBroadcasted, elt::Type, ax) return similar(unconj(a), elt, ax) end -Base.conj(a::ConjBroadcasted) = unconj(a) -Base.adjoint(a::ConjBroadcasted) = transpose(unconj(a)) -Base.transpose(a::ConjBroadcasted) = adjoint(unconj(a)) - -function FI.permuteddims(a::ConjBroadcasted, perm) - return ConjBroadcasted(FI.permuteddims(unconj(a), perm)) -end - -SV.isstrided(a::ConjBroadcasted) = SV.isstrided(unconj(a)) -SV.StridedView(a::ConjBroadcasted) = conj(SV.StridedView(unconj(a))) - operation(::ConjBroadcasted) = conj arguments(a::ConjBroadcasted) = (unconj(a),) @@ -120,17 +96,6 @@ function Base.similar(a::AddBroadcasted, elt::Type, ax) return similar(BC.Broadcasted(+, addends(a)), elt, ax) end -function Base.adjoint(a::AddBroadcasted) - return AddBroadcasted(adjoint.(addends(a))...) -end -function Base.transpose(a::AddBroadcasted) - return AddBroadcasted(transpose.(addends(a))...) -end - -function FI.permuteddims(a::AddBroadcasted, perm) - return AddBroadcasted(Base.Fix2(FI.permuteddims, perm).(addends(a))...) -end - operation(::AddBroadcasted) = + arguments(a::AddBroadcasted) = addends(a) @@ -166,20 +131,6 @@ function Base.show(io::IO, a::Mul) return nothing end -function Base.adjoint(a::Mul) - f = factors(a) - return Mul(adjoint(f[2]), adjoint(f[1])) -end -function Base.transpose(a::Mul) - f = factors(a) - return Mul(transpose(f[2]), transpose(f[1])) -end - -function FI.permuteddims(a::Mul, perm) - perm == (1, 2) && return a - return transpose(a) -end - iscall(::Mul) = true operation(::Mul) = * arguments(a::Mul) = factors(a) diff --git a/test/test_linearbroadcasted.jl b/test/test_linearbroadcasted.jl index 7fc017fe..5dc23b6b 100644 --- a/test/test_linearbroadcasted.jl +++ b/test/test_linearbroadcasted.jl @@ -1,4 +1,3 @@ -import FunctionImplementations as FI using Base.Broadcast: Broadcast as BC using TensorAlgebra: TensorAlgebra as TA, linearbroadcasted using Test: @test, @test_throws, @testset @@ -16,7 +15,6 @@ using Test: @test, @test_throws, @testset x = linearbroadcasted(conj, a) @test x ≡ TA.ConjBroadcasted(a) @test copy(x) ≈ conj(a) - @test conj(x) ≈ a x = linearbroadcasted(+, a, b) @test x ≡ TA.AddBroadcasted(a, b) @@ -39,67 +37,6 @@ using Test: @test, @test_throws, @testset ) @test copy(x) ≈ 2 * a * b + 3 * c end - @testset "adjoint" begin - a = randn(ComplexF64, 2, 2) - b = randn(ComplexF64, 2, 2) - - x = linearbroadcasted(*, 2, a)' - @test x ≡ linearbroadcasted(*, 2, a') - @test copy(x) ≈ 2a' - - x = linearbroadcasted(conj, a)' - @test x ≡ transpose(a) - @test copy(x) ≈ permutedims(a) - - x = linearbroadcasted(+, a, b)' - @test x ≡ linearbroadcasted(+, a', b') - @test copy(x) ≈ a' + b' - - x = TA.Mul(a, b)' - @test x ≡ TA.Mul(b', a') - @test copy(x) ≈ b' * a' - end - @testset "transpose" begin - a = randn(ComplexF64, 2, 2) - b = randn(ComplexF64, 2, 2) - - x = transpose(linearbroadcasted(*, 2, a)) - @test x ≡ linearbroadcasted(*, 2, transpose(a)) - @test copy(x) ≈ 2transpose(a) - - x = transpose(linearbroadcasted(conj, a)) - @test x ≡ adjoint(a) - @test copy(x) ≈ permutedims(conj(a)) - - x = transpose(linearbroadcasted(+, a, b)) - @test x ≡ linearbroadcasted(+, transpose(a), transpose(b)) - @test copy(x) ≈ transpose(a) + transpose(b) - - x = transpose(TA.Mul(a, b)) - @test x ≡ TA.Mul(transpose(b), transpose(a)) - @test copy(x) ≈ transpose(b) * transpose(a) - end - @testset "permuteddims" begin - a = randn(ComplexF64, 2, 2) - b = randn(ComplexF64, 2, 2) - perm = (2, 1) - - x = FI.permuteddims(linearbroadcasted(*, 2, a), perm) - @test x ≡ linearbroadcasted(*, 2, FI.permuteddims(a, perm)) - @test copy(x) ≈ 2permutedims(a, perm) - - x = FI.permuteddims(linearbroadcasted(conj, a), perm) - @test x ≡ linearbroadcasted(conj, FI.permuteddims(a, perm)) - @test copy(x) ≈ conj(permutedims(a, perm)) - - x = FI.permuteddims(linearbroadcasted(+, a, b), perm) - @test x ≡ - linearbroadcasted(+, FI.permuteddims(a, perm), FI.permuteddims(b, perm)) - @test copy(x) ≈ permutedims(a, perm) + permutedims(b, perm) - - x = FI.permuteddims(TA.Mul(a, b), perm) - @test copy(x) ≈ permutedims(a * b, perm) - end @testset "linear broadcast lowering" begin a = randn(ComplexF64, 2, 2) style = BC.DefaultArrayStyle{2}() @@ -118,7 +55,7 @@ using Test: @test, @test_throws, @testset @test TA.broadcasted_linear(style, conj, a) ≡ linearbroadcasted(conj, a) @test_throws ArgumentError TA.broadcasted_linear(style, exp, a) end - @testset "LinearBroadcastFunction algebra" begin + @testset "linearbroadcasted algebra" begin a = randn(ComplexF64, 3, 3) # Scaling absorbs coefficients From b6c476eda0978578244208360b27a0ea8325d991 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 09:52:37 -0400 Subject: [PATCH 14/29] Add permutedimsopadd! and wire ConjBroadcasted to use it MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit permutedimsopadd!(dest, op, src, perm, α, β) is the op-parameterized materialization primitive. Default eagerly applies op. ConjBroadcasted now materializes through permutedimsopadd!(dest, conj, ...) giving downstream types a dispatch point to fuse conjugation. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 2 +- src/permutedimsadd.jl | 17 +++++++++++++++++ test/test_permutedimsadd.jl | 23 ++++++++++++++++++++++- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index c68ddf77..bda4a8ba 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -164,7 +164,7 @@ function add!(dest::AbstractArray, src::ScaledBroadcasted, α::Number, β::Numbe end function add!(dest::AbstractArray, src::ConjBroadcasted, α::Number, β::Number) - return add!(dest, conj(unconj(src)), α, β) + return permutedimsopadd!(dest, conj, unconj(src), ntuple(identity, ndims(dest)), α, β) end function add!(dest::AbstractArray, src::AddBroadcasted, α::Number, β::Number) diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index e96bf32a..a5f90ee7 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -43,6 +43,23 @@ function add!_broadcast(dest::AbstractArray, src::AbstractArray, α::Number, β: return dest end +""" + permutedimsopadd!(dest, op, src, perm, α, β) + +`dest = β * dest + α * permutedims(op(src), perm)`. + +The `op` is an element-wise linear map applied to `src` before permutation and +accumulation. Downstream array types can specialize on specific `op`s (e.g., +`op::typeof(conj)`) to fuse the operation without allocating. + +The default implementation eagerly applies `op`: `permutedimsadd!(dest, op.(src), perm, α, β)`. +""" +function permutedimsopadd!( + dest::AbstractArray, op, src::AbstractArray, perm, α::Number, β::Number + ) + return permutedimsadd!(dest, op.(src), perm, α, β) +end + """ permutedimsadd!(dest, src, perm, α, β) diff --git a/test/test_permutedimsadd.jl b/test/test_permutedimsadd.jl index 01b68165..48491e07 100644 --- a/test/test_permutedimsadd.jl +++ b/test/test_permutedimsadd.jl @@ -1,6 +1,6 @@ using Adapt: adapt using JLArrays: JLArray -using TensorAlgebra: add!, permutedimsadd! +using TensorAlgebra: add!, permutedimsadd!, permutedimsopadd! using Test: @test, @testset @testset "[permutedims]add!" begin @@ -47,4 +47,25 @@ using Test: @test, @testset @test b′ ≈ β * b + α * permutedims(a, perm) end end + @testset "permutedimsopadd! (arraytype=$arrayt)" for arrayt in (Array,) + dev = adapt(arrayt) + a = dev(randn(ComplexF64, 2, 2, 2)) + perm = (3, 1, 2) + α = 2 + for β in (0, 3) + b = dev(randn(ComplexF64, 2, 2, 2)) + b′ = copy(b) + permutedimsopadd!(b′, conj, a, perm, α, β) + @test b′ ≈ β * b + α * permutedims(conj(a), perm) + end + # identity op should match permutedimsadd! + for β in (0, 3) + b = dev(randn(ComplexF64, 2, 2, 2)) + b′ = copy(b) + b″ = copy(b) + permutedimsopadd!(b′, identity, a, perm, α, β) + permutedimsadd!(b″, a, perm, α, β) + @test b′ ≈ b″ + end + end end From d0aa5dda12b24dc03d222629522b02b32a30c529 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 09:56:55 -0400 Subject: [PATCH 15/29] Make permutedimsopadd! the single materialization primitive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All LinearBroadcasted materialization now funnels through permutedimsopadd!(dest, op, src, perm, α, β). add! and permutedimsadd! are convenience functions that call it with identity op and/or identity perm. This gives downstream array types (GradedArrays, etc.) a single function to implement for full LinearBroadcasted support. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/permutedimsadd.jl | 71 +++++++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index a5f90ee7..a36701c9 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -9,26 +9,8 @@ function maybestrided(as::AbstractArray...) return all(a -> SV.isstrided(a) && iscpu(a), as) ? SV.StridedView.(as) : as end -""" - add!(dest, src) - -Equivalent to `dest .+= src`, but maybe with a more optimized/specialized implementation. -Generally calls `add!(dest, src, true, true)`. -""" -add!(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, true) - -""" - add!(dest, src, α, β) - -Equivalent to `dest .= β .* dest .+ α .* src`, but maybe with a more optimized/specialized -implementation. -""" -function add!(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) - add!_broadcast(maybestrided(dest, src)..., α, β) - return dest -end - -# Broadcasting implementation of add!. +# Low-level broadcasting kernel: dest .= β .* dest .+ α .* src. +# This is the leaf implementation that does actual computation. function add!_broadcast(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) # This works around a bug in Strided.jl v2.3.4 and below when broadcasting # empty StridedViews: https://github.com/QuantumKitHub/Strided.jl/pull/50 @@ -43,23 +25,42 @@ function add!_broadcast(dest::AbstractArray, src::AbstractArray, α::Number, β: return dest end +# ---------------------------------------------------------------------------- # +# permutedimsopadd! — the single materialization primitive +# ---------------------------------------------------------------------------- # + """ permutedimsopadd!(dest, op, src, perm, α, β) -`dest = β * dest + α * permutedims(op(src), perm)`. +`dest = β * dest + α * permutedims(op.(src), perm)`. -The `op` is an element-wise linear map applied to `src` before permutation and -accumulation. Downstream array types can specialize on specific `op`s (e.g., -`op::typeof(conj)`) to fuse the operation without allocating. +This is the single materialization primitive for `LinearBroadcasted` types. +Downstream array types should implement this function. The `op` is an element-wise +linear map (e.g., `identity`, `conj`, `adjoint`, `transpose`, `Float32`). -The default implementation eagerly applies `op`: `permutedimsadd!(dest, op.(src), perm, α, β)`. +The default implementation eagerly applies `op` and `permutedims`, then accumulates +via broadcasting. """ function permutedimsopadd!( dest::AbstractArray, op, src::AbstractArray, perm, α::Number, β::Number ) - return permutedimsadd!(dest, op.(src), perm, α, β) + add!_broadcast(maybestrided(dest, permuteddims(op.(src), perm))..., α, β) + return dest end +# Optimization: identity op skips the broadcast of op. +function permutedimsopadd!( + dest::AbstractArray, ::typeof(identity), src::AbstractArray, perm, α::Number, + β::Number + ) + add!_broadcast(maybestrided(dest, permuteddims(src, perm))..., α, β) + return dest +end + +# ---------------------------------------------------------------------------- # +# Convenience functions that lower to permutedimsopadd! +# ---------------------------------------------------------------------------- # + """ permutedimsadd!(dest, src, perm, α, β) @@ -68,5 +69,21 @@ end function permutedimsadd!( dest::AbstractArray, src::AbstractArray, perm, α::Number, β::Number ) - return add!(dest, permuteddims(src, perm), α, β) + return permutedimsopadd!(dest, identity, src, perm, α, β) end + +""" + add!(dest, src, α, β) + +`dest = β * dest + α * src`. +""" +function add!(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) + return permutedimsopadd!(dest, identity, src, ntuple(identity, ndims(src)), α, β) +end + +""" + add!(dest, src) + +`dest .+= src`. +""" +add!(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, true) From af3582b204c2fb81b80706a801350243a1f1958b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 12:24:37 -0400 Subject: [PATCH 16/29] Remove eager ndims check from AddBroadcasted constructor Match Base.Broadcast's pattern of deferring axes checks to materialization time. The check is redundant with combine_axes in axes(::AddBroadcasted). Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index bda4a8ba..8980556e 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -78,12 +78,7 @@ arguments(a::ConjBroadcasted) = (unconj(a),) struct AddBroadcasted{Args <: Tuple} <: LinearBroadcasted args::Args - function AddBroadcasted(args...) - if !allequal(ndims, args) - error("All addends must have the same number of dimensions.") - end - return new{typeof(args)}(args) - end + AddBroadcasted(args...) = new{typeof(args)}(args) end addends(a::AddBroadcasted) = a.args From 1fd07e8981e41a57b4a843f84ba69fa29ccd13a6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 12:27:13 -0400 Subject: [PATCH 17/29] Keep to_linear/broadcasted_linear as separate internal helpers linearbroadcasted(f, args...) is the constructor (parallels broadcasted). to_linear(bc::Broadcasted) is the internal rewriter (parallels flatten). broadcasted_linear(style, f, args...) validates + converts. Mixing construction and rewriting in one function was confusing. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 8980556e..07a2a95c 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -311,6 +311,8 @@ function is_linear(bc::BC.Broadcasted) return broadcast_is_linear(bc.f, bc.args...) && all(is_linear, bc.args) end +# Rewrite a Broadcasted tree as a LinearBroadcasted tree. +# Internal helper, analogous to Broadcast.flatten for Broadcasted trees. to_linear(x) = x function to_linear(bc::BC.Broadcasted) return linearbroadcasted(bc.f, to_linear.(bc.args)...) @@ -323,6 +325,8 @@ function broadcast_error(style, f) ) ) end + +# Validate linearity and convert Broadcasted to LinearBroadcasted. function broadcasted_linear(style::BC.BroadcastStyle, f, args...) bc = BC.Broadcasted(style, f, args) is_linear(bc) || broadcast_error(style, f) From 51372afab2fe01cc04a4cf76c76e1f7f5eaa467f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 12:30:59 -0400 Subject: [PATCH 18/29] Merge is_linear + to_linear into single _to_linear pass Single recursive function validates linearity and converts in one pass, returning nothing on failure. Eliminates redundant tree traversal. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 48 ++++++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 07a2a95c..7e4d8171 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -306,31 +306,33 @@ function broadcast_is_linear( ) return true end -is_linear(x) = true -function is_linear(bc::BC.Broadcasted) - return broadcast_is_linear(bc.f, bc.args...) && all(is_linear, bc.args) +# Check if a Broadcasted tree is linear and convert it to LinearBroadcasted +# in a single recursive pass. Returns `nothing` if nonlinear. +_to_linear(x) = x +function _to_linear(bc::BC.Broadcasted) + broadcast_is_linear(bc.f, bc.args...) || return nothing + args = map(_to_linear, bc.args) + any(isnothing, args) && return nothing + return linearbroadcasted(bc.f, args...) end -# Rewrite a Broadcasted tree as a LinearBroadcasted tree. -# Internal helper, analogous to Broadcast.flatten for Broadcasted trees. -to_linear(x) = x -function to_linear(bc::BC.Broadcasted) - return linearbroadcasted(bc.f, to_linear.(bc.args)...) -end +""" + broadcasted_linear(style, f, args...) -function broadcast_error(style, f) - return throw( +Validate that a broadcast expression is linear and convert it to a `LinearBroadcasted` +expression tree. Throws `ArgumentError` if the expression is not linear. + +This is the entry point called by `BC.broadcasted(::LinearBroadcastedStyle, ...)` and +downstream broadcast styles that opt into linear broadcasting. +""" +function broadcasted_linear(style::BC.BroadcastStyle, f, args...) + result = _to_linear(BC.Broadcasted(style, f, args)) + result === nothing && throw( ArgumentError( "Only linear broadcast operations are supported for `$style`, got `$f`." ) ) -end - -# Validate linearity and convert Broadcasted to LinearBroadcasted. -function broadcasted_linear(style::BC.BroadcastStyle, f, args...) - bc = BC.Broadcasted(style, f, args) - is_linear(bc) || broadcast_error(style, f) - return to_linear(bc) + return result end function broadcasted_linear(f, args...) return broadcasted_linear(BC.combine_styles(args...), f, args...) @@ -415,8 +417,9 @@ function BC.broadcasted( a::AbstractArray, b::BC.Broadcasted ) - is_linear(b) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return linearbroadcasted(+, a, to_linear(b)) + b_linear = _to_linear(b) + b_linear === nothing && return BC.Broadcasted(+, to_broadcasted.((a, b))) + return linearbroadcasted(+, a, b_linear) end function BC.broadcasted( ::LinearBroadcastedStyle, @@ -424,8 +427,9 @@ function BC.broadcasted( a::BC.Broadcasted, b::AbstractArray ) - is_linear(a) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return linearbroadcasted(+, to_linear(a), b) + a_linear = _to_linear(a) + a_linear === nothing && return BC.Broadcasted(+, to_broadcasted.((a, b))) + return linearbroadcasted(+, a_linear, b) end function BC.broadcasted( ::LinearBroadcastedStyle, ::typeof(+), a::BC.Broadcasted, b::BC.Broadcasted From 6dc3d48faa14fc0d431f3b8b14678fd0aa90cb5c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 13:02:03 -0400 Subject: [PATCH 19/29] Delete LinearBroadcastedStyle and BC.broadcasted overloads MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move to instantiation-time conversion design. Downstream styles call tryflattenlinear(bc) from their own copy method instead of routing through LinearBroadcastedStyle. - Delete LinearBroadcastedStyle and all constructors - Delete all BC.broadcasted(::LinearBroadcastedStyle, ...) overloads - Delete to_broadcasted, broadcasted_linear - Rename broadcast_is_linear → islinearbroadcast - Rename _to_linear → tryflattenlinear - BroadcastStyle for LinearBroadcasted types delegates to wrapped type Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 186 ++++++++------------------------- test/test_linearbroadcasted.jl | 32 +++--- 2 files changed, 57 insertions(+), 161 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 7e4d8171..0014bd7b 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -268,189 +268,85 @@ linearbroadcasted(::typeof(*), α::Number, a::Mul) = ScaledBroadcasted(α, a) linearbroadcasted(::typeof(*), a::Number, b::Number) = a * b # ---------------------------------------------------------------------------- # -# Broadcast integration +# Broadcast integration — instantiation-time conversion # ---------------------------------------------------------------------------- # -broadcast_is_linear(f, args...) = false -broadcast_is_linear(::typeof(identity), ::Base.AbstractArrayOrBroadcasted) = true -broadcast_is_linear(::typeof(+), ::Base.AbstractArrayOrBroadcasted...) = true -broadcast_is_linear(::typeof(-), ::Base.AbstractArrayOrBroadcasted) = true -function broadcast_is_linear( +""" + islinearbroadcast(f, args...) -> Bool + +Per-node trait: can `(f, args...)` be expressed as a `LinearBroadcasted`? +Extensible by downstream packages for additional linear operations. +""" +islinearbroadcast(f, args...) = false +islinearbroadcast(::typeof(identity), ::Base.AbstractArrayOrBroadcasted) = true +islinearbroadcast(::typeof(+), ::Base.AbstractArrayOrBroadcasted...) = true +islinearbroadcast(::typeof(-), ::Base.AbstractArrayOrBroadcasted) = true +function islinearbroadcast( ::typeof(-), ::Base.AbstractArrayOrBroadcasted, ::Base.AbstractArrayOrBroadcasted ) return true end -broadcast_is_linear(::typeof(*), ::Number, ::Base.AbstractArrayOrBroadcasted) = true -broadcast_is_linear(::typeof(\), ::Number, ::Base.AbstractArrayOrBroadcasted) = true -broadcast_is_linear(::typeof(*), ::Base.AbstractArrayOrBroadcasted, ::Number) = true -broadcast_is_linear(::typeof(/), ::Base.AbstractArrayOrBroadcasted, ::Number) = true -function broadcast_is_linear( +islinearbroadcast(::typeof(*), ::Number, ::Base.AbstractArrayOrBroadcasted) = true +islinearbroadcast(::typeof(\), ::Number, ::Base.AbstractArrayOrBroadcasted) = true +islinearbroadcast(::typeof(*), ::Base.AbstractArrayOrBroadcasted, ::Number) = true +islinearbroadcast(::typeof(/), ::Base.AbstractArrayOrBroadcasted, ::Number) = true +function islinearbroadcast( ::typeof(*), ::Base.AbstractArrayOrBroadcasted, ::Base.AbstractArrayOrBroadcasted ) return false end -broadcast_is_linear(::typeof(*), ::Number, ::Number) = true -broadcast_is_linear(::typeof(conj), ::Base.AbstractArrayOrBroadcasted) = true -function broadcast_is_linear( +islinearbroadcast(::typeof(*), ::Number, ::Number) = true +islinearbroadcast(::typeof(conj), ::Base.AbstractArrayOrBroadcasted) = true +function islinearbroadcast( ::Base.Fix1{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted ) return true end -function broadcast_is_linear( +function islinearbroadcast( ::Base.Fix2{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted ) return true end -function broadcast_is_linear( +function islinearbroadcast( ::Base.Fix2{typeof(/), <:Number}, ::Base.AbstractArrayOrBroadcasted ) return true end -# Check if a Broadcasted tree is linear and convert it to LinearBroadcasted -# in a single recursive pass. Returns `nothing` if nonlinear. -_to_linear(x) = x -function _to_linear(bc::BC.Broadcasted) - broadcast_is_linear(bc.f, bc.args...) || return nothing - args = map(_to_linear, bc.args) - any(isnothing, args) && return nothing - return linearbroadcasted(bc.f, args...) -end """ - broadcasted_linear(style, f, args...) - -Validate that a broadcast expression is linear and convert it to a `LinearBroadcasted` -expression tree. Throws `ArgumentError` if the expression is not linear. + tryflattenlinear(bc::Broadcasted) -> LinearBroadcasted or nothing -This is the entry point called by `BC.broadcasted(::LinearBroadcastedStyle, ...)` and -downstream broadcast styles that opt into linear broadcasting. -""" -function broadcasted_linear(style::BC.BroadcastStyle, f, args...) - result = _to_linear(BC.Broadcasted(style, f, args)) - result === nothing && throw( - ArgumentError( - "Only linear broadcast operations are supported for `$style`, got `$f`." - ) - ) - return result -end -function broadcasted_linear(f, args...) - return broadcasted_linear(BC.combine_styles(args...), f, args...) -end +Recursively convert a `Broadcasted` tree to a `LinearBroadcasted` tree. +Returns `nothing` if any node is not linear (as determined by `islinearbroadcast`). -# Convert LinearBroadcasted / Mul back to Broadcasted for non-linear contexts. -to_broadcasted(x) = x -function to_broadcasted(a::AbstractArray) - (BC.BroadcastStyle(typeof(a)) isa LinearBroadcastedStyle) || return a - return BC.broadcasted(operation(a), to_broadcasted.(arguments(a))...) -end -function to_broadcasted(a::LinearBroadcasted) - return BC.broadcasted(operation(a), to_broadcasted.(arguments(a))...) -end -# Matmul isn't a broadcasting operation so we materialize when building a -# broadcast expression involving a Mul. -to_broadcasted(a::Mul) = *(factors(a)...) -to_broadcasted(bc::BC.Broadcasted) = BC.Broadcasted(bc.f, to_broadcasted.(bc.args)) +Analogous to `Broadcast.flatten` for `Broadcasted` trees, but converts to +`LinearBroadcasted` subtypes via `linearbroadcasted`. -# LinearBroadcastedStyle for broadcast interop. -struct LinearBroadcastedStyle{N, Style <: BC.AbstractArrayStyle{N}} <: - BC.AbstractArrayStyle{N} - style::Style -end -# TODO: This empty constructor is required in some Julia versions below v1.12 (such as -# Julia v1.10), try deleting it once we drop support for those versions. -function LinearBroadcastedStyle{N, Style}() where {N, Style <: BC.AbstractArrayStyle{N}} - return LinearBroadcastedStyle{N, Style}(Style()) -end -function LinearBroadcastedStyle{N, Style}( - ::Val{M} - ) where {M, N, Style <: BC.AbstractArrayStyle{N}} - return LinearBroadcastedStyle(Style(Val(M))) -end -function BC.BroadcastStyle(style1::LinearBroadcastedStyle, style2::LinearBroadcastedStyle) - style = BC.BroadcastStyle(style1.style, style2.style) - style ≡ BC.Unknown() && return BC.Unknown() - return LinearBroadcastedStyle(style) -end -function Base.similar(bc::BC.Broadcasted{<:LinearBroadcastedStyle}, elt::Type, ax) - return similar(BC.Broadcasted(bc.style.style, bc.f, bc.args, bc.axes), elt, ax) +Downstream styles call this from `Base.copy(::Broadcasted{MyStyle})` to +opt into linear broadcasting at materialization time. +""" +tryflattenlinear(x) = x +function tryflattenlinear(bc::BC.Broadcasted) + islinearbroadcast(bc.f, bc.args...) || return nothing + args = map(tryflattenlinear, bc.args) + any(isnothing, args) && return nothing + return linearbroadcasted(bc.f, args...) end -# BroadcastStyle for LinearBroadcasted subtypes. +# BroadcastStyle for LinearBroadcasted subtypes — delegate to the wrapped array type. function BC.BroadcastStyle(::Type{<:ScaledBroadcasted{<:Any, A}}) where {A} - return LinearBroadcastedStyle(BC.BroadcastStyle(A)) + return BC.BroadcastStyle(A) end function BC.BroadcastStyle(::Type{<:ConjBroadcasted{A}}) where {A} - return LinearBroadcastedStyle(BC.BroadcastStyle(A)) + return BC.BroadcastStyle(A) end function BC.BroadcastStyle(::Type{<:AddBroadcasted{Args}}) where {Args} - style = Base.promote_op(BC.combine_styles, fieldtypes(Args)...)() - return LinearBroadcastedStyle(style) + return Base.promote_op(BC.combine_styles, fieldtypes(Args)...)() end function BC.BroadcastStyle(::Type{<:Mul{A, B}}) where {A, B} - style = BC.BroadcastStyle(BC.BroadcastStyle(A), BC.BroadcastStyle(B)) - return LinearBroadcastedStyle(style) + return BC.BroadcastStyle(BC.BroadcastStyle(A), BC.BroadcastStyle(B)) end # Broadcast.materialize for LinearBroadcasted and Mul. BC.materialize(a::LinearBroadcasted) = copy(a) BC.materialize(a::Mul) = copy(a) - -# Backup definition: for broadcast operations that don't preserve lazy types -# (such as nonlinear operations), convert back to Broadcasted expressions. -function BC.broadcasted(::LinearBroadcastedStyle, f, args...) - return BC.Broadcasted(f, to_broadcasted.(args)) -end - -# Linear broadcast operations produce LinearBroadcasted / Mul types. -function BC.broadcasted( - ::LinearBroadcastedStyle, - ::typeof(+), - a::AbstractArray, - b::AbstractArray - ) - return linearbroadcasted(+, a, b) -end -function BC.broadcasted( - ::LinearBroadcastedStyle, - ::typeof(+), - a::AbstractArray, - b::BC.Broadcasted - ) - b_linear = _to_linear(b) - b_linear === nothing && return BC.Broadcasted(+, to_broadcasted.((a, b))) - return linearbroadcasted(+, a, b_linear) -end -function BC.broadcasted( - ::LinearBroadcastedStyle, - ::typeof(+), - a::BC.Broadcasted, - b::AbstractArray - ) - a_linear = _to_linear(a) - a_linear === nothing && return BC.Broadcasted(+, to_broadcasted.((a, b))) - return linearbroadcasted(+, a_linear, b) -end -function BC.broadcasted( - ::LinearBroadcastedStyle, ::typeof(+), a::BC.Broadcasted, b::BC.Broadcasted - ) - return error("Not implemented") -end -function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(*), α::Number, a::AbstractArray) - return linearbroadcasted(*, α, a) -end -function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(*), a::AbstractArray, α::Number) - return linearbroadcasted(*, a, α) -end -function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(\), α::Number, a::AbstractArray) - return linearbroadcasted(\, α, a) -end -function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(/), a::AbstractArray, α::Number) - return linearbroadcasted(/, a, α) -end -function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(-), a::AbstractArray) - return linearbroadcasted(-, a) -end -function BC.broadcasted(::LinearBroadcastedStyle, ::typeof(conj), a::AbstractArray) - return linearbroadcasted(conj, a) -end diff --git a/test/test_linearbroadcasted.jl b/test/test_linearbroadcasted.jl index 5dc23b6b..ca19f41b 100644 --- a/test/test_linearbroadcasted.jl +++ b/test/test_linearbroadcasted.jl @@ -37,23 +37,23 @@ using Test: @test, @test_throws, @testset ) @test copy(x) ≈ 2 * a * b + 3 * c end - @testset "linear broadcast lowering" begin + @testset "tryflattenlinear" begin a = randn(ComplexF64, 2, 2) - style = BC.DefaultArrayStyle{2}() - - @test TA.broadcasted_linear(identity, a) ≡ a - @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ linearbroadcasted(*, 2, a) - @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ linearbroadcasted(*, a, 2) - @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ linearbroadcasted(/, a, 2) - @test TA.broadcasted_linear(style, identity, a) ≡ a - @test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ - linearbroadcasted(*, 2, a) - @test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ - linearbroadcasted(*, a, 2) - @test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ - linearbroadcasted(/, a, 2) - @test TA.broadcasted_linear(style, conj, a) ≡ linearbroadcasted(conj, a) - @test_throws ArgumentError TA.broadcasted_linear(style, exp, a) + b = randn(ComplexF64, 2, 2) + + # Linear expressions convert successfully + @test TA.tryflattenlinear(BC.broadcasted(*, 2, a)) ≡ linearbroadcasted(*, 2, a) + @test TA.tryflattenlinear(BC.broadcasted(conj, a)) ≡ linearbroadcasted(conj, a) + @test TA.tryflattenlinear(BC.broadcasted(+, a, b)) ≡ linearbroadcasted(+, a, b) + @test TA.tryflattenlinear(BC.broadcasted(identity, a)) ≡ a + + # Nested linear expression + bc = BC.broadcasted(+, BC.broadcasted(*, 2, a), BC.broadcasted(*, 3, b)) + @test copy(TA.tryflattenlinear(bc)) ≈ 2a + 3b + + # Nonlinear expression returns nothing + @test TA.tryflattenlinear(BC.broadcasted(exp, a)) === nothing + @test TA.tryflattenlinear(BC.broadcasted(+, a, BC.broadcasted(exp, b))) === nothing end @testset "linearbroadcasted algebra" begin a = randn(ComplexF64, 3, 3) From 19026aeb5c44a85d121b80a7cdf64f7919020336 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 13:28:47 -0400 Subject: [PATCH 20/29] Inline add!_broadcast into permutedimsopadd!, remove identity dispatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Single implementation of permutedimsopadd! with the broadcasting logic inlined. Fuse op into the broadcast expression (op.(src′) inside .=) to avoid intermediate allocation. Branch on β first (NaN avoidance), then op === identity (skip no-op). Co-Authored-By: Claude Opus 4.6 (1M context) --- src/permutedimsadd.jl | 47 ++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index a36701c9..a9125eae 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -9,22 +9,6 @@ function maybestrided(as::AbstractArray...) return all(a -> SV.isstrided(a) && iscpu(a), as) ? SV.StridedView.(as) : as end -# Low-level broadcasting kernel: dest .= β .* dest .+ α .* src. -# This is the leaf implementation that does actual computation. -function add!_broadcast(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) - # This works around a bug in Strided.jl v2.3.4 and below when broadcasting - # empty StridedViews: https://github.com/QuantumKitHub/Strided.jl/pull/50 - # TODO: Delete this and bump the version of Strided.jl once that is fixed. - isempty(dest) && return dest - - if iszero(β) - dest .= α .* src - else - dest .= β .* dest .+ α .* src - end - return dest -end - # ---------------------------------------------------------------------------- # # permutedimsopadd! — the single materialization primitive # ---------------------------------------------------------------------------- # @@ -38,22 +22,31 @@ This is the single materialization primitive for `LinearBroadcasted` types. Downstream array types should implement this function. The `op` is an element-wise linear map (e.g., `identity`, `conj`, `adjoint`, `transpose`, `Float32`). -The default implementation eagerly applies `op` and `permutedims`, then accumulates -via broadcasting. +The default implementation applies `op` element-wise, permutes, then accumulates +via broadcasting with Strided.jl optimization when possible. """ function permutedimsopadd!( dest::AbstractArray, op, src::AbstractArray, perm, α::Number, β::Number ) - add!_broadcast(maybestrided(dest, permuteddims(op.(src), perm))..., α, β) - return dest -end + # This works around a bug in Strided.jl v2.3.4 and below when broadcasting + # empty StridedViews: https://github.com/QuantumKitHub/Strided.jl/pull/50 + # TODO: Delete this and bump the version of Strided.jl once that is fixed. + isempty(dest) && return dest -# Optimization: identity op skips the broadcast of op. -function permutedimsopadd!( - dest::AbstractArray, ::typeof(identity), src::AbstractArray, perm, α::Number, - β::Number - ) - add!_broadcast(maybestrided(dest, permuteddims(src, perm))..., α, β) + dest′, src′ = maybestrided(dest, permuteddims(src, perm)) + if iszero(β) + if op === identity + dest′ .= α .* src′ + else + dest′ .= α .* op.(src′) + end + else + if op === identity + dest′ .= β .* dest′ .+ α .* src′ + else + dest′ .= β .* dest′ .+ α .* op.(src′) + end + end return dest end From 39e6105a01203d911013f6ce6ffedc09ec707d12 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 17:29:20 -0400 Subject: [PATCH 21/29] Swap check ordering in permutedimsadd --- src/permutedimsadd.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index a9125eae..16ae8f7f 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -34,15 +34,15 @@ function permutedimsopadd!( isempty(dest) && return dest dest′, src′ = maybestrided(dest, permuteddims(src, perm)) - if iszero(β) - if op === identity + if op === identity + if iszero(β) dest′ .= α .* src′ else - dest′ .= α .* op.(src′) + dest′ .= β .* dest′ .+ α .* src′ end else - if op === identity - dest′ .= β .* dest′ .+ α .* src′ + if iszero(β) + dest′ .= α .* op.(src′) else dest′ .= β .* dest′ .+ α .* op.(src′) end From 931715ed4777b971a565b7b1513b7fc916da6ea9 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 18:06:26 -0400 Subject: [PATCH 22/29] Add operator composition --- src/linearbroadcasted.jl | 38 +++++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 0014bd7b..57970f7e 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -153,27 +153,43 @@ function Base.copyto!(dest::AbstractArray, src::Mul) return LA.mul!(dest, BC.materialize.(factors(src))...) end -# add! for LinearBroadcasted subtypes. -function add!(dest::AbstractArray, src::ScaledBroadcasted, α::Number, β::Number) - return add!(dest, unscaled(src), coeff(src) * α, β) +# Op composition with simplification rules. +_compose_op(::typeof(identity), g) = g +_compose_op(f, ::typeof(identity)) = f +_compose_op(::typeof(identity), ::typeof(identity)) = identity +_compose_op(::typeof(conj), ::typeof(conj)) = identity +_compose_op(f, g) = f ∘ g + +# permutedimsopadd! for LinearBroadcasted subtypes. +function permutedimsopadd!( + dest::AbstractArray, op, src::ScaledBroadcasted, perm, α::Number, β::Number + ) + return permutedimsopadd!(dest, op, unscaled(src), perm, op(coeff(src)) * α, β) end -function add!(dest::AbstractArray, src::ConjBroadcasted, α::Number, β::Number) - return permutedimsopadd!(dest, conj, unconj(src), ntuple(identity, ndims(dest)), α, β) +function permutedimsopadd!( + dest::AbstractArray, op, src::ConjBroadcasted, perm, α::Number, β::Number + ) + return permutedimsopadd!(dest, _compose_op(op, conj), unconj(src), perm, α, β) end -function add!(dest::AbstractArray, src::AddBroadcasted, α::Number, β::Number) +function permutedimsopadd!( + dest::AbstractArray, op, src::AddBroadcasted, perm, α::Number, β::Number + ) args = addends(src) - add!(dest, first(args), α, β) + permutedimsopadd!(dest, op, first(args), perm, α, β) for a in Base.tail(args) - add!(dest, a, α, true) + permutedimsopadd!(dest, op, a, perm, α, true) end return dest end -# add! for Mul materializes the factors and calls mul!. -function add!(dest::AbstractArray, src::Mul, α::Number, β::Number) - return LA.mul!(dest, BC.materialize.(factors(src))..., α, β) +# TODO: Replace with contractopadd! once that interface exists, +# to avoid materializing the Mul intermediate. +function permutedimsopadd!( + dest::AbstractArray, op, src::Mul, perm, α::Number, β::Number + ) + return permutedimsopadd!(dest, op, copy(src), perm, α, β) end # ---------------------------------------------------------------------------- # From 51f2a449c0d931a69eb90262df0bfcf00373f2e7 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 19:08:53 -0400 Subject: [PATCH 23/29] Widen add! signature to accept any src type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This allows add! to accept LinearBroadcasted types, which is needed for copyto!(dest, ::LinearBroadcasted) to work through the add! → permutedimsopadd! chain. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/permutedimsadd.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index 16ae8f7f..81ed7150 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -70,7 +70,7 @@ end `dest = β * dest + α * src`. """ -function add!(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) +function add!(dest::AbstractArray, src, α::Number, β::Number) return permutedimsopadd!(dest, identity, src, ntuple(identity, ndims(src)), α, β) end @@ -79,4 +79,4 @@ end `dest .+= src`. """ -add!(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, true) +add!(dest::AbstractArray, src) = add!(dest, src, true, true) From 41e9a2a008d5d2b7a7b625e0fa9a9c832a4e8849 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Mar 2026 22:48:22 -0400 Subject: [PATCH 24/29] Handle 0-dimensional arrays in permutedimsopadd! Add runtime branch for ndims == 0 to avoid wrapping in PermutedDimsArray, which doesn't support getindex() on 0-dimensional BlockSparseArray. Needed because GradedArray is a type alias and 0-dimensional contraction results don't match the alias. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/permutedimsadd.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index 81ed7150..dac9490a 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -28,6 +28,14 @@ via broadcasting with Strided.jl optimization when possible. function permutedimsopadd!( dest::AbstractArray, op, src::AbstractArray, perm, α::Number, β::Number ) + # TODO: Remove this 0-dimensional special case once GradedArray is its own type + # (not an alias for BlockSparseArray), so the GradedArray permutedimsopadd! overload + # catches the 0-dimensional contraction result. + if iszero(ndims(dest)) + dest[] = β * dest[] + α * op(src[]) + return dest + end + # This works around a bug in Strided.jl v2.3.4 and below when broadcasting # empty StridedViews: https://github.com/QuantumKitHub/Strided.jl/pull/50 # TODO: Delete this and bump the version of Strided.jl once that is fixed. From bf16e9cdd65da7b0d8f9a71c9ceb2c229a41d378 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 26 Mar 2026 10:12:24 -0400 Subject: [PATCH 25/29] Fix similar(::LinearBroadcasted) via _to_broadcasted converter Add _to_broadcasted to convert LinearBroadcasted tree back to Broadcasted (inverse of tryflattenlinear). Uses BC.Broadcasted constructor directly to avoid style-based dispatch that could re-enter LinearBroadcasted. Replace per-subtype similar methods with a single generic method that delegates to similar(_to_broadcasted(a), elt, ax). Also handle 0-dimensional arrays in permutedimsopadd! and widen add! signature to accept any src type. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 57970f7e..427ab45a 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -32,6 +32,29 @@ function Base.show(io::IO, a::LinearBroadcasted) end iscall(::LinearBroadcasted) = true +# Convert LinearBroadcasted back to Broadcasted (inverse of tryflattenlinear). +# Used by similar(::LinearBroadcasted) to delegate allocation to the Broadcasted system. +# Uses BC.Broadcasted constructor directly (not BC.broadcasted) to avoid style-based +# dispatch that could re-enter LinearBroadcasted conversion. +_to_broadcasted(a::AbstractArray) = a +_to_broadcasted(a::Number) = a +function _to_broadcasted(a::ScaledBroadcasted) + args = (coeff(a), _to_broadcasted(unscaled(a))) + return BC.Broadcasted(BC.combine_styles(args...), *, args) +end +function _to_broadcasted(a::ConjBroadcasted) + args = (_to_broadcasted(unconj(a)),) + return BC.Broadcasted(BC.combine_styles(args...), conj, args) +end +function _to_broadcasted(a::AddBroadcasted) + args = map(_to_broadcasted, addends(a)) + return BC.Broadcasted(BC.combine_styles(args...), +, args) +end + +function Base.similar(a::LinearBroadcasted, elt::Type, ax) + return similar(_to_broadcasted(a), elt, ax) +end + # --- ScaledBroadcasted -------------------------------------------------------- struct ScaledBroadcasted{C <: Number, A} <: LinearBroadcasted @@ -48,10 +71,6 @@ function Base.eltype(a::ScaledBroadcasted) end Base.ndims(a::ScaledBroadcasted) = ndims(unscaled(a)) -function Base.similar(a::ScaledBroadcasted, elt::Type, ax) - return similar(unscaled(a), elt, ax) -end - operation(::ScaledBroadcasted) = * arguments(a::ScaledBroadcasted) = (coeff(a), unscaled(a)) @@ -67,10 +86,6 @@ Base.axes(a::ConjBroadcasted) = axes(unconj(a)) Base.eltype(a::ConjBroadcasted) = eltype(unconj(a)) Base.ndims(a::ConjBroadcasted) = ndims(unconj(a)) -function Base.similar(a::ConjBroadcasted, elt::Type, ax) - return similar(unconj(a), elt, ax) -end - operation(::ConjBroadcasted) = conj arguments(a::ConjBroadcasted) = (unconj(a),) @@ -87,10 +102,6 @@ Base.axes(a::AddBroadcasted) = BC.combine_axes(addends(a)...) Base.eltype(a::AddBroadcasted) = Base.promote_op(+, eltype.(addends(a))...) Base.ndims(a::AddBroadcasted) = ndims(first(addends(a))) -function Base.similar(a::AddBroadcasted, elt::Type, ax) - return similar(BC.Broadcasted(+, addends(a)), elt, ax) -end - operation(::AddBroadcasted) = + arguments(a::AddBroadcasted) = addends(a) From 303926774e00960a05d7866f3c51acd10b890949 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 26 Mar 2026 10:16:18 -0400 Subject: [PATCH 26/29] Use generic _to_broadcasted via TermInterface-like operation/arguments Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index 427ab45a..bd030fb2 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -33,26 +33,19 @@ end iscall(::LinearBroadcasted) = true # Convert LinearBroadcasted back to Broadcasted (inverse of tryflattenlinear). -# Used by similar(::LinearBroadcasted) to delegate allocation to the Broadcasted system. # Uses BC.Broadcasted constructor directly (not BC.broadcasted) to avoid style-based # dispatch that could re-enter LinearBroadcasted conversion. -_to_broadcasted(a::AbstractArray) = a -_to_broadcasted(a::Number) = a -function _to_broadcasted(a::ScaledBroadcasted) - args = (coeff(a), _to_broadcasted(unscaled(a))) - return BC.Broadcasted(BC.combine_styles(args...), *, args) +_to_broadcasted(a) = a +function _to_broadcasted(a::LinearBroadcasted) + args = map(_to_broadcasted, arguments(a)) + return BC.Broadcasted(BC.combine_styles(args...), operation(a), args) end -function _to_broadcasted(a::ConjBroadcasted) - args = (_to_broadcasted(unconj(a)),) - return BC.Broadcasted(BC.combine_styles(args...), conj, args) -end -function _to_broadcasted(a::AddBroadcasted) - args = map(_to_broadcasted, addends(a)) - return BC.Broadcasted(BC.combine_styles(args...), +, args) +function BC.Broadcasted(a::LinearBroadcasted) + return _to_broadcasted(a) end function Base.similar(a::LinearBroadcasted, elt::Type, ax) - return similar(_to_broadcasted(a), elt, ax) + return similar(BC.Broadcasted(a), elt, ax) end # --- ScaledBroadcasted -------------------------------------------------------- From 62c4680c4cde918f3936dd4368bd77b03361f56d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 26 Mar 2026 10:25:50 -0400 Subject: [PATCH 27/29] Simplify Broadcasted(::LinearBroadcasted) converter Use generic implementation via operation/arguments instead of per-subtype _to_broadcasted methods. Remove per-subtype similar methods in favor of a single generic one. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/linearbroadcasted.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index bd030fb2..d1bfadf9 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -35,13 +35,11 @@ iscall(::LinearBroadcasted) = true # Convert LinearBroadcasted back to Broadcasted (inverse of tryflattenlinear). # Uses BC.Broadcasted constructor directly (not BC.broadcasted) to avoid style-based # dispatch that could re-enter LinearBroadcasted conversion. -_to_broadcasted(a) = a -function _to_broadcasted(a::LinearBroadcasted) - args = map(_to_broadcasted, arguments(a)) - return BC.Broadcasted(BC.combine_styles(args...), operation(a), args) -end function BC.Broadcasted(a::LinearBroadcasted) - return _to_broadcasted(a) + args = map(arguments(a)) do arg + arg isa LinearBroadcasted ? BC.Broadcasted(arg) : arg + end + return BC.Broadcasted(BC.combine_styles(args...), operation(a), args) end function Base.similar(a::LinearBroadcasted, elt::Type, ax) From 202c4c8237835bd65ca597dabcc96978b1c0aa91 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 26 Mar 2026 12:27:03 -0400 Subject: [PATCH 28/29] Bump versions --- Project.toml | 2 +- docs/Project.toml | 2 +- examples/Project.toml | 2 +- src/linearbroadcasted.jl | 2 +- test/Project.toml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 40eac37e..3c65eb0d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.7.21" +version = "0.8.0" authors = ["ITensor developers and contributors"] [workspace] diff --git a/docs/Project.toml b/docs/Project.toml index 549440ac..20650d9d 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,4 +11,4 @@ path = ".." Documenter = "1.8.1" ITensorFormatter = "0.2.27" Literate = "2.20.1" -TensorAlgebra = "0.7" +TensorAlgebra = "0.8" diff --git a/examples/Project.toml b/examples/Project.toml index 9b0b1293..a8006256 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" path = ".." [compat] -TensorAlgebra = "0.7" +TensorAlgebra = "0.8" diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl index d1bfadf9..637afbf7 100644 --- a/src/linearbroadcasted.jl +++ b/src/linearbroadcasted.jl @@ -37,7 +37,7 @@ iscall(::LinearBroadcasted) = true # dispatch that could re-enter LinearBroadcasted conversion. function BC.Broadcasted(a::LinearBroadcasted) args = map(arguments(a)) do arg - arg isa LinearBroadcasted ? BC.Broadcasted(arg) : arg + return arg isa LinearBroadcasted ? BC.Broadcasted(arg) : arg end return BC.Broadcasted(BC.combine_styles(args...), operation(a), args) end diff --git a/test/Project.toml b/test/Project.toml index 5a4045d2..5c0c3426 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.7" +TensorAlgebra = "0.8" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" From 6d7ab0cd98cafa4579cb0974396f5a439a9fe877 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 26 Mar 2026 13:16:17 -0400 Subject: [PATCH 29/29] Add tests for _compose_op, Broadcasted round-trip, add! with LinearBroadcasted, similar(AddBroadcasted), and 0-dim permutedimsopadd! Co-Authored-By: Claude Opus 4.6 (1M context) --- test/test_linearbroadcasted.jl | 61 ++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/test/test_linearbroadcasted.jl b/test/test_linearbroadcasted.jl index ca19f41b..e2afc16c 100644 --- a/test/test_linearbroadcasted.jl +++ b/test/test_linearbroadcasted.jl @@ -119,4 +119,65 @@ using Test: @test, @test_throws, @testset x = linearbroadcasted(+, ab, cd) @test TA.addends(x) === (a, b, c, a) end + @testset "similar(::AddBroadcasted) with LinearBroadcasted addends" begin + a = randn(ComplexF64, 3, 4) + b = randn(ComplexF64, 3, 4) + + # Addends are ScaledBroadcasted, not AbstractArray + lb = linearbroadcasted(+, linearbroadcasted(*, 2, a), linearbroadcasted(*, 3, b)) + s = similar(lb) + @test size(s) == (3, 4) + @test eltype(s) === ComplexF64 + end + @testset "_compose_op" begin + @test TA._compose_op(identity, identity) === identity + @test TA._compose_op(identity, conj) === conj + @test TA._compose_op(conj, identity) === conj + @test TA._compose_op(conj, conj) === identity + f = TA._compose_op(sqrt, conj) + @test f isa ComposedFunction + end + @testset "Broadcasted(::LinearBroadcasted) round-trip" begin + a = randn(ComplexF64, 3, 3) + b = randn(ComplexF64, 3, 3) + + lb = linearbroadcasted(+, linearbroadcasted(*, 2, a), linearbroadcasted(conj, b)) + bc = BC.Broadcasted(lb) + @test bc isa BC.Broadcasted + @test copy(bc) ≈ 2a + conj(b) + end + @testset "add! and copyto! with LinearBroadcasted" begin + a = randn(ComplexF64, 3, 3) + b = randn(ComplexF64, 3, 3) + + # add! with ScaledBroadcasted + dest = zeros(ComplexF64, 3, 3) + TA.add!(dest, linearbroadcasted(*, 2, a), true, false) + @test dest ≈ 2a + + # add! with AddBroadcasted + dest = zeros(ComplexF64, 3, 3) + TA.add!(dest, linearbroadcasted(+, a, b), true, false) + @test dest ≈ a + b + + # add! with ConjBroadcasted + dest = zeros(ComplexF64, 3, 3) + TA.add!(dest, linearbroadcasted(conj, a), true, false) + @test dest ≈ conj(a) + + # add! with β accumulation + dest = ones(ComplexF64, 3, 3) + TA.add!(dest, linearbroadcasted(*, 2, a), 3, 1) + @test dest ≈ ones(ComplexF64, 3, 3) + 6a + end + @testset "0-dimensional permutedimsopadd!" begin + a = fill(3.0 + 2.0im) + dest = fill(1.0 + 0.0im) + TA.permutedimsopadd!(dest, identity, a, (), 2, 3) + @test dest[] ≈ 3 * 1 + 2 * a[] + + dest = fill(0.0 + 0.0im) + TA.permutedimsopadd!(dest, conj, a, (), 1, 0) + @test dest[] ≈ conj(a[]) + end end