diff --git a/Project.toml b/Project.toml index 6126713f..3e6245fc 100644 --- a/Project.toml +++ b/Project.toml @@ -15,10 +15,12 @@ QuantumPropagators = "7bf12567-5742-4b91-a078-644e72a65fc1" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] +QuantumControlChainRulesCoreExt = "ChainRulesCore" QuantumControlFiniteDifferencesExt = "FiniteDifferences" QuantumControlZygoteExt = "Zygote" diff --git a/ext/QuantumControlChainRulesCoreExt.jl b/ext/QuantumControlChainRulesCoreExt.jl new file mode 100644 index 00000000..091fe880 --- /dev/null +++ b/ext/QuantumControlChainRulesCoreExt.jl @@ -0,0 +1,32 @@ +module QuantumControlChainRulesCoreExt + +using ChainRulesCore: ChainRulesCore, NoTangent +using QuantumControl: Trajectory + + +# Allow to differentiate w.r.t. to a trajectory. See `test_traj_zygote.jl` for +# an example. Evaluating a gradient with Zygote returns a NamedTuple with +# the fields of the trajectory. Unfortunately, Zygote gets confused about the +# custom `getproperty` method that is defined for a Trajectory, and we need a +# special method to differential through `getproperty` +function ChainRulesCore.rrule(::typeof(getproperty), traj::Trajectory, name::Symbol) + val = getproperty(traj, name) + if name in (:initial_state, :generator, :target_state, :weight) + function field_pullback(Δ) + dt = ChainRulesCore.Tangent{typeof(traj)}(; (name => Δ,)...) + return NoTangent(), dt, NoTangent() + end + return val, field_pullback + else + # kwargs-stored property: route gradient back into the kwargs Dict + function kwargs_pullback(Δ) + dkwargs = Dict{Symbol,Any}(name => ChainRulesCore.unthunk(Δ)) + dt = ChainRulesCore.Tangent{typeof(traj)}(; kwargs = dkwargs) + return NoTangent(), dt, NoTangent() + end + return val, kwargs_pullback + end +end + + +end diff --git a/src/trajectories.jl b/src/trajectories.jl index 58d83935..50001aa4 100644 --- a/src/trajectories.jl +++ b/src/trajectories.jl @@ -167,6 +167,11 @@ function Base.setproperty!(traj::Trajectory, name::Symbol, value) end +# Transparently access properties stored in the `kwargs` field. +# Note: This also requires a custom ChainRulesCore.rrule for certain operations +# in Zygote (or other AD frameworks using ChainRules). This is implemented in +# the QuantumControlChainRulesCoreExt extension module. +# See `test_traj_zygote.jl` for an example function Base.getproperty(traj::Trajectory, name::Symbol) if name in (:initial_state, :generator, :target_state, :weight) return getfield(traj, name) diff --git a/test/runtests.jl b/test/runtests.jl index ec404f20..dd945115 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -59,6 +59,7 @@ end println("* Trajectories (test_trajectories.jl):") @time @safetestset "Trajectories" begin include("test_trajectories.jl") + include("test_traj_zygote.jl") end println("* Adjoint Trajectories (test_adjoint_trajectory.jl):") diff --git a/test/test_traj_zygote.jl b/test/test_traj_zygote.jl new file mode 100644 index 00000000..72dbe784 --- /dev/null +++ b/test/test_traj_zygote.jl @@ -0,0 +1,71 @@ +using Test +using StableRNGs +using IOCapture +using QuantumControl: Trajectory +using LinearAlgebra: dot, norm +using Random: rand +using Zygote + + +function J_T(Ψ; Ψtgt, N) + return 1 - (abs2(dot(Ψ, Ψtgt)) / N) +end + + + +@testset "Gradient w.r.t. trajectory.initial_state" begin + + function f(traj; Ψtgt, N) + return J_T(traj.initial_state; Ψtgt, N) + end + + rng = StableRNG(3143162815) + N = 4 + H = nothing + Ψ = rand(rng, ComplexF64, N) + Ψ ./= norm(Ψ) + Ψtgt = zeros(ComplexF64, N) + Ψtgt[1] = 1.0 + traj = Trajectory(Ψ, H) + @test f(traj; Ψtgt, N) > 0.0 + grad = Zygote.gradient(traj -> f(traj; Ψtgt, N), traj)[1] + @test grad isa NamedTuple + @test grad.initial_state isa Vector + expected_grad = -2 .* Ψtgt .* conj(dot(Ψ, Ψtgt)) / N + @test norm(grad.initial_state - expected_grad) < 1e-14 + +end + + +@testset "Gradient w.r.t. trajectory.x" begin + + function f(traj; Ψtgt, N) + return J_T(traj.x; Ψtgt, N) + end + + rng = StableRNG(3143162816) + N = 4 + H = nothing + Ψ = rand(rng, ComplexF64, N) + Ψ ./= norm(Ψ) + Ψtgt = zeros(ComplexF64, N) + Ψtgt[1] = 1.0 + x = Ψ + traj = Trajectory(Ψ, H; x) + @test f(traj; Ψtgt, N) > 0.0 + captured = IOCapture.capture(rethrow = Union{}) do + # Without the custom `rrule` in `QuantumControlChainRulesCoreExt`, this + # test would show a potentially very confusing error, and throw an + # `UndefRefError`. See also: https://discourse.julialang.org/t/136704/ + Zygote.gradient(traj -> f(traj; Ψtgt, N), traj)[1] + end + grad = captured.value + @test grad isa NamedTuple + if grad isa NamedTuple + @test grad.initial_state isa Nothing + @test grad.kwargs[:x] isa Vector + expected_grad = -2 .* Ψtgt .* conj(dot(Ψ, Ψtgt)) / N + @test norm(grad.kwargs[:x] - expected_grad) < 1e-14 + end + +end