Conversation
|
@Technici4n I'd love to have some feedback on this, and the related chalk-lab/Mooncake.jl#548 it enables! |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #975 +/- ##
==========================================
- Coverage 98.20% 97.27% -0.93%
==========================================
Files 135 131 -4
Lines 8000 8008 +8
==========================================
- Hits 7856 7790 -66
- Misses 144 218 +74
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
I'm not very familiar with contexts but the basic idea sounds reasonable. Why is the nan poisoning only used now and wasn't previously? Does this PR introduce a new kind of context? |
|
The idea behind contexts is summed up concisely in https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/stable/explanation/arguments/. |
There was a problem hiding this comment.
Pull request overview
This PR extends DifferentiateWith to support additional “context” arguments by introducing a context_wrappers field, and updates backend integrations (ChainRulesCore/Zygote, Mooncake) plus internal utilities/tests to preserve array formats when converting between array-of-tuples and tuple-of-arrays.
Changes:
- Extend
DifferentiateWithto acceptcontext_wrappers(default()), updateshow, and adjust ChainRulesCore/Mooncake/ForwardDiff integrations accordingly. - Add/adjust internal array conversion utility
arroftup_to_tupofarr(tx, x)to preserve the primal’s array type (incl. a GPUArraysCore-specific method). - Expand test coverage for “wrong-mode” array format preservation and for
DifferentiateWithscenarios (normal/Constant/Cache).
Reviewed changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| DifferentiationInterface/src/misc/differentiate_with.jl | Adds context_wrappers to DifferentiateWith, default constructor, updated call/show behavior. |
| DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl | Updates rrule to support contexts and return unknown tangents for context derivatives. |
| DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl | Adds Mooncake primitive handling for DifferentiateWith{C} and poisons context derivatives in reverse mode. |
| DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl | Adds nanify* helpers intended to poison Mooncake fdata/rdata for contexts. |
| DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl | Restricts ForwardDiff override to DifferentiateWith{0} (no contexts). |
| DifferentiationInterface/src/utils/linalg.jl | Refactors arroftup_to_tupofarr to preserve primal array type via similar(x). |
| DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl | Adds GPU specialization for arroftup_to_tupofarr using Adapt.adapt. |
| DifferentiationInterface/src/first_order/pushforward.jl | Updates calls to arroftup_to_tupofarr to pass the primal output y. |
| DifferentiationInterface/src/first_order/pullback.jl | Updates calls to arroftup_to_tupofarr to pass the primal input x. |
| DifferentiationInterface/src/utils/context.jl | Adds DI.call helper used to apply wrappers cleanly in mappings. |
| DifferentiationInterface/test/Core/Internals/linalg.jl | Adds tests ensuring arroftup_to_tupofarr preserves array/container types. |
| DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl | Adds regression tests for array format preservation. |
| DifferentiationInterface/test/Core/Internals/display.jl | Updates expected DifferentiateWith string representation. |
| DifferentiationInterface/test/Back/DifferentiateWith/test.jl | Expands DI tests to normal/Constant/Cache scenario variants and updates Mooncake error expectations. |
| DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl | Imports @not_implemented for unknown context tangents. |
| DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl | Adds Mooncake imports needed for new poisoning utilities. |
| DifferentiationInterface/Project.toml | Adds Adapt as a weakdep + compat for GPU array conversions. |
Comments suppressed due to low confidence (1)
DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl:15
DifferentiateWithsupports context arguments, but the ForwardDiff override only handlesDifferentiateWith{0}. ForDifferentiateWith{C}withC>0, ForwardDiff will fall back to the generic call implementation and will propagateDuals through context arguments if they contain them, which can silently produce incorrect derivatives (the PR description notes this problem). Consider adding a more specific method forDifferentiateWith{C}+Dualthat throws an explicit error (or otherwise marks the context derivative as unknown) so ForwardDiff usage with contexts fails loudly instead of computing the wrong result.
function (dw::DI.DifferentiateWith{0})(x::Dual{T, V, N}) where {T, V, N}
(; f, backend) = dw
xval = myvalue(T, x)
tx = mypartials(T, Val(N), x)
y, ty = DI.value_and_pushforward(f, backend, xval, tx)
return make_dual(T, y, ty)
end
function (dw::DI.DifferentiateWith{0})(x::AbstractArray{Dual{T, V, N}}) where {T, V, N}
(; f, backend) = dw
xval = myvalue(T, x)
tx = mypartials(T, Val(N), x)
y, ty = DI.value_and_pushforward(f, backend, xval, tx)
return make_dual(T, y, ty)
end
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Partial fix for #806, #675.
Since DI doesn't compute derivatives with respect to contexts, in each implementation we need a way to mark the derivative as unknown.
DifferentiateWithwith acontext_wrappersfield, denoting how additional arguments beyondxmust be taken into account.ChainRulesCore.@not_implementedtangentFDataandRDataof context arguments usingNaNConstantcontext containingDuals would make the derivative wrong. We could allowCaches though.frule!!doesn't yet exist forDifferentiateWith