Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions src/aggregation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ function as(::Type{<:SArray{S}}, inner_transformation::AbstractTransform) where
@argcheck all(x -> x ≥ 1, dim)
StaticArrayTransformation{prod(dim),S,typeof(inner_transformation)}(inner_transformation)
end
# Repeated with more specific typing to eliminate method ambiguity with
# Repeated with more specific typing to eliminate method ambiguity with
# the ScalarWrapperTransform method for `as`
function as(::Type{<:SArray{S}}, inner_transformation::ScalarTransform = Identity()) where S
dim = fieldtypes(S)
Expand All @@ -206,17 +206,16 @@ function transform_with(flag::LogJacFlag, transformation::StaticArrayTransformat
# first element.
y1, ℓ1, index1 = transform_with(flag, inner_transformation, x, index)
D == 1 && return SArray{S}(y1), ℓ1, index1
L = typeof(ℓ1)
let ℓ::L = ℓ1, index::Int = index1
function _f(_)
y, ℓΔ, index′ = transform_with(flag, inner_transformation, x, index)
index = index′
ℓ = ℓ + ℓΔ
y
end
yrest = SVector{D-1}(_f(i) for i in 2:D)
SArray{S}(pushfirst(yrest, y1)), ℓ, index
ℓ = Ref(ℓ1)
index = Ref(index1)
function _f(_)
y, ℓΔ, index′ = transform_with(flag, inner_transformation, x, index[])
index[] = index′
ℓ[] += ℓΔ
y
end
yrest = SVector{D-1}(_f(i) for i in 2:D)
SArray{S}(pushfirst(yrest, y1)), ℓ[], index[]
end

function inverse_eltype(transformation::Union{ArrayTransformation,StaticArrayTransformation},
Expand Down Expand Up @@ -520,7 +519,7 @@ dimension(t::TypeWrapperTransform) = dimension(t.inner_transformation)
function _summary_rows(transformation::TypeWrapperTransform{T, S}, mime) where {T, S<:TransformTuple}
(; inner_transformation) = transformation
innerinner = _inner(inner_transformation)
name = string("$T wrapper on ", nameof(typeof(innerinner)), " of transformations")
name = string("$T wrapper on ", nameof(typeof(innerinner)), " of transformations")
_tuple_summary_rows(name, inner_transformation, mime)
end

Expand Down
42 changes: 25 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ const ALLOCS = get(ENV, "BUILD_IS_PRODUCTION_BUILD", "false") == "true"

@info "test environment" CIENV ALLOCS

####
#### static analysis with JET
####

import JET
JET.test_package(TransformVariables; target_modules = (TransformVariables,))

####
#### setup
####

include("utilities.jl")

Random.seed!(1)
Expand Down Expand Up @@ -544,7 +555,7 @@ end
end

@testset "transform to custom type" begin

struct CustomType{A, B}
a::A
b::B
Expand Down Expand Up @@ -573,22 +584,22 @@ end
@test_throws ArgumentError inverse(t2, MyType(3.0))

# Named tuple with different ordering
t1 = as((b = asℝ, a = asℝ))
t1 = as((b = asℝ, a = asℝ))
t2 = as(KwCustomType, t1)
y = @inferred transform(t2, [1.0, 2.0])
@test y == KwCustomType(a = 2.0, b = 1.0)
@test y == KwCustomType(a = 2.0, b = 1.0)
test_transformation(t2, y -> y isa KwCustomType; N=1, jac=false)

# Named tuple with wrong number or names of fields
t1 = as((;b = asℝ))
t1 = as((;b = asℝ))
t2 = as(KwCustomType, t1)
@test_throws UndefKeywordError transform(t2, [1.0])
@test inverse(t2, KwCustomType(1.0, 3.0)) == [3.0]
t1 = as((a = asℝ, c = asℝ))
t1 = as((a = asℝ, c = asℝ))
t2 = as(KwCustomType, t1)
@test_throws UndefKeywordError transform(t2, [1.0, 3.0])
@test_throws ArgumentError inverse(t2, KwCustomType(1.0, 3.0))
t1 = as((b = asℝ, a = asℝ, c = asℝ))
t1 = as((b = asℝ, a = asℝ, c = asℝ))
t2 = as(KwCustomType, t1)
@test_throws MethodError transform(t2, [1.0, 2.0, 3.0])
@test_throws ArgumentError inverse(t2, KwCustomType(1.0, 3.0))
Expand All @@ -607,7 +618,7 @@ end
t = as(MaskedType, as((;a=asℝ)))
@test_throws MethodError transform(t, [1.0])

# When constructor accepts less args than struct has fields,
# When constructor accepts less args than struct has fields,
# inverse errors
t = as(MaskedType, (asℝ,))
x = [1.0]
Expand Down Expand Up @@ -1009,6 +1020,11 @@ end
@testset "static arrays inference" begin
@test @inferred transform_with(NOLOGJAC, as(SVector{3, Float64}), zeros(3), 1) == (SVector(0.0, 0.0, 0.0), NOLOGJAC, 4)
@test @inferred transform_with(NOLOGJAC, as(SVector{1, Float64}), zeros(1), 1) == (SVector(0.0), NOLOGJAC, 2)

@testset "type stability of captured variable" begin
t = as(SVector{3}, asℝ₊)
@test isempty(JET.get_reports(JET.@report_opt transform(t, ones(3))))
end
end

@testset "view transformations" begin
Expand Down Expand Up @@ -1155,7 +1171,7 @@ end
@test outr[1] ≈ out[1]
@test outr[2] ≈ out[2]
end
end
end
@testset "Test inner transformations" begin
for T in (asℝ, asℝ₊, asℝ₋, as𝕀)
tr = as(Array, as(Array, T, 3), 3)
Expand Down Expand Up @@ -1229,16 +1245,8 @@ end
t = as(Vector, as((a = asℝ,)), 4)
a = randn(dimension(t))
ar = Reactant.to_rarray(a)

@test_throws ArgumentError @jit(transform_and_logjac(t, ar))
end
end
end


####
#### static analysis with JET
####

import JET
JET.test_package(TransformVariables; target_modules = (TransformVariables,))
Loading