Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions GraphNeuralNetworks/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
45 changes: 25 additions & 20 deletions GraphNeuralNetworks/test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ end
l = GCNConv(D_IN => D_OUT)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end

l = GCNConv(D_IN => D_OUT, tanh, bias = false)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end

l = GCNConv(D_IN => D_OUT, add_self_loops = false)
for g in TEST_GRAPHS
has_isolated_nodes(g) && continue
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end
end

Expand All @@ -49,7 +49,7 @@ end
l = GCNConv(1 => 1, add_self_loops = false, use_edge_weight = true)
@test gradient(w -> sum(l(g, x, w)), w)[1] isa AbstractVector{Float32} # redundant test but more explicit
@test size(l(g, x, w)) == (1, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end

@testset "conv_weight" begin
Expand Down Expand Up @@ -86,6 +86,7 @@ end
for g in TEST_GRAPHS
g = add_self_loops(g)
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
# Note: test_mooncake not enabled for ChebConv (Mooncake backward pass error)
test_gradients(l, g, g.x, rtol = RTOL_LOW)
end

Expand Down Expand Up @@ -124,13 +125,13 @@ end
l = GraphConv(D_IN => D_OUT)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end

l = GraphConv(D_IN => D_OUT, tanh, bias = false, aggr = mean)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end

@testset "bias=false" begin
Expand All @@ -157,7 +158,7 @@ end
l = GATConv(D_IN => D_OUT; heads, concat, dropout=0)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (concat ? heads * D_OUT : D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_LOW)
test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
end
end

Expand All @@ -166,7 +167,7 @@ end
l = GATConv((D_IN, ein) => D_OUT, add_self_loops = false, dropout=0)
g = GNNGraph(TEST_GRAPHS[1], edata = rand(Float32, ein, TEST_GRAPHS[1].num_edges))
@test size(l(g, g.x, g.e)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW)
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
end

@testset "num params" begin
Expand Down Expand Up @@ -197,6 +198,7 @@ end
l = GATv2Conv(D_IN => D_OUT, tanh; heads, concat, dropout=0)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (concat ? heads * D_OUT : D_OUT, g.num_nodes)
# Mooncake backward pass error for this layer on CI
test_gradients(l, g, g.x, rtol = RTOL_LOW, atol=ATOL_LOW)
end
end
Expand All @@ -206,6 +208,7 @@ end
l = GATv2Conv((D_IN, ein) => D_OUT, add_self_loops = false, dropout=0)
g = GNNGraph(TEST_GRAPHS[1], edata = rand(Float32, ein, TEST_GRAPHS[1].num_edges))
@test size(l(g, g.x, g.e)) == (D_OUT, g.num_nodes)
# Mooncake backward pass error for this layer on CI
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, atol=ATOL_LOW)
end

Expand Down Expand Up @@ -239,7 +242,7 @@ end

for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end
end

Expand All @@ -260,7 +263,7 @@ end
l = EdgeConv(Dense(2 * D_IN, D_OUT), aggr = +)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end
end

Expand All @@ -281,7 +284,7 @@ end
l = GINConv(nn, 0.01, aggr = mean)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end

@test !in(:eps, Flux.trainable(l))
Expand All @@ -307,7 +310,7 @@ end
for g in TEST_GRAPHS
g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges))
@test size(l(g, g.x, g.e)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, g.e, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, g.e, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end
end

Expand All @@ -332,7 +335,7 @@ end
l = SAGEConv(D_IN => D_OUT, tanh, bias = false, aggr = +)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end
end

Expand All @@ -351,7 +354,7 @@ end
l = ResGatedGraphConv(D_IN => D_OUT, tanh, bias = true)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end
end

Expand Down Expand Up @@ -411,7 +414,7 @@ end
Flux.trainable(l) == (; β = [1f0])
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_IN, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end
end

Expand All @@ -437,7 +440,7 @@ end
y = l(g, x, e)
return mean(y[1]) + sum(y[2])
end
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW; loss)
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW; loss, test_mooncake = TEST_MOONCAKE)
end
end

Expand Down Expand Up @@ -491,13 +494,13 @@ end
l = SGConv(D_IN => D_OUT, k, add_self_loops = true)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end

l = SGConv(D_IN => D_OUT, k, add_self_loops = true)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end
end
end
Expand All @@ -520,13 +523,13 @@ end
l = TAGConv(D_IN => D_OUT, k, add_self_loops = true)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end

