From 48dcedae6230f5ac0544266a42c19f85edd2124f Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Tue, 14 Apr 2026 09:00:30 -0400 Subject: [PATCH 1/5] Add support for gradients w.r.t. trajectories This allows to call `Zygote.gradient` for a function where a `Trajectory` is the argument. See also https://discourse.julialang.org/t/136704 While this is probably not something that people would do _directly_, the lack of a custom `rrule` was causing issues with Zygote constructing the derivative of state-dependent running costs where information relevant to the running cost was stored in a custom property of the relevant `Trajectory`. --- Project.toml | 2 + ext/QuantumControlChainRulesCoreExt.jl | 32 ++++++++++++ src/trajectories.jl | 5 ++ test/runtests.jl | 1 + test/test_traj_zygote.jl | 67 ++++++++++++++++++++++++++ 5 files changed, 107 insertions(+) create mode 100644 ext/QuantumControlChainRulesCoreExt.jl create mode 100644 test/test_traj_zygote.jl diff --git a/Project.toml b/Project.toml index 6126713f..3e6245fc 100644 --- a/Project.toml +++ b/Project.toml @@ -15,10 +15,12 @@ QuantumPropagators = "7bf12567-5742-4b91-a078-644e72a65fc1" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] +QuantumControlChainRulesCoreExt = "ChainRulesCore" QuantumControlFiniteDifferencesExt = "FiniteDifferences" QuantumControlZygoteExt = "Zygote" diff --git a/ext/QuantumControlChainRulesCoreExt.jl b/ext/QuantumControlChainRulesCoreExt.jl new file mode 100644 index 00000000..56798932 --- /dev/null +++ b/ext/QuantumControlChainRulesCoreExt.jl @@ -0,0 +1,32 @@ +module QuantumControlChainRulesCoreExt + +using ChainRulesCore: ChainRulesCore, NoTangent +using QuantumControl: Trajectory + + +# Allow to differentiate w.r.t. to a trajectory. See `test_traj_zygote.jl` for +# an example. Evaluating a gradient with Zygote returns a NamedTuple with +# the fields of the trajectory. Unfortunately, Zygote gets confused about the +# custom `getproperty` method that is defined for a Trajectory, and we need a +# special method to differential through `getproperty` +function ChainRulesCore.rrule(::typeof(getproperty), traj::Trajectory, name::Symbol) + val = getproperty(traj, name) + if name in (:initial_state, :generator, :target_state, :weight) + function field_pullback(Δ) + dt = ChainRulesCore.Tangent{typeof(traj)}(; (name => Δ,)...) + return NoTangent(), dt, NoTangent() + end + return val, field_pullback + else + # kwargs-stored property: route gradient back into the kwargs Dict + function kwargs_pullback(Δ) + dkwargs = Dict{Symbol,Any}(name => Δ) + dt = ChainRulesCore.Tangent{typeof(traj)}(; kwargs = dkwargs) + return NoTangent(), dt, NoTangent() + end + return val, kwargs_pullback + end +end + + +end diff --git a/src/trajectories.jl b/src/trajectories.jl index 58d83935..50001aa4 100644 --- a/src/trajectories.jl +++ b/src/trajectories.jl @@ -167,6 +167,11 @@ function Base.setproperty!(traj::Trajectory, name::Symbol, value) end +# Transparently access properties stored in the `kwargs` field. +# Note: This also requires a custom ChainRulesCore.rrule for certain operations +# in Zygote (or other AD frameworks using ChainRules). This is implemented in +# the QuantumControlChainRulesCoreExt extension module. +# See `test_traj_zygote.jl` for an example function Base.getproperty(traj::Trajectory, name::Symbol) if name in (:initial_state, :generator, :target_state, :weight) return getfield(traj, name) diff --git a/test/runtests.jl b/test/runtests.jl index ec404f20..dd945115 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -59,6 +59,7 @@ end println("* Trajectories (test_trajectories.jl):") @time @safetestset "Trajectories" begin include("test_trajectories.jl") + include("test_traj_zygote.jl") end println("* Adjoint Trajectories (test_adjoint_trajectory.jl):") diff --git a/test/test_traj_zygote.jl b/test/test_traj_zygote.jl new file mode 100644 index 00000000..55854ab2 --- /dev/null +++ b/test/test_traj_zygote.jl @@ -0,0 +1,67 @@ +using Test +using StableRNGs +using IOCapture +using QuantumControl: Trajectory +using LinearAlgebra: dot, norm +using Random: rand +using Zygote + + +function J_T(Ψ; Ψtgt, N) + return 1 - (abs2(dot(Ψ, Ψtgt)) / N) +end + + + +@testset "Gradient w.r.t. trajectory.initial_state" begin + + function f(traj; Ψtgt, N) + return J_T(traj.initial_state; Ψtgt, N) + end + + rng = StableRNG(3143162815) + N = 4 + H = nothing + Ψ = rand(rng, ComplexF64, N) + Ψ ./ norm(Ψ) + Ψtgt = zeros(ComplexF64, N) + Ψtgt[1] = 1.0 + traj = Trajectory(Ψ, H) + @test f(traj; Ψtgt, N) > 0.0 + grad = Zygote.gradient(traj -> f(traj; Ψtgt, N), traj)[1] + @test grad isa NamedTuple + @test grad.initial_state isa Vector + +end + + +@testset "Gradient w.r.t. trajectory.x" begin + + function f(traj; Ψtgt, N) + return J_T(traj.x; Ψtgt, N) + end + + rng = StableRNG(3143162816) + N = 4 + H = nothing + Ψ = rand(rng, ComplexF64, N) + Ψ ./ norm(Ψ) + Ψtgt = zeros(ComplexF64, N) + Ψtgt[1] = 1.0 + x = Ψ + traj = Trajectory(Ψ, H; x) + @test f(traj; Ψtgt, N) > 0.0 + captured = IOCapture.capture(rethrow = Union{}) do + # Without the custom `rrule` in `QuantumControlchainRulesCoreExt`, this + # test would show a potentially very confusing error, and throw an + # `UndefRefError`. See also: https://discourse.julialang.org/t/136704/ + Zygote.gradient(traj -> f(traj; Ψtgt, N), traj)[1] + end + grad = captured.value + @test grad isa NamedTuple + if grad isa NamedTuple + @test grad.initial_state isa Nothing + @test grad.kwargs[:x] isa Vector + end + +end From c611f0364c67a4c3dde2085ae5597e3036037f09 Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Tue, 14 Apr 2026 10:12:40 -0400 Subject: [PATCH 2/5] Fix incompatibility with Zygote 0.7 --- ext/QuantumControlChainRulesCoreExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/QuantumControlChainRulesCoreExt.jl b/ext/QuantumControlChainRulesCoreExt.jl index 56798932..091fe880 100644 --- a/ext/QuantumControlChainRulesCoreExt.jl +++ b/ext/QuantumControlChainRulesCoreExt.jl @@ -20,7 +20,7 @@ function ChainRulesCore.rrule(::typeof(getproperty), traj::Trajectory, name::Sym else # kwargs-stored property: route gradient back into the kwargs Dict function kwargs_pullback(Δ) - dkwargs = Dict{Symbol,Any}(name => Δ) + dkwargs = Dict{Symbol,Any}(name => ChainRulesCore.unthunk(Δ)) dt = ChainRulesCore.Tangent{typeof(traj)}(; kwargs = dkwargs) return NoTangent(), dt, NoTangent() end From 0f889af5dc99c6a1edde7d6ebf4ea5f25611acba Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Apr 2026 16:35:00 +0000 Subject: [PATCH 3/5] Fix normalization, comment typo, and add numerical gradient checks in Zygote tests Agent-Logs-Url: https://github.com/JuliaQuantumControl/QuantumControl.jl/sessions/8d24cbf1-8f4b-49f7-9008-ac22384e7772 Co-authored-by: goerz <112306+goerz@users.noreply.github.com> --- test/test_traj_zygote.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/test_traj_zygote.jl b/test/test_traj_zygote.jl index 55854ab2..72c593e9 100644 --- a/test/test_traj_zygote.jl +++ b/test/test_traj_zygote.jl @@ -23,7 +23,7 @@ end N = 4 H = nothing Ψ = rand(rng, ComplexF64, N) - Ψ ./ norm(Ψ) + Ψ ./= norm(Ψ) Ψtgt = zeros(ComplexF64, N) Ψtgt[1] = 1.0 traj = Trajectory(Ψ, H) @@ -31,6 +31,8 @@ end grad = Zygote.gradient(traj -> f(traj; Ψtgt, N), traj)[1] @test grad isa NamedTuple @test grad.initial_state isa Vector + expected_grad = -Ψtgt .* conj(dot(Ψ, Ψtgt)) / N + @test grad.initial_state ≈ expected_grad end @@ -45,14 +47,14 @@ end N = 4 H = nothing Ψ = rand(rng, ComplexF64, N) - Ψ ./ norm(Ψ) + Ψ ./= norm(Ψ) Ψtgt = zeros(ComplexF64, N) Ψtgt[1] = 1.0 x = Ψ traj = Trajectory(Ψ, H; x) @test f(traj; Ψtgt, N) > 0.0 captured = IOCapture.capture(rethrow = Union{}) do - # Without the custom `rrule` in `QuantumControlchainRulesCoreExt`, this + # Without the custom `rrule` in `QuantumControlChainRulesCoreExt`, this # test would show a potentially very confusing error, and throw an # `UndefRefError`. See also: https://discourse.julialang.org/t/136704/ Zygote.gradient(traj -> f(traj; Ψtgt, N), traj)[1] @@ -62,6 +64,8 @@ end if grad isa NamedTuple @test grad.initial_state isa Nothing @test grad.kwargs[:x] isa Vector + expected_grad = -Ψtgt .* conj(dot(Ψ, Ψtgt)) / N + @test grad.kwargs[:x] ≈ expected_grad end end From 62cd8e49866fb12d7d40afecc9fbdddc10e416ad Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Apr 2026 16:40:32 +0000 Subject: [PATCH 4/5] =?UTF-8?q?Replace=20=E2=89=88=20with=20norm-based=20c?= =?UTF-8?q?omparison=20in=20Zygote=20gradient=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Agent-Logs-Url: https://github.com/JuliaQuantumControl/QuantumControl.jl/sessions/db517dc6-313e-44f8-8f48-90f2a7f85542 Co-authored-by: goerz <112306+goerz@users.noreply.github.com> --- test/test_traj_zygote.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_traj_zygote.jl b/test/test_traj_zygote.jl index 72c593e9..41eac12e 100644 --- a/test/test_traj_zygote.jl +++ b/test/test_traj_zygote.jl @@ -32,7 +32,7 @@ end @test grad isa NamedTuple @test grad.initial_state isa Vector expected_grad = -Ψtgt .* conj(dot(Ψ, Ψtgt)) / N - @test grad.initial_state ≈ expected_grad + @test norm(grad.initial_state - expected_grad) < 1e-14 end @@ -65,7 +65,7 @@ end @test grad.initial_state isa Nothing @test grad.kwargs[:x] isa Vector expected_grad = -Ψtgt .* conj(dot(Ψ, Ψtgt)) / N - @test grad.kwargs[:x] ≈ expected_grad + @test norm(grad.kwargs[:x] - expected_grad) < 1e-14 end end From c8cf47a50dde16f5d4deafba9d2798ac52589440 Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Tue, 14 Apr 2026 13:15:11 -0400 Subject: [PATCH 5/5] Fix analytical derivative in test --- test/test_traj_zygote.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_traj_zygote.jl b/test/test_traj_zygote.jl index 41eac12e..72dbe784 100644 --- a/test/test_traj_zygote.jl +++ b/test/test_traj_zygote.jl @@ -31,7 +31,7 @@ end grad = Zygote.gradient(traj -> f(traj; Ψtgt, N), traj)[1] @test grad isa NamedTuple @test grad.initial_state isa Vector - expected_grad = -Ψtgt .* conj(dot(Ψ, Ψtgt)) / N + expected_grad = -2 .* Ψtgt .* conj(dot(Ψ, Ψtgt)) / N @test norm(grad.initial_state - expected_grad) < 1e-14 end @@ -64,7 +64,7 @@ end if grad isa NamedTuple @test grad.initial_state isa Nothing @test grad.kwargs[:x] isa Vector - expected_grad = -Ψtgt .* conj(dot(Ψ, Ψtgt)) / N + expected_grad = -2 .* Ψtgt .* conj(dot(Ψ, Ψtgt)) / N @test norm(grad.kwargs[:x] - expected_grad) < 1e-14 end