diff --git a/Project.toml b/Project.toml index 18575fb..3c65eb0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.7.20" +version = "0.8.0" authors = ["ITensor developers and contributors"] [workspace] diff --git a/docs/Project.toml b/docs/Project.toml index 549440a..20650d9 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,4 +11,4 @@ path = ".." Documenter = "1.8.1" ITensorFormatter = "0.2.27" Literate = "2.20.1" -TensorAlgebra = "0.7" +TensorAlgebra = "0.8" diff --git a/examples/Project.toml b/examples/Project.toml index 9b0b129..a800625 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" path = ".." [compat] -TensorAlgebra = "0.7" +TensorAlgebra = "0.8" diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index e3e402d..a1c966a 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -17,6 +17,6 @@ include("contract/allocate_output.jl") include("contract/contract_matricize.jl") include("factorizations.jl") include("matrixfunctions.jl") -include("lazyarrays.jl") +include("linearbroadcasted.jl") end diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl deleted file mode 100644 index ba71eda..0000000 --- a/src/lazyarrays.jl +++ /dev/null @@ -1,979 +0,0 @@ -import Base.Broadcast as BC -import FunctionImplementations as FI -import LinearAlgebra as LA -import StridedViews as SV - -# TermInterface-like interface. -iscall(x) = false -function operation end -function arguments end - -# Generic logic for lazy array linear algebra operations. -function +ₗ(a::AbstractArray, b::AbstractArray, c::AbstractArray, xs::AbstractArray...) - return Base.afoldl(+ₗ, +ₗ(+ₗ(a, b), c), xs...) -end --ₗ(a::AbstractArray, b::AbstractArray) = a +ₗ (-b) -function *ₗ(a::AbstractArray, b::AbstractArray, c::AbstractArray, xs::AbstractArray...) - return Base.afoldl(*ₗ, *ₗ(*ₗ(a, b), c), xs...) -end -*ₗ(a::AbstractArray, b::Number) = b *ₗ a -\ₗ(a::Number, b::AbstractArray) = inv(a) *ₗ b -/ₗ(a::AbstractArray, b::Number) = a *ₗ inv(b) -+ₗ(a::AbstractArray) = a --ₗ(a::AbstractArray) = -1 *ₗ a -conjed(a::AbstractArray{<:Real}) = a - -lazy_function(f) = error("No lazy function defined for `$f`.") -lazy_function(::typeof(+)) = +ₗ -lazy_function(::typeof(-)) = -ₗ -lazy_function(::typeof(*)) = *ₗ -lazy_function(::typeof(/)) = /ₗ -lazy_function(::typeof(\)) = \ₗ -lazy_function(::typeof(conj)) = conjed -lazy_function(::typeof(identity)) = identity -lazy_function(f::Base.Fix1{typeof(*), <:Number}) = Base.Fix1(*ₗ, f.x) -lazy_function(f::Base.Fix2{typeof(*), <:Number}) = Base.Fix2(*ₗ, f.x) -lazy_function(f::Base.Fix2{typeof(/), <:Number}) = Base.Fix2(/ₗ, f.x) - -broadcast_is_linear(f, args...) = false -broadcast_is_linear(::typeof(identity), ::Base.AbstractArrayOrBroadcasted) = true -broadcast_is_linear(::typeof(+), ::Base.AbstractArrayOrBroadcasted...) = true -broadcast_is_linear(::typeof(-), ::Base.AbstractArrayOrBroadcasted) = true -function broadcast_is_linear( - ::typeof(-), ::Base.AbstractArrayOrBroadcasted, ::Base.AbstractArrayOrBroadcasted - ) - return true -end -broadcast_is_linear(::typeof(*), ::Number, ::Base.AbstractArrayOrBroadcasted) = true -broadcast_is_linear(::typeof(\), ::Number, ::Base.AbstractArrayOrBroadcasted) = true -broadcast_is_linear(::typeof(*), ::Base.AbstractArrayOrBroadcasted, ::Number) = true -broadcast_is_linear(::typeof(/), ::Base.AbstractArrayOrBroadcasted, ::Number) = true -function broadcast_is_linear( - ::typeof(*), ::Base.AbstractArrayOrBroadcasted, ::Base.AbstractArrayOrBroadcasted - ) - return false -end -broadcast_is_linear(::typeof(*), ::Number, ::Number) = true -broadcast_is_linear(::typeof(conj), ::Base.AbstractArrayOrBroadcasted) = true -function broadcast_is_linear( - ::Base.Fix1{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted - ) - return true -end -function broadcast_is_linear( - ::Base.Fix2{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted - ) - return true -end -function broadcast_is_linear( - ::Base.Fix2{typeof(/), <:Number}, ::Base.AbstractArrayOrBroadcasted - ) - return true -end -is_linear(x) = true -function is_linear(bc::BC.Broadcasted) - return broadcast_is_linear(bc.f, bc.args...) && all(is_linear, bc.args) -end - -to_linear(x) = x -to_linear(bc::BC.Broadcasted) = lazy_function(bc.f)(to_linear.(bc.args)...) -function broadcast_error(style, f) - return throw( - ArgumentError( - "Only linear broadcast operations are supported for `$style`, got `$f`." - ) - ) -end -function broadcasted_linear(style::BC.BroadcastStyle, f, args...) - bc = BC.Broadcasted(style, f, args) - is_linear(bc) || broadcast_error(style, f) - return to_linear(bc) -end -broadcasted_linear(f, args...) = broadcasted_linear(BC.combine_styles(args...), f, args...) -# TODO: Use `Broadcast.broadcastable` interface for this? -to_broadcasted(x) = x -function to_broadcasted(a::AbstractArray) - (BC.BroadcastStyle(typeof(a)) isa LazyArrayStyle) || return a - return BC.broadcasted(operation(a), to_broadcasted.(arguments(a))...) -end -to_broadcasted(bc::BC.Broadcasted) = BC.Broadcasted(bc.f, to_broadcasted.(bc.args)) - -# For lazy arrays, define Broadcast methods in terms of lazy operations. -struct LazyArrayStyle{N, Style <: BC.AbstractArrayStyle{N}} <: BC.AbstractArrayStyle{N} - style::Style -end -# TODO: This empty constructor is required in some Julia versions below v1.12 (such as -# Julia v1.10), try deleting it once we drop support for those versions. -function LazyArrayStyle{N, Style}() where {N, Style <: BC.AbstractArrayStyle{N}} - return LazyArrayStyle{N, Style}(Style()) -end -function LazyArrayStyle{N, Style}(::Val{M}) where {M, N, Style <: BC.AbstractArrayStyle{N}} - return LazyArrayStyle(Style(Val(M))) -end -function BC.BroadcastStyle(style1::LazyArrayStyle, style2::LazyArrayStyle) - style = BC.BroadcastStyle(style1.style, style2.style) - style ≡ BC.Unknown() && return BC.Unknown() - return LazyArrayStyle(style) -end -function Base.similar(bc::BC.Broadcasted{<:LazyArrayStyle}, elt::Type, ax) - return similar(BC.Broadcasted(bc.style.style, bc.f, bc.args, bc.axes), elt, ax) -end -# Backup definition, for broadcast operations that don't preserve LazyArrays -# (such as nonlinear operations), convert back to Broadcasted expressions. -BC.broadcasted(::LazyArrayStyle, f, args...) = BC.Broadcasted(f, to_broadcasted.(args)) -BC.broadcasted(::LazyArrayStyle, ::typeof(+), a::AbstractArray, b::AbstractArray) = a +ₗ b -function BC.broadcasted(::LazyArrayStyle, ::typeof(+), a::AbstractArray, b::BC.Broadcasted) - is_linear(b) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return a +ₗ to_linear(b) -end -function BC.broadcasted(::LazyArrayStyle, ::typeof(+), a::BC.Broadcasted, b::AbstractArray) - is_linear(a) || return BC.Broadcasted(+, to_broadcasted.((a, b))) - return to_linear(a) +ₗ b -end -function BC.broadcasted(::LazyArrayStyle, ::typeof(+), a::BC.Broadcasted, b::BC.Broadcasted) - return error("Not implemented") -end -BC.broadcasted(::LazyArrayStyle, ::typeof(*), α::Number, a::AbstractArray) = α *ₗ a -BC.broadcasted(::LazyArrayStyle, ::typeof(*), a::AbstractArray, α::Number) = a *ₗ α -BC.broadcasted(::LazyArrayStyle, ::typeof(\), α::Number, a::AbstractArray) = α \ₗ a -BC.broadcasted(::LazyArrayStyle, ::typeof(/), a::AbstractArray, α::Number) = a /ₗ α -BC.broadcasted(::LazyArrayStyle, ::typeof(-), a::AbstractArray) = -ₗ(a) -BC.broadcasted(::LazyArrayStyle, ::typeof(conj), a::AbstractArray) = conjed(a) - -# Base overloads for lazy arrays. -function show_lazy(io::IO, a::AbstractArray) - print(io, operation(a), "(", join(arguments(a), ", "), ")") - return nothing -end -function show_lazy(io::IO, mime::MIME"text/plain", a::AbstractArray) - summary(io, a) - println(io, ":") - show(io, a) - return nothing -end - -# Generic constructors, accessors, and properties for ScaledArrays. -*ₗ(α::Number, a::AbstractArray) = ScaledArray(α, a) -unscaled(a::AbstractArray) = a -unscaled_type(arrayt::Type{<:AbstractArray}) = Base.promote_op(unscaled, arrayt) -coeff(a::AbstractArray) = true -coeff_type(arrayt::Type{<:AbstractArray}) = Base.promote_op(coeff, arrayt) -function scaled_eltype(coeff::Number, a::AbstractArray) - return Base.promote_op(*, typeof(coeff), eltype(a)) -end - -# Base overloads for ScaledArrays. -axes_scaled(a::AbstractArray) = axes(unscaled(a)) -size_scaled(a::AbstractArray) = size(unscaled(a)) -similar_scaled(a::AbstractArray) = similar(unscaled(a)) -similar_scaled(a::AbstractArray, elt::Type) = similar(unscaled(a), elt) -similar_scaled(a::AbstractArray, ax) = similar(unscaled(a), ax) -similar_scaled(a::AbstractArray, elt::Type, ax) = similar(unscaled(a), elt, ax) -getindex_scaled(a::AbstractArray, I...) = coeff(a) * getindex(unscaled(a), I...) -copyto!_scaled(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false) -show_scaled(io::IO, a::AbstractArray) = show_lazy(io, a) -show_scaled(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a) - -# Base overloads of adjoint and transpose for ScaledArrays. -adjoint_scaled(a::AbstractArray) = coeff(a) *ₗ adjoint(unscaled(a)) -transpose_scaled(a::AbstractArray) = coeff(a) *ₗ transpose(unscaled(a)) - -# Base.Broadcast overloads for ScaledArrays. -materialize_scaled(a::AbstractArray) = copy(a) -function BroadcastStyle_scaled(arrayt::Type{<:AbstractArray}) - return LazyArrayStyle(BC.BroadcastStyle(unscaled_type(arrayt))) -end - -# LinearAlgebra overloads for ScaledArrays. -function mul!_scaled( - dest::AbstractArray, - a::AbstractArray, - b::AbstractArray, - α::Number, - β::Number - ) - return LA.mul!(dest, unscaled(a), unscaled(b), coeff(a) * coeff(b) * α, β) -end - -# Lazy operations for ScaledArrays. -mulled_scaled(α::Number, a::AbstractArray) = (α * coeff(a)) *ₗ unscaled(a) -function mulled_scaled(a::AbstractArray, b::AbstractArray) - return (coeff(a) * coeff(b)) *ₗ (unscaled(a) *ₗ unscaled(b)) -end -conjed_scaled(a::AbstractArray) = conj(coeff(a)) *ₗ conjed(unscaled(a)) - -# TensorAlgebra overloads for ScaledArrays. -function add!_scaled(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) - return add!(dest, unscaled(src), coeff(src) * α, β) -end - -# TermInterface-like overloads for ScaledArrays. -iscall_scaled(::AbstractArray) = true -operation_scaled(::AbstractArray) = * -arguments_scaled(a::AbstractArray) = (coeff(a), unscaled(a)) - -# FunctionImplementations overloads for ScaledArrays. -permuteddims_scaled(a::AbstractArray, perm) = coeff(a) *ₗ FI.permuteddims(unscaled(a), perm) - -macro scaledarray_type(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - struct $ScaledArray{T, N, P <: AbstractArray{<:Any, N}, C <: Number} <: - $AbstractArray{T, N} - coeff::C - parent::P - function $ScaledArray(coeff::Number, a::AbstractArray) - T = $TensorAlgebra.scaled_eltype(coeff, a) - return new{T, ndims(a), typeof(a), typeof(coeff)}(coeff, a) - end - end - $TensorAlgebra.unscaled(a::$ScaledArray) = a.parent - function $TensorAlgebra.unscaled_type(arrayt::Type{<:$ScaledArray}) - return fieldtype(arrayt, :parent) - end - $TensorAlgebra.coeff(a::$ScaledArray) = a.coeff - function $TensorAlgebra.coeff_type(arrayt::Type{<:$ScaledArray}) - return fieldtype(arrayt, :coeff) - end - end - ) -end - -macro scaledarray_base(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.axes(a::$ScaledArray) = - $TensorAlgebra.axes_scaled(a) - Base.size(a::$ScaledArray) = - $TensorAlgebra.size_scaled(a) - Base.similar(a::$ScaledArray) = - $TensorAlgebra.similar_scaled(a) - function Base.similar(a::$ScaledArray, elt::Type) - return $TensorAlgebra.similar_scaled(a, elt) - end - Base.similar(a::$ScaledArray, ax) = - $TensorAlgebra.similar_scaled(a, ax) - Base.similar(a::$ScaledArray, ax::Tuple) = - $TensorAlgebra.similar_scaled(a, ax) - function Base.similar(a::$ScaledArray, elt::Type, ax) - return $TensorAlgebra.similar_scaled(a, elt, ax) - end - function Base.similar(a::$ScaledArray, elt::Type, ax::Dims) - return $TensorAlgebra.similar_scaled(a, elt, ax) - end - Base.@propagate_inbounds function Base.getindex(a::$ScaledArray, I...) - return $TensorAlgebra.getindex_scaled(a, I...) - end - function Base.copyto!(dest::$AbstractArray, src::$ScaledArray) - return $TensorAlgebra.copyto!_scaled(dest, src) - end - Base.show(io::IO, a::$ScaledArray) = - $TensorAlgebra.show_scaled(io, a) - function Base.show(io::IO, mime::MIME"text/plain", a::$ScaledArray) - return $TensorAlgebra.show_scaled(io, mime, a) - end - end - ) -end - -macro scaledarray_adjtrans(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.adjoint(a::$ScaledArray) = - $TensorAlgebra.adjoint_scaled(a) - Base.transpose(a::$ScaledArray) = - $TensorAlgebra.transpose_scaled(a) - end - ) -end - -macro scaledarray_broadcast(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - function Base.Broadcast.materialize(a::$ScaledArray) - return $TensorAlgebra.materialize_scaled(a) - end - function Base.Broadcast.BroadcastStyle(arrayt::Type{<:$ScaledArray}) - return $TensorAlgebra.BroadcastStyle_scaled(arrayt) - end - end - ) -end - -macro scaledarray_linearalgebra(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.LA.mul!( - dest::$AbstractArray{<:Any, 2}, - a::$ScaledArray{<:Any, 2}, - b::$ScaledArray{<:Any, 2}, - α::Number, β::Number - ) - return $TensorAlgebra.mul!_scaled(dest, a, b, α, β) - end - function $TensorAlgebra.LA.mul!( - dest::$AbstractArray{<:Any, 2}, - a::$AbstractArray{<:Any, 2}, - b::$ScaledArray{<:Any, 2}, - α::Number, β::Number - ) - return $TensorAlgebra.mul!_scaled(dest, a, b, α, β) - end - function $TensorAlgebra.LA.mul!( - dest::$AbstractArray{<:Any, 2}, - a::$ScaledArray{<:Any, 2}, - b::$AbstractArray{<:Any, 2}, - α::Number, β::Number - ) - return $TensorAlgebra.mul!_scaled(dest, a, b, α, β) - end - end - ) -end - -macro scaledarray_lazy(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.:*ₗ(α::Number, a::$ScaledArray) - return $TensorAlgebra.mulled_scaled(α, a) - end - function $TensorAlgebra.:*ₗ(a::$ScaledArray, b::$ScaledArray) - return $TensorAlgebra.mulled_scaled(a, b) - end - function $TensorAlgebra.:*ₗ(a::$AbstractArray, b::$ScaledArray) - return $TensorAlgebra.mulled_scaled(a, b) - end - function $TensorAlgebra.:*ₗ(a::$ScaledArray, b::$AbstractArray) - return $TensorAlgebra.mulled_scaled(a, b) - end - $TensorAlgebra.conjed(a::$ScaledArray) = - $TensorAlgebra.conjed_scaled(a) - end - ) -end - -macro scaledarray_tensoralgebra(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.add!( - dest::$AbstractArray, src::$ScaledArray, α::Number, β::Number - ) - return $TensorAlgebra.add!_scaled(dest, src, α, β) - end - end - ) -end - -macro scaledarray_terminterface(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.iscall(a::$ScaledArray) = $TensorAlgebra.iscall_scaled(a) - $TensorAlgebra.operation(a::$ScaledArray) = $TensorAlgebra.operation_scaled(a) - $TensorAlgebra.arguments(a::$ScaledArray) = $TensorAlgebra.arguments_scaled(a) - end - ) -end - -macro scaledarray_functionimplementations(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.FI.permuteddims(a::$ScaledArray, perm) - return $TensorAlgebra.permuteddims_scaled(a, perm) - end - end - ) -end - -macro scaledarray(ScaledArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.@scaledarray_base $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_adjtrans $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_broadcast $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_lazy $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_linearalgebra $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_tensoralgebra $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_terminterface $ScaledArray $AbstractArray - $TensorAlgebra.@scaledarray_functionimplementations $ScaledArray $AbstractArray - end - ) -end - -# Generic constructors for ConjArrays. -conjed(a::AbstractArray) = ConjArray(a) -conjed_type(arrayt::Type{<:AbstractArray}) = Base.promote_op(conjed, arrayt) - -# Base overloads for ConjArrays. -axes_conj(a::AbstractArray) = axes(conjed(a)) -size_conj(a::AbstractArray) = size(conjed(a)) -similar_conj(a::AbstractArray, elt::Type) = similar(conjed(a), elt) -similar_conj(a::AbstractArray, elt::Type, ax) = similar(conjed(a), elt, ax) -similar_conj(a::AbstractArray, ax) = similar(conjed(a), ax) -getindex_conj(a::AbstractArray, I...) = conj(getindex(conjed(a), I...)) -copyto!_conj(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false) -show_conj(io::IO, a::AbstractArray) = show_lazy(io, a) -show_conj(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a) - -# Base overloads of adjoint and transpose for ConjArrays. -adjoint_conj(a::AbstractArray) = transpose(conjed(a)) -transpose_conj(a::AbstractArray) = adjoint(conjed(a)) - -# Base.Broadcast overloads for ConjArrays. -materialize_conj(a::AbstractArray) = copy(a) -function BroadcastStyle_conj(arrayt::Type{<:AbstractArray}) - return LazyArrayStyle(BC.BroadcastStyle(conjed_type(arrayt))) -end - -# StridedViews overloads for ConjArrays. -isstrided_conj(a::AbstractArray) = SV.isstrided(conjed(a)) -StridedView_conj(a::AbstractArray) = conj(SV.StridedView(conjed(a))) - -# TermInterface-like overloads for ConjArrays. -iscall_conj(::AbstractArray) = true -operation_conj(::AbstractArray) = conj -arguments_conj(a::AbstractArray) = (conjed(a),) - -# FunctionImplementations overloads for ConjArrays. -permuteddims_conj(a::AbstractArray, perm) = conjed(FI.permuteddims(conjed(a), perm)) - -macro conjarray_type(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - struct $ConjArray{T, N, P <: AbstractArray{T, N}} <: $AbstractArray{T, N} - parent::P - end - $TensorAlgebra.conjed(a::$ConjArray) = a.parent - end - ) -end - -macro conjarray_base(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.axes(a::$ConjArray) = - $TensorAlgebra.axes_conj(a) - Base.size(a::$ConjArray) = - $TensorAlgebra.size_conj(a) - Base.similar(a::$ConjArray, elt::Type) = - $TensorAlgebra.similar_conj(a, elt) - function Base.similar(a::$ConjArray, elt::Type, ax) - return $TensorAlgebra.similar_conj(a, elt, ax) - end - function Base.similar(a::$ConjArray, elt::Type, ax::Dims) - return $TensorAlgebra.similar_conj(a, elt, ax) - end - Base.@propagate_inbounds function Base.getindex(a::$ConjArray, I...) - return $TensorAlgebra.getindex_conj(a, I...) - end - function Base.copyto!(dest::$AbstractArray, src::$ConjArray) - return $TensorAlgebra.copyto!_conj(dest, src) - end - Base.show(io::IO, a::$ConjArray) = $TensorAlgebra.show_conj(io, a) - function Base.show(io::IO, mime::MIME"text/plain", a::$ConjArray) - return $TensorAlgebra.show_conj(io, mime, a) - end - end - ) -end - -macro conjarray_adjtrans(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.adjoint(a::$ConjArray) = - $TensorAlgebra.adjoint_conj(a) - Base.transpose(a::$ConjArray) = - $TensorAlgebra.transpose_conj(a) - end - ) -end - -macro conjarray_broadcast(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.Broadcast.materialize(a::$ConjArray) = $TensorAlgebra.materialize_conj(a) - function Base.Broadcast.BroadcastStyle(arrayt::Type{<:$ConjArray}) - return $TensorAlgebra.BroadcastStyle_conj(arrayt) - end - end - ) -end - -macro conjarray_stridedviews(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.SV.isstrided(a::$ConjArray) = - $TensorAlgebra.isstrided_conj(a) - function $TensorAlgebra.SV.StridedView(a::$ConjArray) - return $TensorAlgebra.StridedView_conj(a) - end - end - ) -end - -macro conjarray_terminterface(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.iscall(a::$ConjArray) = $TensorAlgebra.iscall_conj(a) - $TensorAlgebra.operation(a::$ConjArray) = $TensorAlgebra.operation_conj(a) - $TensorAlgebra.arguments(a::$ConjArray) = $TensorAlgebra.arguments_conj(a) - end - ) -end - -macro conjarray_functionimplementations(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.FI.permuteddims(a::$ConjArray, perm) - return $TensorAlgebra.permuteddims_conj(a, perm) - end - end - ) -end - -macro conjarray(ConjArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.@conjarray_base $ConjArray $AbstractArray - $TensorAlgebra.@conjarray_adjtrans $ConjArray $AbstractArray - $TensorAlgebra.@conjarray_broadcast $ConjArray $AbstractArray - $TensorAlgebra.@conjarray_stridedviews $ConjArray $AbstractArray - $TensorAlgebra.@conjarray_terminterface $ConjArray $AbstractArray - $TensorAlgebra.@conjarray_functionimplementations $ConjArray $AbstractArray - end - ) -end - -# Generic constructors, accessors, and properties for AddArrays. -+ₗ(a::AbstractArray, b::AbstractArray) = AddArray(a, b) -addends(a::AbstractArray) = (a,) -addends_type(arrayt::Type{<:AbstractArray}) = Tuple{arrayt} -add_eltype(args::AbstractArray...) = Base.promote_op(+, eltype.(args)...) -function add_ndims(args::AbstractArray...) - return if allequal(ndims, args) - ndims(first(args)) - else - error("All addends must have the same number of dimensions.") - end -end - -# Base overloads for AddArrays. -add_axes(args::AbstractArray...) = BC.combine_axes(args...) -axes_add(a::AbstractArray) = add_axes(addends(a)...) -size_add(a::AbstractArray) = length.(axes_add(a)) -similar_add(a::AbstractArray) = similar(a, eltype(a)) -similar_add(a::AbstractArray, ax::Tuple) = similar(a, eltype(a), ax) -similar_add(a::AbstractArray, elt::Type) = similar(BC.Broadcasted(+, addends(a)), elt) -function similar_add(a::AbstractArray, elt::Type, ax) - return similar(BC.Broadcasted(+, addends(a)), elt, ax) -end -getindex_add(a::AbstractArray, I...) = sum(addend -> getindex(addend, I...), addends(a)) -copyto!_add(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false) -show_add(io::IO, a::AbstractArray) = show_lazy(io, a) -show_add(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a) - -# Base overloads of adjoint and transpose for AddArrays. -adjoint_add(a::AbstractArray) = +ₗ(adjoint.(addends(a))...) -transpose_add(a::AbstractArray) = +ₗ(transpose.(addends(a))...) - -# Base.Broadcast overloads for AddArrays. -materialize_add(a::AbstractArray) = copy(a) -function BroadcastStyle_add(arrayt::Type{<:AbstractArray}) - args_type = addends_type(arrayt) - style = Base.promote_op(BC.combine_styles, fieldtypes(args_type)...)() - return LazyArrayStyle(style) -end - -# TensorAlgebra overloads for AddArrays. -function add!_add(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) - args = addends(src) - add!(dest, first(args), α, β) - for a in Base.tail(args) - add!(dest, a, α, true) - end - return dest -end - -# Lazy operations for AddArrays. -added_add(a::AbstractArray, b::AbstractArray) = AddArray((addends(a)..., addends(b)...)...) -mulled_add(α::Number, a::AbstractArray) = +ₗ((α .*ₗ addends(a))...) -## TODO: Define multiplication of added arrays by expanding all combinations, treating -## both inputs as AddArrays. -## mulled_add(a::AbstractArray, b::AbstractArray) = +ₗ((Ref(a) .*ₗ addends(b))...) -## mulled_add(a::AddArray, b::AbstractArray) = +ₗ((addends(a) .*ₗ Ref(b))...) -## mulled_add(a::AddArray, b::AddArray) = +ₗ((Ref(a) .*ₗ addends(b))...) -conjed_add(a::AbstractArray) = +ₗ(conjed.(addends(a))...) - -# TermInterface-like overloads for AddArrays. -iscall_add(::AbstractArray) = true -operation_add(::AbstractArray) = + -arguments_add(a::AbstractArray) = addends(a) - -# FunctionImplementations overloads for AddArrays. -function permuteddims_add(a::AbstractArray, perm) - return +ₗ(Base.Fix2(FI.permuteddims, perm).(addends(a))...) -end - -macro addarray_type(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - struct $AddArray{T, N, Args <: Tuple{Vararg{AbstractArray{<:Any, N}}}} <: - $AbstractArray{T, N} - args::Args - function $AddArray(args::AbstractArray...) - T = $TensorAlgebra.add_eltype(args...) - N = $TensorAlgebra.add_ndims(args...) - return new{T, N, typeof(args)}(args) - end - end - $TensorAlgebra.addends(a::$AddArray) = a.args - function $TensorAlgebra.addends_type(arrayt::Type{<:$AddArray}) - return fieldtype(arrayt, :args) - end - end - ) -end - -macro addarray_base(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.axes(a::$AddArray) = $TensorAlgebra.axes_add(a) - Base.size(a::$AddArray) = $TensorAlgebra.size_add(a) - Base.similar(a::$AddArray) = $TensorAlgebra.similar_add(a) - Base.similar(a::$AddArray, ax::Tuple) = $TensorAlgebra.similar_add(a, ax) - Base.similar(a::$AddArray, elt::Type) = $TensorAlgebra.similar_add(a, elt) - function Base.similar( - a::$AddArray, elt::Type, - ax::Tuple{Union{Integer, Base.OneTo}, Vararg{Union{Integer, Base.OneTo}}} - ) - return $TensorAlgebra.similar_add(a, elt, ax) - end - function Base.similar(a::$AddArray, elt::Type, ax::Dims) - return $TensorAlgebra.similar_add(a, elt, ax) - end - function Base.similar(a::$AddArray, elt::Type, ax) - return $TensorAlgebra.similar_add(a, elt, ax) - end - Base.@propagate_inbounds function Base.getindex(a::$AddArray, I...) - return $TensorAlgebra.getindex_add(a, I...) - end - function Base.copyto!(dest::$AbstractArray, src::$AddArray) - return $TensorAlgebra.copyto!_add(dest, src) - end - Base.show(io::IO, a::$AddArray) = - $TensorAlgebra.show_add(io, a) - function Base.show(io::IO, mime::MIME"text/plain", a::$AddArray) - return $TensorAlgebra.show_add(io, mime, a) - end - end - ) -end - -macro addarray_adjtrans(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.adjoint(a::$AddArray) = - $TensorAlgebra.adjoint_add(a) - Base.transpose(a::$AddArray) = - $TensorAlgebra.transpose_add(a) - end - ) -end - -macro addarray_broadcast(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.Broadcast.materialize(a::$AddArray) = $TensorAlgebra.materialize_add(a) - function Base.Broadcast.BroadcastStyle(arrayt::Type{<:$AddArray}) - return $TensorAlgebra.BroadcastStyle_add(arrayt) - end - end - ) -end - -macro addarray_lazy(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.:+ₗ(a::$AbstractArray, b::$AddArray) - return $TensorAlgebra.added_add(a, b) - end - function $TensorAlgebra.:+ₗ(a::$AddArray, b::$AbstractArray) - return $TensorAlgebra.added_add(a, b) - end - $TensorAlgebra.:+ₗ(a::$AddArray, b::$AddArray) = - $TensorAlgebra.added_add(a, b) - $TensorAlgebra.:*ₗ(α::Number, a::$AddArray) = - $TensorAlgebra.mulled_add(α, a) - function $TensorAlgebra.:*ₗ(a::$AbstractArray, b::$AddArray) - return $TensorAlgebra.mulled_add(a, b) - end - function $TensorAlgebra.:*ₗ(a::$AddArray, b::$AbstractArray) - return $TensorAlgebra.mulled_add(a, b) - end - function $TensorAlgebra.:*ₗ(a::$AddArray, b::$AddArray) - return $TensorAlgebra.mulled_add(a, b) - end - $TensorAlgebra.conjed(a::$AddArray) = - $TensorAlgebra.conjed_add(a) - end - ) -end - -macro addarray_tensoralgebra(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.add!( - dest::$AbstractArray, src::$AddArray, α::Number, β::Number - ) - return $TensorAlgebra.add!_add(dest, src, α, β) - end - end - ) -end - -macro addarray_terminterface(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.iscall(a::$AddArray) = $TensorAlgebra.iscall_add(a) - $TensorAlgebra.operation(a::$AddArray) = $TensorAlgebra.operation_add(a) - $TensorAlgebra.arguments(a::$AddArray) = $TensorAlgebra.arguments_add(a) - end - ) -end - -macro addarray_functionimplementations(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.FI.permuteddims(a::$AddArray, perm) - return $TensorAlgebra.permuteddims_add(a, perm) - end - end - ) -end - -macro addarray(AddArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.@addarray_base $AddArray $AbstractArray - $TensorAlgebra.@addarray_adjtrans $AddArray $AbstractArray - $TensorAlgebra.@addarray_broadcast $AddArray $AbstractArray - $TensorAlgebra.@addarray_lazy $AddArray $AbstractArray - $TensorAlgebra.@addarray_tensoralgebra $AddArray $AbstractArray - $TensorAlgebra.@addarray_terminterface $AddArray $AbstractArray - $TensorAlgebra.@addarray_functionimplementations $AddArray $AbstractArray - end - ) -end - -# Generic constructors, accessors, and properties for MulArrays. -*ₗ(a::AbstractArray, b::AbstractArray) = MulArray(a, b) -factors(a::AbstractArray) = (a,) -factor_types(arrayt::Type{<:AbstractArray}) = Base.promote_op(factors, arrayt) -# Same as `LinearAlgebra.matprod`, but duplicated here since it is private. -matprod(x, y) = x * y + x * y -function mul_eltype(a::AbstractArray, b::AbstractArray) - return Base.promote_op(matprod, eltype(a), eltype(b)) -end -mul_ndims(a::AbstractArray, b::AbstractArray) = ndims(b) -mul_axes(a::AbstractArray, b::AbstractArray) = (axes(a, 1), axes(b, ndims(b))) - -# Base overloads for MulArrays. -eltype_mul(a::AbstractArray{T}) where {T} = T -axes_mul(a::AbstractArray) = mul_axes(factors(a)...) -size_mul(a::AbstractArray) = length.(axes_mul(a)) -similar_mul(a::AbstractArray) = similar(a, eltype(a)) -similar_mul(a::AbstractArray, ax::Tuple) = similar(a, eltype(a), ax) -similar_mul(a::AbstractArray, elt::Type) = similar(a, elt, axes(a)) -# TODO: Make use of both arguments to determine the output, maybe -# using `LinearAlgebra.matprod_dest(factors(a)..., elt)`? -similar_mul(a::AbstractArray, elt::Type, ax) = similar(last(factors(a)), elt, ax) -function mul_getindex(a1::AbstractMatrix, a2::AbstractMatrix, i::Int, j::Int) - return transpose(view(a1, i, :)) * view(a2, :, j) -end -function mul_getindex(a1::AbstractMatrix, a2::AbstractVector, i::Int) - return transpose(view(a1, i, :)) * a2 -end -function mul_getindex(a1::AbstractVector, a2::AbstractMatrix, j::Int) - return transpose(a1) * view(a2, :, j) -end -function getindex_mul(a::AbstractArray, i::Int) - I = Tuple(CartesianIndices(axes(a))[i]) - return getindex_mul(a, I...) -end -getindex_mul(a::AbstractArray, I::Vararg{Int}) = mul_getindex(factors(a)..., I...) -copyto!_mul(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false) -show_mul(io::IO, a::AbstractArray) = show_lazy(io, a) -show_mul(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a) - -# Base overloads of adjoint and transpose for MulArrays. -adjoint_mul(a::AbstractArray) = *ₗ(reverse(adjoint.(factors(a)))...) -transpose_mul(a::AbstractArray) = *ₗ(reverse(transpose.(factors(a)))...) - -# Base.Broadcast overloads for MulArrays. -materialize_mul(a::AbstractArray) = copy(a) -function BroadcastStyle_mul(arrayt::Type{<:AbstractArray}) - style = Base.promote_op(BC.combine_styles, factor_types(arrayt)...)() - return LazyArrayStyle(style) -end - -# TensorAlgebra overloads for MulArrays. -# We materialize the arguments here to avoid nested lazy evaluation. -# Rewrite rules should make it so that `MulArray` is a "leaf` node of the -# expression tree. -function add!_mul(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) - return LA.mul!(dest, BC.materialize.(factors(src))..., α, β) -end - -# Lazy operations for MulArrays. -conjed_mul(a::AbstractArray) = *ₗ(conjed.(factors(a))...) -# Matmul isn't a broadcasting operation so we materialize (i.e. -# perform the matrix multiplication) when building a broadcast -# expression involving a `MulArray`. -# TODO: Use `Broadcast.broadcastable` interface for this? -to_broadcasted_mul(a::AbstractArray) = *(factors(a)...) - -# TermInterface-like overloads for MulArrays. -iscall_mul(::AbstractArray) = true -operation_mul(::AbstractArray) = * -arguments_mul(a::AbstractArray) = factors(a) - -macro mularray_type(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - struct $MulArray{T, N, A <: AbstractArray, B <: AbstractArray} <: - $AbstractArray{T, N} - a::A - b::B - function $MulArray(a::AbstractArray, b::AbstractArray) - T = $TensorAlgebra.mul_eltype(a, b) - N = $TensorAlgebra.mul_ndims(a, b) - return new{T, N, typeof(a), typeof(b)}(a, b) - end - end - $TensorAlgebra.factors(a::$MulArray) = (a.a, a.b) - function $TensorAlgebra.factor_types(arrayt::Type{<:$MulArray}) - return (fieldtype(arrayt, :a), fieldtype(arrayt, :b)) - end - end - ) -end - -function copy_permuteddims(a::PermutedDimsArray{<:Any, 2, perm}) where {perm} - perm == (1, 2) && return copy(parent(a)) - return copy(transpose(parent(a))) -end - -macro mularray_base(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.eltype(a::$MulArray) = $TensorAlgebra.eltype_mul(a) - Base.axes(a::$MulArray) = $TensorAlgebra.axes_mul(a) - Base.size(a::$MulArray) = $TensorAlgebra.size_mul(a) - Base.similar(a::$MulArray) = $TensorAlgebra.similar_mul(a) - Base.similar(a::$MulArray, ax::Tuple) = $TensorAlgebra.similar_mul(a, ax) - Base.similar(a::$MulArray, elt::Type) = $TensorAlgebra.similar_mul(a, elt) - function Base.similar( - a::$MulArray, elt::Type, - ax::Tuple{Union{Integer, Base.OneTo}, Vararg{Union{Integer, Base.OneTo}}} - ) - return $TensorAlgebra.similar_mul(a, elt, ax) - end - function Base.similar(a::$MulArray, elt::Type, ax) - return $TensorAlgebra.similar_mul(a, elt, ax) - end - function Base.similar(a::$MulArray, elt::Type, ax::Dims) - return $TensorAlgebra.similar_mul(a, elt, ax) - end - Base.@propagate_inbounds function Base.getindex(a::$MulArray, I...) - return $TensorAlgebra.getindex_mul(a, I...) - end - function Base.copyto!(dest::$AbstractArray, src::$MulArray) - return $TensorAlgebra.copyto!_mul(dest, src) - end - Base.show(io::IO, a::$MulArray) = $TensorAlgebra.show_mul(io, a) - function Base.show(io::IO, mime::MIME"text/plain", a::$MulArray) - return $TensorAlgebra.show_mul(io, mime, a) - end - end - ) -end - -macro mularray_adjtrans(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.adjoint(a::$MulArray) = - $TensorAlgebra.adjoint_mul(a) - Base.transpose(a::$MulArray) = - $TensorAlgebra.transpose_mul(a) - end - ) -end - -macro mularray_broadcast(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - Base.Broadcast.materialize(a::$MulArray) = $TensorAlgebra.materialize_mul(a) - function Base.Broadcast.BroadcastStyle(arrayt::Type{<:$MulArray}) - return $TensorAlgebra.BroadcastStyle_mul(arrayt) - end - end - ) -end - -macro mularray_lazy(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.conjed(a::$MulArray) = $TensorAlgebra.conjed_mul(a) - function $TensorAlgebra.to_broadcasted(a::$MulArray) - return $TensorAlgebra.to_broadcasted_mul(a) - end - end - ) -end - -macro mularray_tensoralgebra(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - function $TensorAlgebra.add!( - dest::$AbstractArray, src::$MulArray, α::Number, β::Number - ) - return $TensorAlgebra.add!_mul(dest, src, α, β) - end - end - ) -end - -macro mularray_terminterface(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.iscall(a::$MulArray) = $TensorAlgebra.iscall_mul(a) - $TensorAlgebra.operation(a::$MulArray) = $TensorAlgebra.operation_mul(a) - $TensorAlgebra.arguments(a::$MulArray) = $TensorAlgebra.arguments_mul(a) - function Base.copy(a::PermutedDimsArray{<:Any, 2, <:Any, <:Any, $MulArray}) - return $TensorAlgebra.copy_permuteddims(a) - end - end - ) -end - -macro mularray(MulArray, AbstractArray = :AbstractArray) - return esc( - quote - $TensorAlgebra.@mularray_base $MulArray $AbstractArray - $TensorAlgebra.@mularray_adjtrans $MulArray $AbstractArray - $TensorAlgebra.@mularray_broadcast $MulArray $AbstractArray - $TensorAlgebra.@mularray_lazy $MulArray $AbstractArray - $TensorAlgebra.@mularray_tensoralgebra $MulArray $AbstractArray - $TensorAlgebra.@mularray_terminterface $MulArray $AbstractArray - end - ) -end - -# Define types. -@scaledarray_type ScaledArray -@scaledarray ScaledArray -@conjarray_type ConjArray -@conjarray ConjArray -@addarray_type AddArray -@addarray AddArray -@mularray_type MulArray -@mularray MulArray diff --git a/src/linearbroadcasted.jl b/src/linearbroadcasted.jl new file mode 100644 index 0000000..637afbf --- /dev/null +++ b/src/linearbroadcasted.jl @@ -0,0 +1,370 @@ +import Base.Broadcast as BC +import LinearAlgebra as LA + +# TermInterface-like interface. +iscall(x) = false +function operation end +function arguments end + +# ---------------------------------------------------------------------------- # +# LinearBroadcasted — lazy linear broadcast expressions (not <: AbstractArray) +# ---------------------------------------------------------------------------- # + +""" + LinearBroadcasted + +Abstract supertype for lazy linear broadcast expressions. Analogous to +`Base.Broadcast.Broadcasted` but restricted to linear operations. + +Materializes via the protocol: +copy(lb) = copyto!(similar(lb), lb) +copyto!(dest, lb) → add!(dest, lb, 1, 0) +""" +abstract type LinearBroadcasted end + +# Generic interface for LinearBroadcasted subtypes. +Base.axes(a::LinearBroadcasted, d::Int) = axes(a)[d] +Base.similar(a::LinearBroadcasted) = similar(a, eltype(a)) +Base.similar(a::LinearBroadcasted, elt::Type) = similar(a, elt, axes(a)) +function Base.show(io::IO, a::LinearBroadcasted) + print(io, operation(a), "(", join(arguments(a), ", "), ")") + return nothing +end +iscall(::LinearBroadcasted) = true + +# Convert LinearBroadcasted back to Broadcasted (inverse of tryflattenlinear). +# Uses BC.Broadcasted constructor directly (not BC.broadcasted) to avoid style-based +# dispatch that could re-enter LinearBroadcasted conversion. +function BC.Broadcasted(a::LinearBroadcasted) + args = map(arguments(a)) do arg + return arg isa LinearBroadcasted ? BC.Broadcasted(arg) : arg + end + return BC.Broadcasted(BC.combine_styles(args...), operation(a), args) +end + +function Base.similar(a::LinearBroadcasted, elt::Type, ax) + return similar(BC.Broadcasted(a), elt, ax) +end + +# --- ScaledBroadcasted -------------------------------------------------------- + +struct ScaledBroadcasted{C <: Number, A} <: LinearBroadcasted + coeff::C + parent::A +end + +unscaled(a::ScaledBroadcasted) = a.parent +coeff(a::ScaledBroadcasted) = a.coeff + +Base.axes(a::ScaledBroadcasted) = axes(unscaled(a)) +function Base.eltype(a::ScaledBroadcasted) + return Base.promote_op(*, typeof(coeff(a)), eltype(unscaled(a))) +end +Base.ndims(a::ScaledBroadcasted) = ndims(unscaled(a)) + +operation(::ScaledBroadcasted) = * +arguments(a::ScaledBroadcasted) = (coeff(a), unscaled(a)) + +# --- ConjBroadcasted ---------------------------------------------------------- + +struct ConjBroadcasted{A} <: LinearBroadcasted + parent::A +end + +unconj(a::ConjBroadcasted) = a.parent + +Base.axes(a::ConjBroadcasted) = axes(unconj(a)) +Base.eltype(a::ConjBroadcasted) = eltype(unconj(a)) +Base.ndims(a::ConjBroadcasted) = ndims(unconj(a)) + +operation(::ConjBroadcasted) = conj +arguments(a::ConjBroadcasted) = (unconj(a),) + +# --- AddBroadcasted ----------------------------------------------------------- + +struct AddBroadcasted{Args <: Tuple} <: LinearBroadcasted + args::Args + AddBroadcasted(args...) = new{typeof(args)}(args) +end + +addends(a::AddBroadcasted) = a.args + +Base.axes(a::AddBroadcasted) = BC.combine_axes(addends(a)...) +Base.eltype(a::AddBroadcasted) = Base.promote_op(+, eltype.(addends(a))...) +Base.ndims(a::AddBroadcasted) = ndims(first(addends(a))) + +operation(::AddBroadcasted) = + +arguments(a::AddBroadcasted) = addends(a) + +# ---------------------------------------------------------------------------- # +# Mul — lazy matrix multiplication (standalone, not LinearBroadcasted) +# ---------------------------------------------------------------------------- # + +# Same as `LinearAlgebra.matprod`, but duplicated here since it is private. +matprod(x, y) = x * y + x * y + +struct Mul{A, B} + a::A + b::B +end + +factors(a::Mul) = (a.a, a.b) + +Base.axes(a::Mul) = (axes(a.a, 1), axes(a.b, ndims(a.b))) +Base.axes(a::Mul, d::Int) = axes(a)[d] +Base.eltype(a::Mul) = Base.promote_op(matprod, eltype(a.a), eltype(a.b)) +Base.ndims(a::Mul) = ndims(a.b) +Base.size(a::Mul) = length.(axes(a)) + +Base.similar(a::Mul) = similar(a, eltype(a)) +Base.similar(a::Mul, elt::Type) = similar(a, elt, axes(a)) +function Base.similar(a::Mul, elt::Type, ax) + return similar(BC.materialize(last(factors(a))), elt, ax) +end + +function Base.show(io::IO, a::Mul) + f = factors(a) + print(io, "*(", f[1], ", ", f[2], ")") + return nothing +end + +iscall(::Mul) = true +operation(::Mul) = * +arguments(a::Mul) = factors(a) + +# ---------------------------------------------------------------------------- # +# Materialization protocol: copy, copyto!, add! +# ---------------------------------------------------------------------------- # + +function Base.copy(a::LinearBroadcasted) + return copyto!(similar(a), a) +end + +function Base.copy(a::Mul) + return copyto!(similar(a), a) +end + +# copyto! for LinearBroadcasted dispatches to add!. +function Base.copyto!(dest::AbstractArray, src::LinearBroadcasted) + return add!(dest, src, true, false) +end + +# copyto! for Mul dispatches to mul!. Materialize factors first since +# they may be LinearBroadcasted types. +function Base.copyto!(dest::AbstractArray, src::Mul) + return LA.mul!(dest, BC.materialize.(factors(src))...) +end + +# Op composition with simplification rules. +_compose_op(::typeof(identity), g) = g +_compose_op(f, ::typeof(identity)) = f +_compose_op(::typeof(identity), ::typeof(identity)) = identity +_compose_op(::typeof(conj), ::typeof(conj)) = identity +_compose_op(f, g) = f ∘ g + +# permutedimsopadd! for LinearBroadcasted subtypes. +function permutedimsopadd!( + dest::AbstractArray, op, src::ScaledBroadcasted, perm, α::Number, β::Number + ) + return permutedimsopadd!(dest, op, unscaled(src), perm, op(coeff(src)) * α, β) +end + +function permutedimsopadd!( + dest::AbstractArray, op, src::ConjBroadcasted, perm, α::Number, β::Number + ) + return permutedimsopadd!(dest, _compose_op(op, conj), unconj(src), perm, α, β) +end + +function permutedimsopadd!( + dest::AbstractArray, op, src::AddBroadcasted, perm, α::Number, β::Number + ) + args = addends(src) + permutedimsopadd!(dest, op, first(args), perm, α, β) + for a in Base.tail(args) + permutedimsopadd!(dest, op, a, perm, α, true) + end + return dest +end + +# TODO: Replace with contractopadd! once that interface exists, +# to avoid materializing the Mul intermediate. +function permutedimsopadd!( + dest::AbstractArray, op, src::Mul, perm, α::Number, β::Number + ) + return permutedimsopadd!(dest, op, copy(src), perm, α, β) +end + +# ---------------------------------------------------------------------------- # +# linearbroadcasted — construct LinearBroadcasted subtypes by dispatching on f +# ---------------------------------------------------------------------------- # + +""" + linearbroadcasted(f, args...) + +Construct a `LinearBroadcasted` subtype from function `f` and arguments. +Analogous to `Base.Broadcast.broadcasted(f, args...)`. + +# Examples + +```julia +linearbroadcasted(*, 2.0, a) # ScaledBroadcasted(2.0, a) +linearbroadcasted(conj, a) # ConjBroadcasted(a) +linearbroadcasted(+, a, b) # AddBroadcasted(a, b) +``` +""" +function linearbroadcasted end + +# Scaling: Number * AbstractArray +linearbroadcasted(::typeof(*), α::Number, a::AbstractArray) = ScaledBroadcasted(α, a) +linearbroadcasted(::typeof(*), a::AbstractArray, α::Number) = ScaledBroadcasted(α, a) +# Scaling of ScaledBroadcasted: absorb coefficient. +function linearbroadcasted(::typeof(*), α::Number, a::ScaledBroadcasted) + return ScaledBroadcasted(α * coeff(a), unscaled(a)) +end + +# Conjugation. +linearbroadcasted(::typeof(conj), a::AbstractArray) = ConjBroadcasted(a) +linearbroadcasted(::typeof(conj), a::AbstractArray{<:Real}) = a +linearbroadcasted(::typeof(conj), a::ConjBroadcasted) = unconj(a) +function linearbroadcasted(::typeof(conj), a::ScaledBroadcasted) + return ScaledBroadcasted(conj(coeff(a)), linearbroadcasted(conj, unscaled(a))) +end + +# Addition. +linearbroadcasted(::typeof(+), a, b) = AddBroadcasted(a, b) +function linearbroadcasted(f::typeof(+), a, b, c, xs...) + return Base.afoldl( + (x, y) -> linearbroadcasted(f, x, y), + linearbroadcasted(f, linearbroadcasted(f, a, b), c), + xs... + ) +end +# Flatten AddBroadcasted + anything. +linearbroadcasted(::typeof(+), a::AddBroadcasted, b) = AddBroadcasted(addends(a)..., b) +linearbroadcasted(::typeof(+), a, b::AddBroadcasted) = AddBroadcasted(a, addends(b)...) +function linearbroadcasted(::typeof(+), a::AddBroadcasted, b::AddBroadcasted) + return AddBroadcasted(addends(a)..., addends(b)...) +end +linearbroadcasted(::typeof(+), a) = a + +# Subtraction. +linearbroadcasted(::typeof(-), a, b) = linearbroadcasted(+, a, linearbroadcasted(*, -1, b)) +linearbroadcasted(::typeof(-), a) = linearbroadcasted(*, -1, a) + +# Division / left-division by scalars. +linearbroadcasted(::typeof(/), a, b::Number) = linearbroadcasted(*, inv(b), a) +linearbroadcasted(::typeof(\), a::Number, b) = linearbroadcasted(*, inv(a), b) + +# Identity. +linearbroadcasted(::typeof(identity), a) = a + +# Fix1/Fix2 wrappers for scalar multiplication/division. +linearbroadcasted(f::Base.Fix1{typeof(*)}, a) = linearbroadcasted(*, f.x, a) +linearbroadcasted(f::Base.Fix2{typeof(*)}, a) = linearbroadcasted(*, a, f.x) +linearbroadcasted(f::Base.Fix2{typeof(/)}, a) = linearbroadcasted(/, a, f.x) + +# Scaling of AddBroadcasted distributes. +function linearbroadcasted(::typeof(*), α::Number, a::AddBroadcasted) + return linearbroadcasted(+, map(x -> linearbroadcasted(*, α, x), addends(a))...) +end + +# Conjugation of AddBroadcasted distributes. +function linearbroadcasted(::typeof(conj), a::AddBroadcasted) + return linearbroadcasted(+, map(x -> linearbroadcasted(conj, x), addends(a))...) +end + +# Conjugation of Mul distributes. +function linearbroadcasted(::typeof(conj), a::Mul) + f = factors(a) + return Mul(linearbroadcasted(conj, f[1]), linearbroadcasted(conj, f[2])) +end + +# Scaling of Mul: wrap in ScaledBroadcasted. +linearbroadcasted(::typeof(*), α::Number, a::Mul) = ScaledBroadcasted(α, a) + +# Number * Number passthrough (for broadcast lowering). +linearbroadcasted(::typeof(*), a::Number, b::Number) = a * b + +# ---------------------------------------------------------------------------- # +# Broadcast integration — instantiation-time conversion +# ---------------------------------------------------------------------------- # + +""" + islinearbroadcast(f, args...) -> Bool + +Per-node trait: can `(f, args...)` be expressed as a `LinearBroadcasted`? +Extensible by downstream packages for additional linear operations. +""" +islinearbroadcast(f, args...) = false +islinearbroadcast(::typeof(identity), ::Base.AbstractArrayOrBroadcasted) = true +islinearbroadcast(::typeof(+), ::Base.AbstractArrayOrBroadcasted...) = true +islinearbroadcast(::typeof(-), ::Base.AbstractArrayOrBroadcasted) = true +function islinearbroadcast( + ::typeof(-), ::Base.AbstractArrayOrBroadcasted, ::Base.AbstractArrayOrBroadcasted + ) + return true +end +islinearbroadcast(::typeof(*), ::Number, ::Base.AbstractArrayOrBroadcasted) = true +islinearbroadcast(::typeof(\), ::Number, ::Base.AbstractArrayOrBroadcasted) = true +islinearbroadcast(::typeof(*), ::Base.AbstractArrayOrBroadcasted, ::Number) = true +islinearbroadcast(::typeof(/), ::Base.AbstractArrayOrBroadcasted, ::Number) = true +function islinearbroadcast( + ::typeof(*), ::Base.AbstractArrayOrBroadcasted, ::Base.AbstractArrayOrBroadcasted + ) + return false +end +islinearbroadcast(::typeof(*), ::Number, ::Number) = true +islinearbroadcast(::typeof(conj), ::Base.AbstractArrayOrBroadcasted) = true +function islinearbroadcast( + ::Base.Fix1{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted + ) + return true +end +function islinearbroadcast( + ::Base.Fix2{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted + ) + return true +end +function islinearbroadcast( + ::Base.Fix2{typeof(/), <:Number}, ::Base.AbstractArrayOrBroadcasted + ) + return true +end + +""" + tryflattenlinear(bc::Broadcasted) -> LinearBroadcasted or nothing + +Recursively convert a `Broadcasted` tree to a `LinearBroadcasted` tree. +Returns `nothing` if any node is not linear (as determined by `islinearbroadcast`). + +Analogous to `Broadcast.flatten` for `Broadcasted` trees, but converts to +`LinearBroadcasted` subtypes via `linearbroadcasted`. + +Downstream styles call this from `Base.copy(::Broadcasted{MyStyle})` to +opt into linear broadcasting at materialization time. +""" +tryflattenlinear(x) = x +function tryflattenlinear(bc::BC.Broadcasted) + islinearbroadcast(bc.f, bc.args...) || return nothing + args = map(tryflattenlinear, bc.args) + any(isnothing, args) && return nothing + return linearbroadcasted(bc.f, args...) +end + +# BroadcastStyle for LinearBroadcasted subtypes — delegate to the wrapped array type. +function BC.BroadcastStyle(::Type{<:ScaledBroadcasted{<:Any, A}}) where {A} + return BC.BroadcastStyle(A) +end +function BC.BroadcastStyle(::Type{<:ConjBroadcasted{A}}) where {A} + return BC.BroadcastStyle(A) +end +function BC.BroadcastStyle(::Type{<:AddBroadcasted{Args}}) where {Args} + return Base.promote_op(BC.combine_styles, fieldtypes(Args)...)() +end +function BC.BroadcastStyle(::Type{<:Mul{A, B}}) where {A, B} + return BC.BroadcastStyle(BC.BroadcastStyle(A), BC.BroadcastStyle(B)) +end + +# Broadcast.materialize for LinearBroadcasted and Mul. +BC.materialize(a::LinearBroadcasted) = copy(a) +BC.materialize(a::Mul) = copy(a) diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index e96bf32..dac9490 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -9,40 +9,59 @@ function maybestrided(as::AbstractArray...) return all(a -> SV.isstrided(a) && iscpu(a), as) ? SV.StridedView.(as) : as end -""" - add!(dest, src) +# ---------------------------------------------------------------------------- # +# permutedimsopadd! — the single materialization primitive +# ---------------------------------------------------------------------------- # -Equivalent to `dest .+= src`, but maybe with a more optimized/specialized implementation. -Generally calls `add!(dest, src, true, true)`. """ -add!(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, true) + permutedimsopadd!(dest, op, src, perm, α, β) -""" - add!(dest, src, α, β) +`dest = β * dest + α * permutedims(op.(src), perm)`. + +This is the single materialization primitive for `LinearBroadcasted` types. +Downstream array types should implement this function. The `op` is an element-wise +linear map (e.g., `identity`, `conj`, `adjoint`, `transpose`, `Float32`). -Equivalent to `dest .= β .* dest .+ α .* src`, but maybe with a more optimized/specialized -implementation. +The default implementation applies `op` element-wise, permutes, then accumulates +via broadcasting with Strided.jl optimization when possible. """ -function add!(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) - add!_broadcast(maybestrided(dest, src)..., α, β) - return dest -end +function permutedimsopadd!( + dest::AbstractArray, op, src::AbstractArray, perm, α::Number, β::Number + ) + # TODO: Remove this 0-dimensional special case once GradedArray is its own type + # (not an alias for BlockSparseArray), so the GradedArray permutedimsopadd! overload + # catches the 0-dimensional contraction result. + if iszero(ndims(dest)) + dest[] = β * dest[] + α * op(src[]) + return dest + end -# Broadcasting implementation of add!. -function add!_broadcast(dest::AbstractArray, src::AbstractArray, α::Number, β::Number) # This works around a bug in Strided.jl v2.3.4 and below when broadcasting # empty StridedViews: https://github.com/QuantumKitHub/Strided.jl/pull/50 # TODO: Delete this and bump the version of Strided.jl once that is fixed. isempty(dest) && return dest - if iszero(β) - dest .= α .* src + dest′, src′ = maybestrided(dest, permuteddims(src, perm)) + if op === identity + if iszero(β) + dest′ .= α .* src′ + else + dest′ .= β .* dest′ .+ α .* src′ + end else - dest .= β .* dest .+ α .* src + if iszero(β) + dest′ .= α .* op.(src′) + else + dest′ .= β .* dest′ .+ α .* op.(src′) + end end return dest end +# ---------------------------------------------------------------------------- # +# Convenience functions that lower to permutedimsopadd! +# ---------------------------------------------------------------------------- # + """ permutedimsadd!(dest, src, perm, α, β) @@ -51,5 +70,21 @@ end function permutedimsadd!( dest::AbstractArray, src::AbstractArray, perm, α::Number, β::Number ) - return add!(dest, permuteddims(src, perm), α, β) + return permutedimsopadd!(dest, identity, src, perm, α, β) end + +""" + add!(dest, src, α, β) + +`dest = β * dest + α * src`. +""" +function add!(dest::AbstractArray, src, α::Number, β::Number) + return permutedimsopadd!(dest, identity, src, ntuple(identity, ndims(src)), α, β) +end + +""" + add!(dest, src) + +`dest .+= src`. +""" +add!(dest::AbstractArray, src) = add!(dest, src, true, true) diff --git a/test/Project.toml b/test/Project.toml index 5a4045d..5c0c342 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.7" +TensorAlgebra = "0.8" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" diff --git a/test/test_lazy.jl b/test/test_lazy.jl deleted file mode 100644 index 4d3265f..0000000 --- a/test/test_lazy.jl +++ /dev/null @@ -1,123 +0,0 @@ -import FunctionImplementations as FI -using Base.Broadcast: Broadcast as BC -using TensorAlgebra: TensorAlgebra as TA, *ₗ, +ₗ, /ₗ, conjed -using Test: @test, @test_broken, @test_throws, @testset - -@testset "lazy arrays" begin - @testset "lazy array operations" begin - a = randn(ComplexF64, 3, 3) - b = randn(ComplexF64, 3, 3) - c = randn(ComplexF64, 3, 3) - - x = 2 *ₗ a - @test x ≡ TA.ScaledArray(2, a) - @test copy(x) ≈ 2a - - x = conjed(a) - @test x ≡ TA.ConjArray(a) - @test copy(x) ≈ conj(a) - @test conj(x) ≈ a - - x = a +ₗ b - @test x ≡ TA.AddArray(a, b) - @test copy(x) ≈ a + b - - x = a *ₗ b - @test x ≡ TA.MulArray(a, b) - @test copy(x) ≈ a * b - - x = a *ₗ b +ₗ c - @test x ≡ TA.AddArray(TA.MulArray(a, b), c) - @test copy(x) ≈ a *ₗ b .+ c ≈ a * b + c - - x = 2 *ₗ a *ₗ b +ₗ 3 *ₗ c - @test x ≡ TA.AddArray(TA.ScaledArray(2, TA.MulArray(a, b)), TA.ScaledArray(3, c)) - @test copy(x) ≈ 2 .* a *ₗ b .+ 3 .* c ≈ 2 * a * b + 3 * c - end - @testset "adjoint" begin - a = randn(ComplexF64, 2, 2) - b = randn(ComplexF64, 2, 2) - - x = (2 *ₗ a)' - @test x ≡ 2 *ₗ a' - @test copy(x) ≈ 2a' - - x = conjed(a)' - @test x ≡ transpose(a) - @test copy(x) ≈ permutedims(a) - - x = (a +ₗ b)' - @test x ≡ a' +ₗ b' - @test copy(x) ≈ a' + b' - - x = (a *ₗ b)' - @test x ≡ b' *ₗ a' - @test copy(x) ≈ b' * a' - end - @testset "transpose" begin - a = randn(ComplexF64, 2, 2) - b = randn(ComplexF64, 2, 2) - - x = transpose(2 *ₗ a) - @test x ≡ 2 *ₗ transpose(a) - @test copy(x) ≈ 2transpose(a) - - x = transpose(conjed(a)) - @test x ≡ adjoint(a) - @test copy(x) ≈ permutedims(conj(a)) - - x = transpose(a +ₗ b) - @test x ≡ transpose(a) +ₗ transpose(b) - @test copy(x) ≈ transpose(a) + transpose(b) - - x = transpose(a *ₗ b) - @test x ≡ transpose(b) *ₗ transpose(a) - @test copy(x) ≈ transpose(b) * transpose(a) - end - @testset "permuteddims" begin - a = randn(ComplexF64, 2, 2) - b = randn(ComplexF64, 2, 2) - perm = (2, 1) - - x = FI.permuteddims(2 *ₗ a, perm) - @test x ≡ 2 *ₗ FI.permuteddims(a, perm) - @test copy(x) ≈ 2permutedims(a, perm) - - x = FI.permuteddims(conjed(a), perm) - @test x ≡ conjed(FI.permuteddims(a, perm)) - @test copy(x) ≈ conj(permutedims(a, perm)) - - x = FI.permuteddims(a +ₗ b, perm) - @test x ≡ FI.permuteddims(a, perm) +ₗ FI.permuteddims(b, perm) - @test copy(x) ≈ permutedims(a, perm) + permutedims(b, perm) - - x = FI.permuteddims(a *ₗ b, perm) - @test x ≡ PermutedDimsArray(a *ₗ b, perm) - @test copy(x) ≈ permutedims(a * b, perm) - end - @testset "linear broadcast lowering" begin - a = randn(ComplexF64, 2, 2) - style = BC.DefaultArrayStyle{2}() - - @test TA.broadcasted_linear(identity, a) ≡ a - @test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ 2 *ₗ a - @test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ a *ₗ 2 - @test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ a /ₗ 2 - @test TA.broadcasted_linear(style, identity, a) ≡ a - @test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ 2 *ₗ a - @test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ a *ₗ 2 - @test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ a /ₗ 2 - @test TA.broadcasted_linear(style, conj, a) ≡ conjed(a) - @test_throws ArgumentError TA.broadcasted_linear(style, exp, a) - end - @testset "scalar getindex" begin - a = randn(ComplexF64, 2, 2) - b = randn(ComplexF64, 2, 2) - - @test (2 *ₗ a)[1, 2] == 2 * a[1, 2] - @test conjed(a)[2, 1] == conj(a[2, 1]) - @test (a +ₗ b)[2, 2] == a[2, 2] + b[2, 2] - @test (a *ₗ b)[1, 2] ≈ (a * b)[1, 2] - @test (a *ₗ b)[3] ≈ (a * b)[3] - end -end diff --git a/test/test_linearbroadcasted.jl b/test/test_linearbroadcasted.jl new file mode 100644 index 0000000..e2afc16 --- /dev/null +++ b/test/test_linearbroadcasted.jl @@ -0,0 +1,183 @@ +using Base.Broadcast: Broadcast as BC +using TensorAlgebra: TensorAlgebra as TA, linearbroadcasted +using Test: @test, @test_throws, @testset + +@testset "LinearBroadcasted and Mul" begin + @testset "construction and materialization" begin + a = randn(ComplexF64, 3, 3) + b = randn(ComplexF64, 3, 3) + c = randn(ComplexF64, 3, 3) + + x = linearbroadcasted(*, 2, a) + @test x ≡ TA.ScaledBroadcasted(2, a) + @test copy(x) ≈ 2a + + x = linearbroadcasted(conj, a) + @test x ≡ TA.ConjBroadcasted(a) + @test copy(x) ≈ conj(a) + + x = linearbroadcasted(+, a, b) + @test x ≡ TA.AddBroadcasted(a, b) + @test copy(x) ≈ a + b + + x = TA.Mul(a, b) + @test copy(x) ≈ a * b + + x = linearbroadcasted(+, TA.Mul(a, b), c) + @test x ≡ TA.AddBroadcasted(TA.Mul(a, b), c) + @test copy(x) ≈ a * b + c + + x = linearbroadcasted( + +, + linearbroadcasted(*, 2, TA.Mul(a, b)), + linearbroadcasted(*, 3, c) + ) + @test x ≡ TA.AddBroadcasted( + TA.ScaledBroadcasted(2, TA.Mul(a, b)), TA.ScaledBroadcasted(3, c) + ) + @test copy(x) ≈ 2 * a * b + 3 * c + end + @testset "tryflattenlinear" begin + a = randn(ComplexF64, 2, 2) + b = randn(ComplexF64, 2, 2) + + # Linear expressions convert successfully + @test TA.tryflattenlinear(BC.broadcasted(*, 2, a)) ≡ linearbroadcasted(*, 2, a) + @test TA.tryflattenlinear(BC.broadcasted(conj, a)) ≡ linearbroadcasted(conj, a) + @test TA.tryflattenlinear(BC.broadcasted(+, a, b)) ≡ linearbroadcasted(+, a, b) + @test TA.tryflattenlinear(BC.broadcasted(identity, a)) ≡ a + + # Nested linear expression + bc = BC.broadcasted(+, BC.broadcasted(*, 2, a), BC.broadcasted(*, 3, b)) + @test copy(TA.tryflattenlinear(bc)) ≈ 2a + 3b + + # Nonlinear expression returns nothing + @test TA.tryflattenlinear(BC.broadcasted(exp, a)) === nothing + @test TA.tryflattenlinear(BC.broadcasted(+, a, BC.broadcasted(exp, b))) === nothing + end + @testset "linearbroadcasted algebra" begin + a = randn(ComplexF64, 3, 3) + + # Scaling absorbs coefficients + @test linearbroadcasted(*, 3, linearbroadcasted(*, 2, a)) ≡ + TA.ScaledBroadcasted(6, a) + + # Conjugation of scaled + x = linearbroadcasted(conj, linearbroadcasted(*, 2im, a)) + @test x ≡ TA.ScaledBroadcasted(-2im, TA.ConjBroadcasted(a)) + + # Double conjugation cancels + @test linearbroadcasted(conj, linearbroadcasted(conj, a)) ≡ a + + # Subtraction + b = randn(ComplexF64, 3, 3) + x = linearbroadcasted(-, a, b) + @test copy(x) ≈ a - b + + # Unary minus + x = linearbroadcasted(-, a) + @test copy(x) ≈ -a + + # Division + x = linearbroadcasted(/, a, 2) + @test copy(x) ≈ a / 2 + + # Left division + x = linearbroadcasted(\, 2, a) + @test copy(x) ≈ a / 2 + + # Scaling distributes over AddBroadcasted + ab = linearbroadcasted(+, a, b) + x = linearbroadcasted(*, 3, ab) + @test copy(x) ≈ 3a + 3b + + # Conjugation distributes over AddBroadcasted + x = linearbroadcasted(conj, ab) + @test copy(x) ≈ conj(a) + conj(b) + + # Conjugation distributes over Mul + m = TA.Mul(a, b) + x = linearbroadcasted(conj, m) + @test copy(x) ≈ conj(a) * conj(b) + end + @testset "AddBroadcasted flattening" begin + a = randn(ComplexF64, 2, 2) + b = randn(ComplexF64, 2, 2) + c = randn(ComplexF64, 2, 2) + + # AddBroadcasted + array flattens + ab = linearbroadcasted(+, a, b) + x = linearbroadcasted(+, ab, c) + @test TA.addends(x) === (a, b, c) + + # array + AddBroadcasted flattens + x = linearbroadcasted(+, c, ab) + @test TA.addends(x) === (c, a, b) + + # AddBroadcasted + AddBroadcasted flattens + cd = linearbroadcasted(+, c, a) + x = linearbroadcasted(+, ab, cd) + @test TA.addends(x) === (a, b, c, a) + end + @testset "similar(::AddBroadcasted) with LinearBroadcasted addends" begin + a = randn(ComplexF64, 3, 4) + b = randn(ComplexF64, 3, 4) + + # Addends are ScaledBroadcasted, not AbstractArray + lb = linearbroadcasted(+, linearbroadcasted(*, 2, a), linearbroadcasted(*, 3, b)) + s = similar(lb) + @test size(s) == (3, 4) + @test eltype(s) === ComplexF64 + end + @testset "_compose_op" begin + @test TA._compose_op(identity, identity) === identity + @test TA._compose_op(identity, conj) === conj + @test TA._compose_op(conj, identity) === conj + @test TA._compose_op(conj, conj) === identity + f = TA._compose_op(sqrt, conj) + @test f isa ComposedFunction + end + @testset "Broadcasted(::LinearBroadcasted) round-trip" begin + a = randn(ComplexF64, 3, 3) + b = randn(ComplexF64, 3, 3) + + lb = linearbroadcasted(+, linearbroadcasted(*, 2, a), linearbroadcasted(conj, b)) + bc = BC.Broadcasted(lb) + @test bc isa BC.Broadcasted + @test copy(bc) ≈ 2a + conj(b) + end + @testset "add! and copyto! with LinearBroadcasted" begin + a = randn(ComplexF64, 3, 3) + b = randn(ComplexF64, 3, 3) + + # add! with ScaledBroadcasted + dest = zeros(ComplexF64, 3, 3) + TA.add!(dest, linearbroadcasted(*, 2, a), true, false) + @test dest ≈ 2a + + # add! with AddBroadcasted + dest = zeros(ComplexF64, 3, 3) + TA.add!(dest, linearbroadcasted(+, a, b), true, false) + @test dest ≈ a + b + + # add! with ConjBroadcasted + dest = zeros(ComplexF64, 3, 3) + TA.add!(dest, linearbroadcasted(conj, a), true, false) + @test dest ≈ conj(a) + + # add! with β accumulation + dest = ones(ComplexF64, 3, 3) + TA.add!(dest, linearbroadcasted(*, 2, a), 3, 1) + @test dest ≈ ones(ComplexF64, 3, 3) + 6a + end + @testset "0-dimensional permutedimsopadd!" begin + a = fill(3.0 + 2.0im) + dest = fill(1.0 + 0.0im) + TA.permutedimsopadd!(dest, identity, a, (), 2, 3) + @test dest[] ≈ 3 * 1 + 2 * a[] + + dest = fill(0.0 + 0.0im) + TA.permutedimsopadd!(dest, conj, a, (), 1, 0) + @test dest[] ≈ conj(a[]) + end +end diff --git a/test/test_permutedimsadd.jl b/test/test_permutedimsadd.jl index 01b6816..48491e0 100644 --- a/test/test_permutedimsadd.jl +++ b/test/test_permutedimsadd.jl @@ -1,6 +1,6 @@ using Adapt: adapt using JLArrays: JLArray -using TensorAlgebra: add!, permutedimsadd! +using TensorAlgebra: add!, permutedimsadd!, permutedimsopadd! using Test: @test, @testset @testset "[permutedims]add!" begin @@ -47,4 +47,25 @@ using Test: @test, @testset @test b′ ≈ β * b + α * permutedims(a, perm) end end + @testset "permutedimsopadd! (arraytype=$arrayt)" for arrayt in (Array,) + dev = adapt(arrayt) + a = dev(randn(ComplexF64, 2, 2, 2)) + perm = (3, 1, 2) + α = 2 + for β in (0, 3) + b = dev(randn(ComplexF64, 2, 2, 2)) + b′ = copy(b) + permutedimsopadd!(b′, conj, a, perm, α, β) + @test b′ ≈ β * b + α * permutedims(conj(a), perm) + end + # identity op should match permutedimsadd! + for β in (0, 3) + b = dev(randn(ComplexF64, 2, 2, 2)) + b′ = copy(b) + b″ = copy(b) + permutedimsopadd!(b′, identity, a, perm, α, β) + permutedimsadd!(b″, a, perm, α, β) + @test b′ ≈ b″ + end + end end