Conversation
| (kw::KwFunc)(args...) = kw.kwf(args...) | ||
|
|
||
| function ChainRulesCore.rrule(::typeof(Core.kwcall), kwargs, f, args...) | ||
| rrule(KwFunc(f), kwargs, f, args...) |
There was a problem hiding this comment.
why isn't this
| rrule(KwFunc(f), kwargs, f, args...) | |
| rrule(f, args...; kwargs...) |
is that the same, or is it different?
There was a problem hiding this comment.
Because (I think) we want to hit this rrule for KwFunc:
https://github.com/JuliaDiff/Diffractor.jl/pull/270/files#diff-5bd76352c0319c1d2659edc9825a90b92ac204aa1c17b9b95a3880e0fd98855eR255
There was a problem hiding this comment.
I am not exactly sure why the KwFunc struct is needed though.. it seems like could be done via rrule(::typeof(Core.kwcall), kwargs, f, args...) directly?
There was a problem hiding this comment.
Removing the KwFunc and dispatching on kwcall directly seems to work, but I was afraid to remove something which I don't exactly understand
function ChainRulesCore.rrule(::typeof(Core.kwcall), kwargs, f, args...)
r = Core.kwfunc(rrule)(kwargs, rrule, f, args...)
if r === nothing
return nothing
end
x, back = r
x, Δ->begin
(NoTangent(), NoTangent(), back(Δ)...)
end
endThere was a problem hiding this comment.
Oh this might be the thing that is there to avoid ADing through so much of the kwarg machinery in the nested AD case
|
Same as in #266 I think in order for this not to make inference worse, the method should be split into kw and non-kw versions. |
| function (::∂⃖{N})(::typeof(Core.kwcall), kwargs, f::T, args...) where {T, N} | ||
| if N == 1 | ||
| # Base case (inlined to avoid ambiguities with manually specified | ||
| # higher order rules) | ||
| z = rrule(DiffractorRuleConfig(), KwFunc(f), kwargs, f, args...) |
There was a problem hiding this comment.
its basically a copy of the non-kw version of the function, but that is what we want in order to avoid ADing through the kw machinery, if there are no kws if I understand correctly?
This seems to fix non-differentiable keyword arguments by constructing the
KwFuncdefined in Diffractor.