Add support for gradients w.r.t. trajectories#101
Conversation
This allows to call `Zygote.gradient` for a function where a `Trajectory` is the argument. See also https://discourse.julialang.org/t/136704 While this is probably not something that people would do _directly_, the lack of a custom `rrule` was causing issues with Zygote constructing the derivative of state-dependent running costs where information relevant to the running cost was stored in a custom property of the relevant `Trajectory`.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #101 +/- ##
========================================
+ Coverage 88.2% 88.3% +0.2%
========================================
Files 19 20 +1
Lines 1035 1046 +11
========================================
+ Hits 912 923 +11
Misses 123 123 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
Adds ChainRules support so Zygote can take gradients w.r.t. Trajectory inputs when Trajectory properties are accessed via the custom getproperty (including properties stored in kwargs).
Changes:
- Add a
ChainRulesCore.rruleforgetproperty(::Trajectory, ::Symbol)via a new package extension. - Add regression tests exercising Zygote gradients through
traj.initial_stateand akwargs-backed property (traj.x). - Wire the new tests into the trajectories testset and document the AD implication in
getproperty.
Reviewed changes
Copilot reviewed 4 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
ext/QuantumControlChainRulesCoreExt.jl |
Introduces a ChainRulesCore extension providing an rrule for getproperty on Trajectory. |
src/trajectories.jl |
Documents that the custom getproperty behavior relies on a custom rrule for AD frameworks like Zygote. |
test/test_traj_zygote.jl |
Adds tests to ensure Zygote gradients work for field and kwargs-backed properties on Trajectory. |
test/runtests.jl |
Includes the new trajectory/Zygote test file in the trajectories testset. |
Project.toml |
Registers ChainRulesCore as a weak dep and adds the new extension mapping. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@copilot apply changes based on the comments in this thread |
… Zygote tests Agent-Logs-Url: https://github.com/JuliaQuantumControl/QuantumControl.jl/sessions/8d24cbf1-8f4b-49f7-9008-ac22384e7772 Co-authored-by: goerz <112306+goerz@users.noreply.github.com>
|
Just as a heads up, I was blocked by some firewall rules while working on your feedback. Expand below for details. Warning Firewall rules blocked me from connecting to one or more addresses (expand for details)I tried to connect to the following addresses, but was blocked by firewall rules:
If you need me to access, download, or install something from one of these locations, you can either:
|
Agent-Logs-Url: https://github.com/JuliaQuantumControl/QuantumControl.jl/sessions/db517dc6-313e-44f8-8f48-90f2a7f85542 Co-authored-by: goerz <112306+goerz@users.noreply.github.com>
This allows to call
Zygote.gradientfor a function where aTrajectoryis the argument.See also https://discourse.julialang.org/t/136704
While this is probably not something that people would do directly, the lack of a custom
rrulewas causing issues with Zygote constructing the derivative of state-dependent running costs where information relevant to the running cost was stored in a custom property of the relevantTrajectory.This might also be considered as working around a possible bug in Zygote/Chainrules: FluxML/Zygote.jl#1610