From 5857fc125273bf33fb0c5e368eb5c60d7112b1c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Mon, 6 Apr 2026 13:36:11 +0200 Subject: [PATCH] fix captured variable in static array aggregation add tests Incidental: move JET tests up front, whitespace changes --- src/aggregation.jl | 23 +++++++++++------------ test/runtests.jl | 42 +++++++++++++++++++++++++----------------- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/src/aggregation.jl b/src/aggregation.jl index e95c060..f124c29 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -185,7 +185,7 @@ function as(::Type{<:SArray{S}}, inner_transformation::AbstractTransform) where @argcheck all(x -> x ≥ 1, dim) StaticArrayTransformation{prod(dim),S,typeof(inner_transformation)}(inner_transformation) end -# Repeated with more specific typing to eliminate method ambiguity with +# Repeated with more specific typing to eliminate method ambiguity with # the ScalarWrapperTransform method for `as` function as(::Type{<:SArray{S}}, inner_transformation::ScalarTransform = Identity()) where S dim = fieldtypes(S) @@ -206,17 +206,16 @@ function transform_with(flag::LogJacFlag, transformation::StaticArrayTransformat # first element. y1, ℓ1, index1 = transform_with(flag, inner_transformation, x, index) D == 1 && return SArray{S}(y1), ℓ1, index1 - L = typeof(ℓ1) - let ℓ::L = ℓ1, index::Int = index1 - function _f(_) - y, ℓΔ, index′ = transform_with(flag, inner_transformation, x, index) - index = index′ - ℓ = ℓ + ℓΔ - y - end - yrest = SVector{D-1}(_f(i) for i in 2:D) - SArray{S}(pushfirst(yrest, y1)), ℓ, index + ℓ = Ref(ℓ1) + index = Ref(index1) + function _f(_) + y, ℓΔ, index′ = transform_with(flag, inner_transformation, x, index[]) + index[] = index′ + ℓ[] += ℓΔ + y end + yrest = SVector{D-1}(_f(i) for i in 2:D) + SArray{S}(pushfirst(yrest, y1)), ℓ[], index[] end function inverse_eltype(transformation::Union{ArrayTransformation,StaticArrayTransformation}, @@ -520,7 +519,7 @@ dimension(t::TypeWrapperTransform) = dimension(t.inner_transformation) function _summary_rows(transformation::TypeWrapperTransform{T, S}, mime) where {T, S<:TransformTuple} (; inner_transformation) = transformation innerinner = _inner(inner_transformation) - name = string("$T wrapper on ", nameof(typeof(innerinner)), " of transformations") + name = string("$T wrapper on ", nameof(typeof(innerinner)), " of transformations") _tuple_summary_rows(name, inner_transformation, mime) end diff --git a/test/runtests.jl b/test/runtests.jl index 5f7d537..79c3ee7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,6 +20,17 @@ const ALLOCS = get(ENV, "BUILD_IS_PRODUCTION_BUILD", "false") == "true" @info "test environment" CIENV ALLOCS +#### +#### static analysis with JET +#### + +import JET +JET.test_package(TransformVariables; target_modules = (TransformVariables,)) + +#### +#### setup +#### + include("utilities.jl") Random.seed!(1) @@ -544,7 +555,7 @@ end end @testset "transform to custom type" begin - + struct CustomType{A, B} a::A b::B @@ -573,22 +584,22 @@ end @test_throws ArgumentError inverse(t2, MyType(3.0)) # Named tuple with different ordering - t1 = as((b = asℝ, a = asℝ)) + t1 = as((b = asℝ, a = asℝ)) t2 = as(KwCustomType, t1) y = @inferred transform(t2, [1.0, 2.0]) - @test y == KwCustomType(a = 2.0, b = 1.0) + @test y == KwCustomType(a = 2.0, b = 1.0) test_transformation(t2, y -> y isa KwCustomType; N=1, jac=false) # Named tuple with wrong number or names of fields - t1 = as((;b = asℝ)) + t1 = as((;b = asℝ)) t2 = as(KwCustomType, t1) @test_throws UndefKeywordError transform(t2, [1.0]) @test inverse(t2, KwCustomType(1.0, 3.0)) == [3.0] - t1 = as((a = asℝ, c = asℝ)) + t1 = as((a = asℝ, c = asℝ)) t2 = as(KwCustomType, t1) @test_throws UndefKeywordError transform(t2, [1.0, 3.0]) @test_throws ArgumentError inverse(t2, KwCustomType(1.0, 3.0)) - t1 = as((b = asℝ, a = asℝ, c = asℝ)) + t1 = as((b = asℝ, a = asℝ, c = asℝ)) t2 = as(KwCustomType, t1) @test_throws MethodError transform(t2, [1.0, 2.0, 3.0]) @test_throws ArgumentError inverse(t2, KwCustomType(1.0, 3.0)) @@ -607,7 +618,7 @@ end t = as(MaskedType, as((;a=asℝ))) @test_throws MethodError transform(t, [1.0]) - # When constructor accepts less args than struct has fields, + # When constructor accepts less args than struct has fields, # inverse errors t = as(MaskedType, (asℝ,)) x = [1.0] @@ -1009,6 +1020,11 @@ end @testset "static arrays inference" begin @test @inferred transform_with(NOLOGJAC, as(SVector{3, Float64}), zeros(3), 1) == (SVector(0.0, 0.0, 0.0), NOLOGJAC, 4) @test @inferred transform_with(NOLOGJAC, as(SVector{1, Float64}), zeros(1), 1) == (SVector(0.0), NOLOGJAC, 2) + + @testset "type stability of captured variable" begin + t = as(SVector{3}, asℝ₊) + @test isempty(JET.get_reports(JET.@report_opt transform(t, ones(3)))) + end end @testset "view transformations" begin @@ -1155,7 +1171,7 @@ end @test outr[1] ≈ out[1] @test outr[2] ≈ out[2] end - end + end @testset "Test inner transformations" begin for T in (asℝ, asℝ₊, asℝ₋, as𝕀) tr = as(Array, as(Array, T, 3), 3) @@ -1229,16 +1245,8 @@ end t = as(Vector, as((a = asℝ,)), 4) a = randn(dimension(t)) ar = Reactant.to_rarray(a) - + @test_throws ArgumentError @jit(transform_and_logjac(t, ar)) end end end - - -#### -#### static analysis with JET -#### - -import JET -JET.test_package(TransformVariables; target_modules = (TransformVariables,))