diff --git a/Project.toml b/Project.toml index 7d89e15..ee2ec89 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] ArgCheck = "1, 2" DocStringExtensions = "0.8, 0.9" +Random = "1.10" julia = "1.10" [extras] diff --git a/src/LogDensityProblems.jl b/src/LogDensityProblems.jl index 650fb44..cddca0e 100644 --- a/src/LogDensityProblems.jl +++ b/src/LogDensityProblems.jl @@ -17,6 +17,7 @@ using DocStringExtensions: SIGNATURES, TYPEDEF using Random: AbstractRNG, default_rng # https://github.com/JuliaLang/julia/pull/50105 +# NOTE remove this once we require Julia v1.11 and use public @static if VERSION >= v"1.11.0-DEV.469" eval( Expr( @@ -27,6 +28,8 @@ using Random: AbstractRNG, default_rng :logdensity, :logdensity_and_gradient, :logdensity_gradient_and_hessian, + :stresstest, + :converting_logdensity, ) ) end diff --git a/src/utilities.jl b/src/utilities.jl index 7e7d514..2b189e8 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -59,3 +59,81 @@ function stresstest(f, ℓ; N = 1000, rng::AbstractRNG = default_rng(), scale = end failures end + +#### +#### converting input and output +#### + +""" +Type implementating [`converting_logdensity`](@ref). + +Not part of the API *per se*, use the eponymous function to construct. +""" +struct ConvertingLogDensity{I, # input type + L, # logdensity output type + G, # gradient output type + H, # Hessian output type + P} + parent::P +end + +""" +$(SIGNATURES) + +Return an object implementing the same [`capabilities`](@ref) as the first argument, +converting inputs and outputs are specified. + +All conversions are implemented via `convert(T, …)::T`. Typical use cases include fixing +type instability introduced by automatic differentiation. + +The original logdensity can be retrieved with `parent`. + +# Keyword arguments (with defaults) + +- `input = Any`: convert *inputs* to the given type before evaluating log densities, gradients, …. + +- `logdensity = Any`: convert the *log density* to that type. + +- `gradient = Any`: convert the *gradient* to that type. + +- `hessian = Any`: convert the *Hessian* to that type. + +# Note + +The types are not checked for validity. + +# Examples + +```julia +converting_logdensity(ℓ; input = Vector{Float32}, logdensity = Float64, + gradient = Vector{Float64}) +``` +will convert to a vector of `Float32`s, then enforce Float64 elements for the logdensity +and its gradient, leaving the Hessian alone. +""" +function converting_logdensity(ℓ::P; input::Type = Any, logdensity::Type = Any, gradient::Type = Any, + hessian::Type = Any) where P + @argcheck(capabilities(ℓ) ≥ LogDensityOrder(0), + "Input does not implement the log density interface.") + ConvertingLogDensity{input,logdensity,gradient,hessian,P}(ℓ) +end + +Base.parent(ℓ::ConvertingLogDensity) = ℓ.parent + +capabilities(ℓ::ConvertingLogDensity) = capabilities(ℓ.parent) + +dimension(ℓ::ConvertingLogDensity) = dimension(ℓ.parent) + +function logdensity(ℓ::ConvertingLogDensity{I,L}, x) where {I,L} + convert(L, logdensity(ℓ.parent, convert(I, x)))::L +end + +function logdensity_and_gradient(ℓ::ConvertingLogDensity{I,L,G}, x) where {I,L,G} + l, g = logdensity_and_gradient(ℓ.parent, convert(I, x)) + convert(L, l)::L, convert(G, g)::G +end + +function logdensity_gradient_and_hessian(ℓ::ConvertingLogDensity{I,L,G,H}, x) where {I,L,G,H} + l, g, h = logdensity_gradient_and_hessian(ℓ.parent, convert(I, x)) + convert(L, l)::L, convert(G, g)::G, convert(H, h)::H +end diff --git a/test/runtests.jl b/test/runtests.jl index bd04938..9912673 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using LogDensityProblems, Test, Random -import LogDensityProblems: capabilities, dimension, logdensity +import LogDensityProblems: capabilities, dimension, logdensity, logdensity_and_gradient, + logdensity_gradient_and_hessian using LogDensityProblems: logdensity_and_gradient, LogDensityOrder #### @@ -96,6 +97,7 @@ end #### @testset "public API" begin + # NOTE remove this once we require Julia v1.11 and use public if isdefined(Base, :ispublic) @test Base.ispublic(LogDensityProblems, :capabilities) @test Base.ispublic(LogDensityProblems, :LogDensityOrder) @@ -103,5 +105,42 @@ end @test Base.ispublic(LogDensityProblems, :logdensity) @test Base.ispublic(LogDensityProblems, :logdensity_and_gradient) @test Base.ispublic(LogDensityProblems, :logdensity_gradient_and_hessian) + @test Base.ispublic(LogDensityProblems, :stresstest) + @test Base.ispublic(LogDensityProblems, :converting_logdensity) end end + +#### +#### converting logdensity +#### + +struct BadLogDensity end +dimension(::BadLogDensity) = 1 +capabilities(::BadLogDensity) = LogDensityOrder(2) +_bad_x(x) = (_x = only(x); _x > 0 ? Float64(_x) : _x) # introduce type instability +function logdensity(::BadLogDensity, x::Vector{Float32}) # deliberate restriction + -_bad_x(x)^2 / 2 +end +function logdensity_and_gradient(::BadLogDensity, x::Vector{Float32}) + _x = _bad_x(x) + -_x^2 / 2, [-_x] +end +function logdensity_gradient_and_hessian(::BadLogDensity, x::Vector{Float32}) + _x = _bad_x(x) + -_x^2 / 2, [-_x], [-one(_x)] +end + +@testset "converting logdensity" begin + bad = BadLogDensity() + ℓ = LogDensityProblems.converting_logdensity(bad; + input = Vector{Float32}, + logdensity = Float64, + gradient = Vector{Float64}) + @test dimension(ℓ) == dimension(bad) + @test capabilities(ℓ) == capabilities(bad) + x = [0.9] # no such method for the parent + xF32 = Float32.(x) + @test @inferred(logdensity(ℓ, x)) == logdensity(bad, xF32) + @test @inferred(logdensity_and_gradient(ℓ, x)) == logdensity_and_gradient(bad, xF32) + @test eltype(logdensity_gradient_and_hessian(ℓ, .-x)[3]) ≡ Float32 # we do not touch this +end