l = TAGConv(D_IN => D_OUT, k, add_self_loops = true)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end
end
end
Expand Down Expand Up @@ -565,6 +568,7 @@ end
ein = 2
heads = 3
# used like in Kool et al., 2019
# Mooncake backward pass error for this layer on CI
l = TransformerConv(D_IN * heads => D_IN; heads, add_self_loops = true,
root_weight = false, ff_channels = 10, skip_connection = true,
batch_norm = false)
Expand Down Expand Up @@ -616,6 +620,7 @@ end
l = DConv(D_IN => D_OUT, k)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
# Note: test_mooncake not enabled for DConv (Mooncake backward pass error)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
end
end
Expand Down
2 changes: 2 additions & 0 deletions GraphNeuralNetworks/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using TestItemRunner
## for how to run the tests within VS Code.
## See test_module.jl for the test infrastructure.

const TEST_MOONCAKE = VERSION >= v"1.12"

## Uncomment below to change the default test settings
# ENV["GNN_TEST_CPU"] = "false"
# ENV["GNN_TEST_CUDA"] = "true"
Expand Down
36 changes: 32 additions & 4 deletions GraphNeuralNetworks/test/test_module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ using ChainRulesTestUtils, FiniteDifferences
using Zygote: Zygote
using SparseArrays

# Mooncake.jl requires Julia >= 1.12
const TEST_MOONCAKE = VERSION >= v"1.12"
if TEST_MOONCAKE
import Mooncake
end

# from Base
export mean, randn, SparseArrays, AbstractSparseMatrix
Expand All @@ -45,7 +50,7 @@ export random_regular_graph, erdos_renyi
# from this module
export D_IN, D_OUT, GRAPH_TYPES, TEST_GRAPHS,
test_gradients, finitediff_withgradient,
check_equal_leaves, gpu_backend
check_equal_leaves, gpu_backend, TEST_MOONCAKE


const D_IN = 3
Expand Down Expand Up @@ -82,12 +87,13 @@ function test_gradients(
test_grad_f = true,
test_grad_x = true,
compare_finite_diff = true,
test_mooncake = false,
loss = (f, g, xs...) -> mean(f(g, xs...)),
)

if !test_gpu && !compare_finite_diff
error("You should either compare finite diff vs CPU AD \
or CPU AD vs GPU AD.")
if !test_gpu && !compare_finite_diff && !test_mooncake
error("You should either compare finite diff vs CPU AD, \
CPU AD vs GPU AD, or test Mooncake AD.")
end

## Let's make sure first that the forward pass works.
Expand Down Expand Up @@ -116,6 +122,17 @@ function test_gradients(
check_equal_leaves(g, g_fd; rtol, atol)
end

if test_mooncake
# Mooncake gradient with respect to input, compared against Zygote.
loss_mc_x = (xs...) -> loss(f, graph, xs...)
_cache_x = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_x, xs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this invokelatest?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is related to the world age, the import Mooncake and TestItemRunner's eval have different world ages and throws error. It is prevented by invokelatest.

y_mc, g_mc = Base.invokelatest(Mooncake.value_and_gradient!!, _cache_x, loss_mc_x, xs...)
@assert isapprox(y, y_mc; rtol, atol)
for i in eachindex(xs)
@assert isapprox(g[i], g_mc[i+1]; rtol, atol)
end
end

if test_gpu
# Zygote gradient with respect to input on GPU.
y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu, graph_gpu, xs...), xs_gpu...)
Expand All @@ -139,6 +156,17 @@ function test_gradients(
check_equal_leaves(g, g_fd; rtol, atol)
end

if test_mooncake
# Mooncake gradient with respect to f, compared against Zygote.
ps_mc, re_mc = Flux.destructure(f)
loss_mc_f = ps -> loss(re_mc(ps), graph, xs...)
_cache_f = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_f, ps_mc)
y_mc, g_mc = Base.invokelatest(Mooncake.value_and_gradient!!, _cache_f, loss_mc_f, ps_mc)
@assert isapprox(y, y_mc; rtol, atol)
g_mc_f = (re_mc(g_mc[2]),)
check_equal_leaves(g, g_mc_f; rtol, atol)
end

if test_gpu
# Zygote gradient with respect to f on GPU.
y_gpu, g_gpu = Zygote.withgradient(f -> loss(f,graph_gpu, xs_gpu...), f_gpu)
Expand Down
Loading