Skip to content

Test: Add Mooncake AD testing to conv layer test infrastructure#641

Merged
CarloLucibello merged 2 commits intoJuliaGraphs:masterfrom
Parvm1102:mooncake-ad-testing
Mar 5, 2026
Merged

Test: Add Mooncake AD testing to conv layer test infrastructure#641
CarloLucibello merged 2 commits intoJuliaGraphs:masterfrom
Parvm1102:mooncake-ad-testing

Conversation

@Parvm1102
Copy link
Contributor

This is related to this issue #640
I have implemented Mooncake AD testing to Conv layers (conv.jl).

  • Modified the test_gradients function in the GraphNeuralNetworks/test/test_module.jl file to accomodate Mooncake AD if test_mooncake flag is true.
  • I found out that DConv and ChebConv are not fully compatible with mooncake.
  • EGNNConv was excluded because it was still TODO.
  • CGConv and GMMConv excluded because of segmentation fault (commented out earlier).

@Parvm1102
Copy link
Contributor Author

@CarloLucibello Is this looking okay? I can proceed to check other layers too.

if test_mooncake
# Mooncake gradient with respect to input, compared against Zygote.
loss_mc_x = (xs...) -> loss(f, graph, xs...)
_cache_x = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_x, xs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this invokelatest?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is related to the world age, the import Mooncake and TestItemRunner's eval have different world ages and throws error. It is prevented by invokelatest.

@CarloLucibello
Copy link
Member

CarloLucibello commented Mar 4, 2026

tests related to this PR are failing (GNN julia 1).

On julia <1.12 mooncake testing should be skipped. You can define a global flag in test/runtests.jl

const TEST_MOONCAKE = VERSION >= v"1.12" 

@CarloLucibello
Copy link
Member

Also, instead of calling Mooncake directly, we can call Flux.gradient(AutoMooncake(), ...) or ``Flux.withgradient(AutoMooncake(), ...)`.

See the testing function https://github.com/FluxML/Flux.jl/blob/master/test/test_utils.jl#L82,
which we should try to replicate here (leaving out reactant for the time being).

@Parvm1102
Copy link
Contributor Author

Parvm1102 commented Mar 4, 2026

Also, instead of calling Mooncake directly, we can call Flux.gradient(AutoMooncake(), ...) or ``Flux.withgradient(AutoMooncake(), ...)`.

See the testing function https://github.com/FluxML/Flux.jl/blob/master/test/test_utils.jl#L82, which we should try to replicate here (leaving out reactant for the time being).

@CarloLucibello, thanks for the detailed review! For Flux.withgradient(AutoMooncake(), ...), the FluxMooncakeExt hardcodes Mooncake.Config(friendly_tangents=true) and ignores the config field passed via AutoMooncake. (https://github.com/FluxML/Flux.jl/blob/ce4b8a081aec37d5f8144044a5a7940f47210551/ext/FluxMooncakeExt.jl#L11)

With friendly_tangents=true, Mooncake tries to deep-copy the closure's captured variables (including GNNGraph), which fails because DataStore contains Dict{Symbol,Any}. I'm calling Mooncake's API directly with the default config as a workaround.

I will include the version check for julia 1.12.

@CarloLucibello
Copy link
Member

With friendly_tangents=true, Mooncake tries to deep-copy the closure's captured variables (including GNNGraph), which fails because DataStore contains Dict{Symbol,Any}. I'm calling Mooncake's API directly with the default config as a workaround.

If you could provide an example, this should be reported to the Mooncake repo. Ok having the workaroud for the time being.

@CarloLucibello
Copy link
Member

There are a couple of layers that fail tests. We should:

  • avoid testing them on mooncake so that we let the tests pass and can merge this PR
  • create a list with the broken layers to keep track of them (e.g. in check compatibility with Mooncake #640 )
  • create MWE and open issues in Mooncake.jl

Signed-off-by: Parvm1102 <parvmittal31757@gmail.com>
@Parvm1102 Parvm1102 force-pushed the mooncake-ad-testing branch from 7aeae52 to 84d49d1 Compare March 4, 2026 20:38
@Parvm1102
Copy link
Contributor Author

@CarloLucibello I have updated my PR and the checks are passing now. I had to disable mooncake testing for four layers:

  • ChebConv
  • DConv
  • GATv2Conv
  • TransformerConv

I have also made an MWE:

using Mooncake
# _copy_output fails on any Dict
Mooncake._copy_output(Dict(:x => 1))
# TypeError: in new, expected DataType, got Type{Symbol}

# This breaks prepare_gradient_cache when friendly_tangents=true
struct DataStore
    _n::Int
    _data::Dict{Symbol, Any}
end

f(x, ds) = sum(x) * Float32(ds._n)
x = randn(Float32, 3, 3)
ds = DataStore(3, Dict{Symbol,Any}(:x => x))

# FAILS: friendly_tangents=true triggers _copy_output on all args
Mooncake.prepare_gradient_cache(f, x, ds; config=Mooncake.Config(friendly_tangents=true))

# WORKS: default (friendly_tangents=false) is fine
cache = Mooncake.prepare_gradient_cache(f, x, ds)
Mooncake.value_and_gradient!!(cache, f, x, ds)  # works

@CarloLucibello
Copy link
Member

I have also made an MWE:

Nice! Could open an issue in Mooncake.jl with it?

@CarloLucibello CarloLucibello merged commit c4c69f4 into JuliaGraphs:master Mar 5, 2026
8 of 10 checks passed
@Parvm1102
Copy link
Contributor Author

Nice! Could open an issue in Mooncake.jl with it?

I will open it, and should I also check mooncake compatibility for layers other than Conv?

@CarloLucibello
Copy link
Member

yes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants