From ccf89a0e9c90a34dde8f66d5bede8dcd656ccf20 Mon Sep 17 00:00:00 2001 From: Romain Poncet Date: Tue, 1 Aug 2023 08:49:01 +0100 Subject: [PATCH 1/8] wip --- src/ReactiveGraphs.jl | 11 ++- src/compilation.jl | 111 +++++++++++++++++++----------- src/foldl.jl | 2 +- src/graph.jl | 154 ++++++++++++++++-------------------------- src/inlinedmap.jl | 5 +- src/input.jl | 4 +- src/map.jl | 6 +- src/updated.jl | 3 +- todo.txt | 5 ++ 9 files changed, 152 insertions(+), 149 deletions(-) create mode 100644 todo.txt diff --git a/src/ReactiveGraphs.jl b/src/ReactiveGraphs.jl index 3f4e0cf..34c4159 100644 --- a/src/ReactiveGraphs.jl +++ b/src/ReactiveGraphs.jl @@ -34,21 +34,20 @@ end TypeOrValue{X} = Union{X,Type{<:X}} -struct TypeSymbol{x} - TypeSymbol(x::Symbol) = new{x}() -end +# struct TypeSymbol{x} +# TypeSymbol(x::Symbol) = new{x}() +# end -getsymbol(::TypeOrValue{TypeSymbol{x}}) where {x} = x +# getsymbol(::TypeOrValue{TypeSymbol{x}}) where {x} = x -include("graph.jl") include("operations.jl") +include("graph.jl") include("trackers.jl") include("compilation.jl") genname(::Nothing) = gensym() genname(s::Symbol) = gensym(s) genname(s::AbstractString) = gensym(string(s)) -getoperationtype(node::Node) = getnode(node) |> getelement |> eltype include("input.jl") include("map.jl") diff --git a/src/compilation.jl b/src/compilation.jl index c582c3d..b36d6df 100644 --- a/src/compilation.jl +++ b/src/compilation.jl @@ -1,13 +1,44 @@ -# todo: ensure that the node exists, and is an input -struct Source{inputname,T,LN<:ListNode} - list::LN - function Source(listnode::LN, inputname::Symbol) where {LN<:ListNode} - node = getnode(listnode, TypeSymbol(inputname)) - T = eltype(node) - new{inputname,T,LN}(listnode) +struct CompiledNode{name,parentnames,Op<:Operation} + operation::Op + function CompiledNode(name::Symbol, parentnames::NTuple{N,Symbol}, operation::Operation) where {N} + new{name, parentnames, typeof(operation)}(operation) end end +getname(::TypeOrValue{CompiledNode{name}}) where {name} = name +getparentnames(::TypeOrValue{CompiledNode{name,parentnames}}) where {name,parentnames} = + parentnames +getoperationtype(::TypeOrValue{CompiledNode{name,parentnames,Op}}) where {name,parentnames,Op} = Op +getoperation(n::CompiledNode) = n.operation + +# function CompiledNode(node::Node) +# CompiledNode(node.name, Tuple(node.parentnames), node.operation) +# end + +struct CompiledGraph{N,T<:NTuple{N,CompiledNode},Tr<:AbstractGraphTracker} + nodes::T + tracker::Tr +end + +# CompiledGraph(node::Node) = CompiledGraph(node.ref[]) +# function CompiledGraph(nodes::Vector{Node}) +# Tuple(CompiledNode(n) for n in nodes) |> CompiledGraph +# end + +nodetypes(::TypeOrValue{CompiledGraph{N,T}}) where {N,T} = T.parameters + +@generated function Base.getindex(g::CompiledGraph, s::Symbol) + for (i, nodetype) in nodetypes(g) |> enumerate + getname(nodetype) == s && return :(g.nodes[$i]) + end + throw(ErrorException("symbol $s not found in graph")) +end + +struct Source{inputname,T} end + +Base.eltype(::TypeOrValue{Source{inputname,T}}) where {inputname,T} = T +getinputname(::TypeOrValue{Source{inputname,T}}) where {inputname,T} = inputname + """ Source(::Node) @@ -43,55 +74,57 @@ julia> i1 = input(Int) 37 ``` """ -Source(node::Node) = Source(getgraph(node)[], getname(node)) - -function Base.show(io::IO, s::Source{inputname,T}) where {inputname,T} - print(io, "Source($inputname, $T)") +function Source(node::Node) + @assert node.operation isa Input + new{node.name, eltype(node.operation)}() end -getlisttype(::TypeOrValue{Source{inputname,T,LN}}) where {inputname,T,LN} = LN -getinputname(::TypeOrValue{Source{inputname,T,LN}}) where {inputname,T,LN} = inputname +function compile(inputs::Node...; tracker::AbstractGraphTracker=NullGraphTracker()) + @assert !isempty(inputs) + sources = Source.(inputs) -@inline Base.push!(src::Source, x) = push!(src => x) -@inline Base.push!(monitor::AbstractGraphTracker, src::Source, x) = push!(monitor, src => x) -@inline Base.push!(p::Pair{<:Source,<:Any}...) = push!(NullGraphTracker(), p...) -@generated function Base.push!(monitor::AbstractGraphTracker, p::Pair{<:Source,<:Any}...) - src_types = p .|> fieldtypes .|> first - inputnames = getinputname.(src_types) - listtypes = getlisttype.(src_types) + @assert allequal(map(x->x.ref, inputs)) + nodes = inputs[1].ref[] + compilednodes = Tuple(CompiledNode(node.name, Tuple(node.parentnames), node.operation) for node in nodes) + compiledgraph = CompiledGraph(compilednodes, tracker) - if !allequal(listtypes) - throw(ErrorException("nodes do not belong to the same graph")) - end - LN = first(listtypes) + compiledgraph, sources... +end + +@inline Base.push!(graph::CompiledGraph, src::Source, x) = push!(graph, src => x) +@generated function Base.push!(graph::CompiledGraph, p::Pair{<:Source,<:Any}...) + sources = p .|> fieldtypes .|> first + generate(graph, sources...) +end +function generate(graph::Type{<:CompiledGraph}, sources::Type{<:Source}...) + inputnames = getinputname.(sources) expr = quote - on_update_start!(monitor, $inputnames) - list = p[1][1].list + on_update_start!(graph.monitor, $inputnames) + end + for compilednodetype in nodetypes(graph) + generate!(expr, compilednodetype, inputnames...) end - generate!(expr, LN, inputnames...) - push!(expr.args, :(on_update_stop!(monitor))) + push!(expr.args, :(on_update_stop!(graph.monitor))) push!(expr.args, nothing) expr end -generate!(::Expr, ::Type{Root}, ::Symbol...) = nothing function generate!( expr::Expr, - ::Type{ListNode{name,parentnames,X,Next}}, + ::Type{CompiledNode{name,parentnames,Op}}, inputnames::Symbol..., -) where {name,parentnames,X,Next} - generate!(expr, Next, inputnames...) - e = generate(inputnames, name, parentnames, X) +) where {name,parentnames,Op} + e = generate(inputnames, name, parentnames, Op) append!(expr.args, e.args) - push!(expr.args, :(on_update_node!(monitor, $(Meta.quot(name))))) + push!(expr.args, :(on_update_node!(graph.monitor, $(Meta.quot(name))))) expr end -# getvalue(list::ListNode, name::Symbol) = getnode(list, name) |> getvalue -getvalue(list::ListNode, v::TypeSymbol) = getnode(list, v) |> getvalue -getvalue(node::ListNode) = getvalue(node, getelement(node)) - -function debugsource(src::Source{inputnames,LN}) where {inputnames,LN} +function debugsource(graph::CompiledGraph, src::Source...) + inputnames = getinputname.(src_types) generate!(Expr(:block), LN, inputnames...) end + +getvalue(graph::CompiledGraph, name::Symbol) = getvalue(graph, name, graph[name].operation) +getvalue(graph::CompiledGraph, name::Symbol, op::Operation) = getvalue(op) diff --git a/src/foldl.jl b/src/foldl.jl index 6e0e2ae..4991300 100644 --- a/src/foldl.jl +++ b/src/foldl.jl @@ -46,7 +46,7 @@ function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type nodename_s = Symbol(:node, name) quote $updated_s = if $condition_initialized & $condition_updated - $nodename_s = getnode(list, $(TypeSymbol(name))) + $nodename_s = getelement(list, $(TypeSymbol(name))) $(Expr(:call, :update!, nodename_s, args...)) true else diff --git a/src/graph.jl b/src/graph.jl index 3cba936..b880e33 100644 --- a/src/graph.jl +++ b/src/graph.jl @@ -1,107 +1,73 @@ -struct Root end - -struct ListNode{name,parentnames,X,Next} - x::X - next::Next - function ListNode( - name::Symbol, - parentnames::NTuple{<:Any,Symbol}, - x, - next::Union{Root,ListNode}, - ) - new{name,parentnames,typeof(x),typeof(next)}(x, next) - end -end - -getname(::TypeOrValue{ListNode{name}}) where {name} = name -getparentnames(::TypeOrValue{ListNode{name,parentnames}}) where {name,parentnames} = - parentnames -getelementtype(::TypeOrValue{ListNode{name,parentnames,X}}) where {name,parentnames,X} = X -getelement(n::ListNode) = n.x -getnext(n::ListNode) = n.next -Base.eltype(x::TypeOrValue{<:ListNode}) = eltype(getelementtype(x)) - -mutable struct Graph - last::Ref{Union{Root,ListNode}} -end - -Graph() = Graph(Ref{Union{Root,ListNode}}(Root())) - -Base.:(==)(g1::Graph, g2::Graph) = g1.last == g2.last - -Base.getindex(g::Graph) = g.last[] -function Base.setindex!(g::Graph, x::Union{Root,ListNode}) - g.last[] = x - g -end - -Base.map(f::Function, g::Graph) = map(f, g.last[]) -Base.map(::Function, ::Root) = nothing -function Base.map(f::Function, n::ListNode) - map(f, n.next) - f(getname(n), getparentnames(n), n.x) -end - -function Base.push!(graph::Graph, name, parentnames, x) - graph[] = ListNode(name, parentnames, x, graph[]) - graph -end - -Base.merge!(graph::Graph) = graph -function Base.merge!(graph1::Graph, graph2::Graph, graphs...) - if graph1 != graph2 - map(graph2) do name, parentnames, x - push!(graph1, name, parentnames, x) - end - graph2.last = graph1.last - end - merge!(graph1, graphs...) +mutable struct Node + ref::Base.RefValue{Vector{Node}} + @tryconst name::Symbol + @tryconst parentnames::Vector{Symbol} + @tryconst operation::Operation end -getnode(graph::Graph, s::TypeSymbol) = getnode(graph[], s) -getnode(::Root, s::TypeSymbol) = - throw(ErrorException("symbol $(getsymbol(s)) not found in graph")) -function getnode(x::ListNode, v::TypeSymbol) - if getname(x) == getsymbol(v) - x - else - getnode(x.next, v) - end +# todo: make sure name is unique +function Node(name::Symbol, op::Operation, parents::Node...) + ref = mergegraphs!(parents...) + parentnames = [a.name for a in parents] + node = Node(ref, name, parentnames, op) + push!(ref[].nodes, node) + node end -""" - Node{name} - -Objects of type `Node` correspond to the nodes of the computational graph. -Each node is identified by a uniquely generated name `name`. -""" -struct Node{name} - graph::Graph - Node(name::Symbol, graph::Graph) = new{name}(graph) -end +Base.eltype(node::Node) = eltype(node.operation) function Base.show(io::IO, node::Node) - name = getname(node) + name = node.name type = eltype(node) print(io, "Node($name,$type)") end -Base.eltype(node::Node) = node |> getnode |> eltype -getname(::TypeOrValue{Node{name}}) where {name} = name -getgraph(node::Node) = node.graph -getnode(node::Node) = getnode(getgraph(node), TypeSymbol(getname(node))) - -function Node(name::Symbol, x::X, parents::Node...) where {X} - graph = if isempty(parents) - Graph() - else - merge!((n.graph for n in parents)...) # return Root() if empty +mergegraphs!() = Graph(Node[], nothing) +mergegraphs!(node::Node) = node.ref +function mergegraphs!(node1::Node, node2::Node, nodes::Node...) + if node1.ref != node2.ref + append!(node1.ref[], node2.ref[]) + node2.ref = node1.ref end - parentnames = Tuple(getname(a) for a in parents) - push!(graph, name, parentnames, x) - Node(name, graph) + mergegraphs!(node1, nodes...) end -update!(node, args...) = update!(getelement(node), args...) -isinitialized(node, args...) = isinitialized(getelement(node), args...) -(node::Node)(args...) = getelement(node)(args...) +# """ +# Node{name} + +# Objects of type `Node` correspond to the nodes of the computational graph. +# Each node is identified by a uniquely generated name `name`. +# """ +# struct Node{name} +# graph::GraphRef +# Node(name::Symbol, graph::GraphRef) = new{name}(graph) +# end + +# function Base.show(io::IO, node::Node) +# name = getname(node) +# type = eltype(node) +# print(io, "Node($name,$type)") +# end + +# Base.eltype(node::Node) = node |> getedge |> eltype +# getname(::TypeOrValue{Node{name}}) where {name} = name +# getgraph(node::Node) = node.graph +# getedge(node::Node) = getedge(getgraph(node), TypeSymbol(getname(node))) + +# function Node(name::Symbol, op::Op, parents::Node...) where {Op <: Operation} +# graph = if isempty(parents) +# GraphRef() +# else +# merge!((n.graph for n in parents)...) # return Root() if empty +# end +# parentnames = Tuple(getname(a) for a in parents) +# push!(graph, name, parentnames, op) +# Node(name, graph) +# end + +# update!(node, args...) = update!(getedge(node), args...) +# # isinitialized(node, args...) = isinitialized(getedge(node), args...) +# (node::Node)(args...) = getedge(node)(args...) + +# @inline getvalue(::Graph, element::Operation) = getvalue(element) +# getoperationtype(node::Node) = getedge(node) |> getoperationtype diff --git a/src/inlinedmap.jl b/src/inlinedmap.jl index c35fb5b..8ba145c 100644 --- a/src/inlinedmap.jl +++ b/src/inlinedmap.jl @@ -44,9 +44,10 @@ function generate( end end -@generated function getvalue(node::ListNode, imap::InlinedMap) +@generated function getvalue(graph::CompiledGraph, name::Symbol, imap::InlinedMap) + node = graph[name] names = getparentnames(node) quote - Base.@ncall $(length(names)) imap.f i -> (getvalue(node, TypeSymbol($names[i]))) + Base.@ncall $(length(names)) imap.f i -> (getvalue(graph, $(Meta.quot(names[i])))) end end diff --git a/src/input.jl b/src/input.jl index a82f246..4959c25 100644 --- a/src/input.jl +++ b/src/input.jl @@ -66,7 +66,7 @@ function input(x::T; name::Union{Nothing,Symbol} = nothing) where {T} Node(uniquename, op) end -getvalue(::ListNode, element::Input) = getvalue(element) +getvalue(::Graph, element::Input) = getvalue(element) function generate( inputnames::NTuple{<:Any,Symbol}, @@ -82,7 +82,7 @@ function generate( expr = Expr(:quote) push!(expr.args, :($updated_s = $updated)) - push!(expr.args, :($nodename_s = getnode(list, $(TypeSymbol(name))))) + push!(expr.args, :($nodename_s = getedge(list, $(TypeSymbol(name))))) if updated push!( expr.args, diff --git a/src/map.jl b/src/map.jl index 91a93f7..7d39ac9 100644 --- a/src/map.jl +++ b/src/map.jl @@ -22,7 +22,7 @@ generated symbol that identifies the node. """ function Base.map(f::Function, arg::Node, args::Node...; name = nothing) uniquename = genname(name) - argtypes = getoperationtype.((arg, args...)) + argtypes = eltype.((arg, args...)) T = Base._return_type(f, Tuple{argtypes...}) op = Map{T}(f) Node(uniquename, op, arg, args...) @@ -32,7 +32,7 @@ end function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type{<:Map}) updated_s = Symbol(:updated, name) initialized_s = Symbol(:initialized, name) - args = (:(getvalue(list, $(TypeSymbol(n)))) for n in parentnames) + args = (:(getvalue(graph, $n)) for n in parentnames) condition_updated = Expr(:call, :|, (Symbol(:updated, n) for n in parentnames)...) condition_initialized = Expr(:call, :&, (Symbol(:initialized, n) for n in parentnames)...) @@ -40,7 +40,7 @@ function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type quote $initialized_s = $condition_initialized $updated_s = if $condition_initialized & $condition_updated - $nodename_s = getnode(list, $(TypeSymbol(name))) + $nodename_s = graph[$name] $(Expr(:call, :update!, nodename_s, args...)) true else diff --git a/src/updated.jl b/src/updated.jl index a4b32f8..9dbbf06 100644 --- a/src/updated.jl +++ b/src/updated.jl @@ -32,10 +32,9 @@ function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type quote $initialized_s = $condition_initialized $updated_s = $condition_initialized & $condition_updated - $nodename_s = getnode(list, $(TypeSymbol(name))) + $nodename_s = getedge(list, $(TypeSymbol(name))) $(Expr(:call, :update!, nodename_s, updated_s)) end end @inline getvalue(x::Updated) = x.updated -@inline getvalue(::ListNode, element::Updated) = getvalue(element) diff --git a/todo.txt b/todo.txt new file mode 100644 index 0000000..c4a3a88 --- /dev/null +++ b/todo.txt @@ -0,0 +1,5 @@ +# refacto of the logic +- how to name the edges? can we use integers instead of Symbols, and make the name optional? +getelementtype -> getoperationtype +getelement->getoperation +getnode => getedge From 8577c842ddce10942723ad818af968e52801ce93 Mon Sep 17 00:00:00 2001 From: Romain Poncet Date: Wed, 2 Aug 2023 10:16:02 +0100 Subject: [PATCH 2/8] passing test --- benchmark/Project.toml | 1 - benchmark/benchmark.jl | 12 +- src/ReactiveGraphs.jl | 9 +- src/compilation.jl | 33 +++--- src/constant.jl | 2 - src/filter.jl | 10 +- src/foldl.jl | 5 +- src/graph.jl | 25 +++-- src/inlinedmap.jl | 11 +- src/input.jl | 4 +- src/lag.jl | 2 +- src/map.jl | 5 +- src/selecter.jl | 10 +- src/updated.jl | 2 +- test/runtests.jl | 243 +++++++++++++++++++---------------------- 15 files changed, 181 insertions(+), 193 deletions(-) diff --git a/benchmark/Project.toml b/benchmark/Project.toml index da523e2..32e44d8 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -1,4 +1,3 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ReactiveGraphs = "e49d8811-4385-4fad-be34-8522254608f3" -Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/benchmark/benchmark.jl b/benchmark/benchmark.jl index 82a5426..2b8831d 100644 --- a/benchmark/benchmark.jl +++ b/benchmark/benchmark.jl @@ -1,4 +1,3 @@ -using Revise using ReactiveGraphs using BenchmarkTools @@ -15,8 +14,9 @@ n5 = lag(1, n4) s1 = Source(i1) s2 = Source(i2) s3 = Source(i3) -s1[] = 1 -s2[] = true -s3[] = true -v = 1 -@benchmark setindex!($s1, $v) +g, s1, s2, s3 = compile(i1, i2, i3) +push!(g, s1, 1) +push!(g, s2, true) +push!(g, s3, true) + +@benchmark push!($g, $s1, 1) diff --git a/src/ReactiveGraphs.jl b/src/ReactiveGraphs.jl index 34c4159..518ed7a 100644 --- a/src/ReactiveGraphs.jl +++ b/src/ReactiveGraphs.jl @@ -11,6 +11,7 @@ export updated export PerformanceGraphTracker export gettrackingnodes, gettrackingtriggers +export compile macro tryinline(e) @static if VERSION >= v"1.8" @@ -34,11 +35,11 @@ end TypeOrValue{X} = Union{X,Type{<:X}} -# struct TypeSymbol{x} -# TypeSymbol(x::Symbol) = new{x}() -# end +struct TypeSymbol{x} + TypeSymbol(x::Symbol) = new{x}() +end -# getsymbol(::TypeOrValue{TypeSymbol{x}}) where {x} = x +getsymbol(::TypeOrValue{TypeSymbol{x}}) where {x} = x include("operations.jl") include("graph.jl") diff --git a/src/compilation.jl b/src/compilation.jl index b36d6df..3de16c9 100644 --- a/src/compilation.jl +++ b/src/compilation.jl @@ -27,11 +27,11 @@ end nodetypes(::TypeOrValue{CompiledGraph{N,T}}) where {N,T} = T.parameters -@generated function Base.getindex(g::CompiledGraph, s::Symbol) - for (i, nodetype) in nodetypes(g) |> enumerate - getname(nodetype) == s && return :(g.nodes[$i]) +@generated function Base.getindex(g::CompiledGraph{N}, ::TypeSymbol{name}) where {N, name} + for (i, node) in nodetypes(g) |> enumerate + getname(node) == name && return :(g.nodes[$i]) end - throw(ErrorException("symbol $s not found in graph")) + throw(ErrorException("symbol $name not found in graph $g")) end struct Source{inputname,T} end @@ -76,15 +76,15 @@ julia> i1 = input(Int) """ function Source(node::Node) @assert node.operation isa Input - new{node.name, eltype(node.operation)}() + Source{node.name, eltype(node.operation)}() end function compile(inputs::Node...; tracker::AbstractGraphTracker=NullGraphTracker()) @assert !isempty(inputs) sources = Source.(inputs) - @assert allequal(map(x->x.ref, inputs)) - nodes = inputs[1].ref[] + @assert allequal(map(x->x.graph.ref, inputs)) + nodes = inputs[1].graph.ref[] compilednodes = Tuple(CompiledNode(node.name, Tuple(node.parentnames), node.operation) for node in nodes) compiledgraph = CompiledGraph(compilednodes, tracker) @@ -100,12 +100,12 @@ end function generate(graph::Type{<:CompiledGraph}, sources::Type{<:Source}...) inputnames = getinputname.(sources) expr = quote - on_update_start!(graph.monitor, $inputnames) + on_update_start!(graph.tracker, $(inputnames)) end for compilednodetype in nodetypes(graph) generate!(expr, compilednodetype, inputnames...) end - push!(expr.args, :(on_update_stop!(graph.monitor))) + push!(expr.args, :(on_update_stop!(graph.tracker))) push!(expr.args, nothing) expr end @@ -117,14 +117,17 @@ function generate!( ) where {name,parentnames,Op} e = generate(inputnames, name, parentnames, Op) append!(expr.args, e.args) - push!(expr.args, :(on_update_node!(graph.monitor, $(Meta.quot(name))))) + push!(expr.args, :(on_update_node!(graph.tracker, $(Meta.quot(name))))) expr end -function debugsource(graph::CompiledGraph, src::Source...) - inputnames = getinputname.(src_types) - generate!(Expr(:block), LN, inputnames...) +function debugsource(graph::CompiledGraph, sources::Source...) + generate(graph, sources...) end -getvalue(graph::CompiledGraph, name::Symbol) = getvalue(graph, name, graph[name].operation) -getvalue(graph::CompiledGraph, name::Symbol, op::Operation) = getvalue(op) +getoperation(graph::CompiledGraph, name::TypeSymbol) = graph[name].operation +function getvalue(graph::CompiledGraph, name::TypeSymbol) + node = graph[name] + getvalue(graph, node, getoperation(node)) +end +getvalue(::CompiledGraph, ::CompiledNode, op::Operation) = getvalue(op) # specific implementations should dispatch on this method diff --git a/src/constant.jl b/src/constant.jl index 044331f..be18f49 100644 --- a/src/constant.jl +++ b/src/constant.jl @@ -30,8 +30,6 @@ function constant(x; name = nothing) Node(uniquename, op) end -getvalue(::ListNode, element::AbstractConstant) = getvalue(element) - function generate(::Any, name::Symbol, parentnames::Tuple{}, ::Type{<:AbstractConstant}) updated_s = Symbol(:updated, name) initialized_s = Symbol(:initialized, name) diff --git a/src/filter.jl b/src/filter.jl index e69f56a..a817018 100644 --- a/src/filter.jl +++ b/src/filter.jl @@ -15,7 +15,7 @@ function Base.filter(x::Node, condition::Node; name = nothing) throw(ErrorException("eltype(condition) is $(eltype(condition)), expected Bool")) end uniquename = genname(name) - op = Filter{getoperationtype(x)}() + op = Filter{eltype(x)}() Node(uniquename, op, x, condition) end @@ -37,15 +37,15 @@ function Base.filter(f::Function, x::Node; name = nothing) filter(x, condition; name) end -function getvalue(node::ListNode, ::Filter) - node_name, _ = getparentnames(node) - getvalue(node, TypeSymbol(node_name)) # todo: avoid starting from the leaf +function getvalue(graph::CompiledGraph, node::CompiledNode, ::Filter) + parent_name = getparentnames(node) |> first |> TypeSymbol + getvalue(graph, parent_name) end function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type{<:Filter}) updated_s = Symbol(:updated, name) initialized_s = Symbol(:initialized, name) - args = [:(getvalue(list, $(TypeSymbol(n)))) for n in parentnames] + args = [:(getvalue(graph, $(TypeSymbol(n)))) for n in parentnames] condition_updated = Expr(:call, :|, (Symbol(:updated, n) for n in parentnames)...) condition_initialized = Expr(:call, :&, (Symbol(:initialized, n) for n in parentnames)...) diff --git a/src/foldl.jl b/src/foldl.jl index 4991300..63850ac 100644 --- a/src/foldl.jl +++ b/src/foldl.jl @@ -3,7 +3,6 @@ mutable struct Foldl{TState,F} <: Operation{TState} state::TState end @inline getvalue(x::Foldl) = x.state -@inline getvalue(::ListNode, element::Foldl) = getvalue(element) @inline function update!(m::Foldl, args...) @tryinline state = m.f(m.state, args...) @@ -39,14 +38,14 @@ end function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type{<:Foldl}) updated_s = Symbol(:updated, name) initialized_s = Symbol(:initialized, name) - args = (:(getvalue(list, $(TypeSymbol(n)))) for n in parentnames) + args = (:(getvalue(graph, $(TypeSymbol(n)))) for n in parentnames) condition_updated = Expr(:call, :|, (Symbol(:updated, n) for n in parentnames)...) condition_initialized = Expr(:call, :&, (Symbol(:initialized, n) for n in parentnames)...) nodename_s = Symbol(:node, name) quote $updated_s = if $condition_initialized & $condition_updated - $nodename_s = getelement(list, $(TypeSymbol(name))) + $nodename_s = getoperation(graph, $(TypeSymbol(name))) $(Expr(:call, :update!, nodename_s, args...)) true else diff --git a/src/graph.jl b/src/graph.jl index b880e33..5b5888d 100644 --- a/src/graph.jl +++ b/src/graph.jl @@ -1,5 +1,12 @@ +mutable struct Graph{N} + ref::Base.RefValue{Vector{N}} +end + +Graph{N}() where {N} = Graph(Ref(N[])) +Base.push!(g::Graph{N}, n::N) where {N} = push!(g.ref[], n) + mutable struct Node - ref::Base.RefValue{Vector{Node}} + graph::Graph{Node} @tryconst name::Symbol @tryconst parentnames::Vector{Symbol} @tryconst operation::Operation @@ -7,10 +14,10 @@ end # todo: make sure name is unique function Node(name::Symbol, op::Operation, parents::Node...) - ref = mergegraphs!(parents...) + graph = mergegraphs!(parents...) parentnames = [a.name for a in parents] - node = Node(ref, name, parentnames, op) - push!(ref[].nodes, node) + node = Node(graph, name, parentnames, op) + push!(graph, node) node end @@ -22,12 +29,12 @@ function Base.show(io::IO, node::Node) print(io, "Node($name,$type)") end -mergegraphs!() = Graph(Node[], nothing) -mergegraphs!(node::Node) = node.ref +mergegraphs!() = Graph{Node}() +mergegraphs!(node::Node) = node.graph function mergegraphs!(node1::Node, node2::Node, nodes::Node...) - if node1.ref != node2.ref - append!(node1.ref[], node2.ref[]) - node2.ref = node1.ref + if node1.graph.ref != node2.graph.ref + append!(node1.graph.ref[], node2.graph.ref[]) + node2.graph.ref = node1.graph.ref end mergegraphs!(node1, nodes...) end diff --git a/src/inlinedmap.jl b/src/inlinedmap.jl index 8ba145c..52ad78b 100644 --- a/src/inlinedmap.jl +++ b/src/inlinedmap.jl @@ -21,7 +21,7 @@ generated symbol that identifies the node. """ function inlinedmap(f, arg::Node, args::Node...; name = nothing) uniquename = genname(name) - argtypes = getoperationtype.((arg, args...)) + argtypes = eltype.((arg, args...)) T = Base._return_type(f, Tuple{argtypes...}) op = InlinedMap(T, f) Node(uniquename, op, arg, args...) @@ -44,10 +44,7 @@ function generate( end end -@generated function getvalue(graph::CompiledGraph, name::Symbol, imap::InlinedMap) - node = graph[name] - names = getparentnames(node) - quote - Base.@ncall $(length(names)) imap.f i -> (getvalue(graph, $(Meta.quot(names[i])))) - end +function getvalue(graph::CompiledGraph, node::CompiledNode, imap::InlinedMap) + parentnames = getparentnames(node) .|> TypeSymbol + imap.f((getvalue(graph, n) for n in parentnames)...) end diff --git a/src/input.jl b/src/input.jl index 4959c25..03bf49c 100644 --- a/src/input.jl +++ b/src/input.jl @@ -66,8 +66,6 @@ function input(x::T; name::Union{Nothing,Symbol} = nothing) where {T} Node(uniquename, op) end -getvalue(::Graph, element::Input) = getvalue(element) - function generate( inputnames::NTuple{<:Any,Symbol}, name::Symbol, @@ -82,7 +80,7 @@ function generate( expr = Expr(:quote) push!(expr.args, :($updated_s = $updated)) - push!(expr.args, :($nodename_s = getedge(list, $(TypeSymbol(name))))) + push!(expr.args, :($nodename_s = getoperation(graph, $(TypeSymbol(name))))) if updated push!( expr.args, diff --git a/src/lag.jl b/src/lag.jl index 61fbf9a..d87a668 100644 --- a/src/lag.jl +++ b/src/lag.jl @@ -44,7 +44,7 @@ julia> i = input(Int) ``` """ function lag(n::Integer, node::Node; name = nothing) - T = getoperationtype(node) + T = eltype(node) lagnode = foldl(Lag(T, n), node; name) do state, x push!(state, x) state diff --git a/src/map.jl b/src/map.jl index 7d39ac9..ea92686 100644 --- a/src/map.jl +++ b/src/map.jl @@ -5,7 +5,6 @@ mutable struct Map{T,F} <: Operation{T} end @inline getvalue(x::Map) = x.x -@inline getvalue(::ListNode, element::Map) = getvalue(element) @inline function update!(m::Map, args...) @tryinline m.x = m.f(args...) @@ -32,7 +31,7 @@ end function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type{<:Map}) updated_s = Symbol(:updated, name) initialized_s = Symbol(:initialized, name) - args = (:(getvalue(graph, $n)) for n in parentnames) + args = (:(getvalue(graph, $(TypeSymbol(n)))) for n in parentnames) condition_updated = Expr(:call, :|, (Symbol(:updated, n) for n in parentnames)...) condition_initialized = Expr(:call, :&, (Symbol(:initialized, n) for n in parentnames)...) @@ -40,7 +39,7 @@ function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type quote $initialized_s = $condition_initialized $updated_s = if $condition_initialized & $condition_updated - $nodename_s = graph[$name] + $nodename_s = getoperation(graph, $(TypeSymbol(name))) $(Expr(:call, :update!, nodename_s, args...)) true else diff --git a/src/selecter.jl b/src/selecter.jl index 0379dc5..9c5072b 100644 --- a/src/selecter.jl +++ b/src/selecter.jl @@ -15,7 +15,7 @@ function select(x::Node, condition::Node; name::Union{Nothing,Symbol} = nothing) throw(ErrorException("eltype(condition) is $(eltype(condition)), expected Bool")) end uniquename = genname(name) - op = Selecter{getoperationtype(x)}() + op = Selecter{eltype(x)}() Node(uniquename, op, x, condition) end @@ -37,9 +37,9 @@ function select(f::Function, x::Node; name::Union{Nothing,Symbol} = nothing) select(x, condition; name) end -function getvalue(node::ListNode, ::Selecter) - node_name, _ = getparentnames(node) - getvalue(node, TypeSymbol(node_name)) # todo: avoid starting from the leaf +function getvalue(graph::CompiledGraph, node::CompiledNode, ::Selecter) + parent_name = getparentnames(node) |> first |> TypeSymbol + getvalue(graph, parent_name) end function generate( @@ -50,7 +50,7 @@ function generate( ) updated_s = Symbol(:updated, name) initialized_s = Symbol(:initialized, name) - args = [:(getvalue(list, $(TypeSymbol(n)))) for n in parentnames] + args = [:(getvalue(graph, $(TypeSymbol(n)))) for n in parentnames] condition_updated = Expr(:call, :|, (Symbol(:updated, n) for n in parentnames)...) condition_initialized = Expr(:call, :&, (Symbol(:initialized, n) for n in parentnames)...) diff --git a/src/updated.jl b/src/updated.jl index 9dbbf06..71f2c53 100644 --- a/src/updated.jl +++ b/src/updated.jl @@ -32,7 +32,7 @@ function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type quote $initialized_s = $condition_initialized $updated_s = $condition_initialized & $condition_updated - $nodename_s = getedge(list, $(TypeSymbol(name))) + $nodename_s = getoperation(graph, $(TypeSymbol(name))) $(Expr(:call, :update!, nodename_s, updated_s)) end end diff --git a/test/runtests.jl b/test/runtests.jl index 1795c84..bcf1616 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using BenchmarkTools using ReactiveGraphs using Test -import ReactiveGraphs: Graph, getoperationtype, Node +import ReactiveGraphs: getoperationtype, Node macro testnoalloc(expr, kws...) esc(quote @@ -11,45 +11,43 @@ macro testnoalloc(expr, kws...) end) end -function collectgraph(g::Graph) - c = [] - map((_, _, x) -> push!(c, x), g) - c -end +# function collectgraph(g::Graph) +# c = [] +# map((_, _, x) -> push!(c, x), g) +# c +# end sink(arg::Node) = sink(identity, arg) function sink(f::Function, args::Node...) - T = Base._return_type(f, Tuple{(getoperationtype(a) for a in args)...}) + m = inlinedmap(f, args...) + T = eltype(m) v = Vector{T}() - map(args...) do x::Vararg - push!(v, f(x...)) - nothing - end + foldl(push!, v, m) v end -@testset "graph" begin - g1 = Graph() - @test g1 == merge!(g1) +# @testset "graph" begin +# g1 = mergegraphs!() +# @test g1 == merge!(g1) - push!(g1, :a, (), 1) - push!(g1, :b, (), 2) - @test collectgraph(g1) == [1, 2] +# push!(g1, :a, (), 1) +# push!(g1, :b, (), 2) +# @test collectgraph(g1) == [1, 2] - g2 = Graph() - push!(g2, :c, (), 3) - merge!(g1, g2) - @test g1 == g2 +# g2 = Graph() +# push!(g2, :c, (), 3) +# merge!(g1, g2) +# @test g1 == g2 - @test collectgraph(g1) == [1, 2, 3] -end +# @test collectgraph(g1) == [1, 2, 3] +# end @testset "input" begin @testset "1 input" begin n1 = input(Int) c = sink(x -> 2x, n1) - s = Source(n1) - push!(s, 2) + g, s = compile(n1) + push!(g, s => 2) @test c == [4] end @@ -57,18 +55,17 @@ end n1 = input(Int) n2 = input(Int) c = sink(+, n1, n2) - s1 = Source(n1) - s2 = Source(n2) - push!(s1 => 1, s2 => 2) - push!(s1 => 3, s2 => 4) + g, s1, s2 = compile(n1, n2) + push!(g, s1 => 1, s2 => 2) + push!(g, s1 => 3, s2 => 4) @test c == [3, 7] end @testset "1 mutable input" begin n1 = input(Ref(0)) c = sink(x -> 2x[], n1) - s = Source(n1) - push!(s, x -> x[] = 2) + g, s = compile(n1) + push!(g, s, x -> x[] = 2) @test c == [4] end @@ -76,18 +73,15 @@ end n1 = input(Ref(0)) n2 = input(Int) c = sink((x, y) -> 2x[] + y, n1, n2) - s1 = Source(n1) - s2 = Source(n2) - push!(s1 => x -> x[] = 2, s2 => 3) + g, s1, s2 = compile(n1, n2) + push!(g, s1 => x -> x[] = 2, s2 => 3) @test c == [7] end @testset "disjoint graphs" begin n1 = input(Int) n2 = input(Int) - s1 = Source(n1) - s2 = Source(n2) - @test_throws ErrorException push!(s1 => 1, s2 => 2) + @test_throws AssertionError compile(n1, n2) end end @@ -96,11 +90,10 @@ end n1 = input(Int) n2 = input(Int) c = sink(+, n1, n2) - s1 = Source(n1) - s2 = Source(n2) - push!(s1, 1) - push!(s2, 2) - push!(s1, 3) + g, s1, s2 = compile(n1, n2) + push!(g, s1, 1) + push!(g, s2, 2) + push!(g, s1, 3) @test c == [3, 5] end @@ -112,11 +105,10 @@ end @test eltype(n2) == Int @test eltype(m) == Int c = sink(m) - s1 = Source(n1) - s2 = Source(n2) - push!(s1, 1) - push!(s2, 2) - push!(s1, 3) + g, s1, s2 = compile(n1, n2) + push!(g, s1, 1) + push!(g, s2, 2) + push!(g, s1, 3) @test c == [3, 5] end @@ -125,11 +117,10 @@ end n2 = input(Int) c1 = sink(+, n1, n2) c2 = sink((x, y) -> -(x + y), n1, n2) - s1 = Source(n1) - s2 = Source(n2) - push!(s1, 1) - push!(s2, 2) - push!(s1, 3) + g, s1, s2 = compile(n1, n2) + push!(g, s1, 1) + push!(g, s2, 2) + push!(g, s1, 3) @test c1 == [3, 5] @test c2 == [-3, -5] end @@ -140,9 +131,9 @@ end n1 = input(Int) n2 = foldl(+, 1, n1) c = sink(n2) - s = Source(n1) - push!(s, 2) - push!(s, 3) + g, s = compile(n1) + push!(g, s, 2) + push!(g, s, 3) @test c == [3, 6] end @@ -154,9 +145,9 @@ end end n3 = map(x -> x[], n2) c = sink(n3) - s = Source(n1) - push!(s, 2) - push!(s, 3) + g, s = compile(n1) + push!(g, s, 2) + push!(g, s, 3) @test c == [3, 6] end end @@ -166,10 +157,9 @@ end n2 = input(Int) n3 = inlinedmap(+, n1, n2) c = sink(n3) - s1 = Source(n1) - s2 = Source(n2) - push!(s1, 1) - push!(s2, 2) + g, s1, s2 = compile(n1, n2) + push!(g, s1, 1) + push!(g, s2, 2) @test c == [3] end @@ -179,15 +169,14 @@ end n2 = input(Bool) n3 = filter(n1, n2) c = sink(n3) - s1 = Source(n1) - s2 = Source(n2) - push!(s1, 2) - push!(s2, false) - push!(s1, 3) - push!(s2, true) - push!(s1, 4) - push!(s2, false) - push!(s1, 5) + g, s1, s2 = compile(n1, n2) + push!(g, s1, 2) + push!(g, s2, false) + push!(g, s1, 3) + push!(g, s2, true) + push!(g, s1, 4) + push!(g, s2, false) + push!(g, s1, 5) @test c == [3, 4] end @@ -195,9 +184,9 @@ end n1 = input(Int) n2 = filter(iseven, n1) c = sink(n2) - s1 = Source(n1) + g, s1 = compile(n1) for i = 1:4 - push!(s1, i) + push!(g, s1, i) end @test c == [2, 4] end @@ -218,18 +207,16 @@ end n4 = input(Int) n5 = map(+, n3, n4) c = sink(n5) - s1 = Source(n1) - s2 = Source(n2) - s4 = Source(n4) - push!(s1, 1) - push!(s2, false) - push!(s4, 2) - push!(s2, true) - push!(s1, 3) - push!(s4, 4) - push!(s2, false) - push!(s1, 5) - push!(s4, 6) + g, s1, s2, s4 = compile(n1, n2, n4) + push!(g, s1, 1) + push!(g, s2, false) + push!(g, s4, 2) + push!(g, s2, true) + push!(g, s1, 3) + push!(g, s4, 4) + push!(g, s2, false) + push!(g, s1, 5) + push!(g, s4, 6) @test c == [3, 5, 7] end @@ -237,9 +224,9 @@ end n1 = input(Int) n2 = select(iseven, n1) c = sink(n2) - s1 = Source(n1) + g, s1 = compile(n1) for i = 1:7 - push!(s1, i) + push!(g, s1, i) end @test c == [2, 4, 6] end @@ -256,16 +243,16 @@ end n1 = input(Int) n2 = constant(1) c = sink(+, n1, n2) - s = Source(n1) - push!(s, 2) + g, s = compile(n1) + push!(g, s, 2) @test c == [3] n1 = input(Bool) n2 = constant(true) c = sink(&, n1, n2) - s = Source(n1) - push!(s, true) - push!(s, false) + g, s = compile(n1) + push!(g, s, true) + push!(g, s, false) @test c == [true, false] end @@ -274,13 +261,12 @@ end n2 = input(Int) n3 = quiet(n2) c = sink(+, n1, n3) - s1 = Source(n1) - s2 = Source(n2) - push!(s1, 1) - push!(s2, 2) - push!(s1, 3) - push!(s2, 4) - push!(s1, 5) + g, s1, s2 = compile(n1, n2) + push!(g, s1, 1) + push!(g, s2, 2) + push!(g, s1, 3) + push!(g, s2, 4) + push!(g, s1, 5) @test c == [5, 9] end @@ -288,9 +274,9 @@ end n1 = input(Int) n2 = lag(2, n1) c = sink(n2) - s1 = Source(n1) + g, s1 = compile(n1) for i = 1:10 - push!(s1, i) + push!(g, s1, i) end @test c == 1:8 end @@ -302,9 +288,9 @@ end c1 = sink(identity, n3) c2 = sink((x, y) -> x, n3, n1) - s = Source(n1) + g, s = compile(n1) for i = 1:4 - push!(s, i) + push!(g, s, i) end @test c1 == [true, true] @test c2 == [false, true, false, true] @@ -324,30 +310,31 @@ end s1 = Source(i1) s2 = Source(i2) s3 = Source(i3) - push!(s1, 1) - push!(s2, true) - push!(s3, true) - @testnoalloc push!($s1, 1) + g, s1, s2, s3 = compile(i1, i2, i3) + push!(g, s1, 1) + push!(g, s2, true) + push!(g, s3, true) + @testnoalloc push!($g, $s1, 1) end -@testset "PerformanceTrackers" begin - i1 = input(Int; name = "input1") - i2 = input(Int; name = "input2") - n1 = map(x -> 2x, i1) - n2 = map(+, n1, i2) - s1 = Source(i1) - s2 = Source(i2) - push!(s1, 1) - push!(s2, 2) - tracker = PerformanceGraphTracker() - push!(tracker, s1, 2) - push!(tracker, s2, 2) - push!(tracker, s1 => 3, s2 => 3) - triggers = gettrackingtriggers(tracker) - nodes = gettrackingnodes(tracker) - @test length(triggers) == 4 - @test map(x -> x.id, triggers) == [1, 2, 3, 3] - @test length(nodes) == 3 * 4 - @test map(x -> x.id, nodes) == [i for i = 1:3 for _ = 1:4] - @test map(x -> x.bytes_allocated, nodes) == [0 for i = 1:3 for _ = 1:4] -end +# @testset "PerformanceTrackers" begin +# i1 = input(Int; name = "input1") +# i2 = input(Int; name = "input2") +# n1 = map(x -> 2x, i1) +# n2 = map(+, n1, i2) +# s1 = Source(i1) +# s2 = Source(i2) +# push!(s1, 1) +# push!(s2, 2) +# tracker = PerformanceGraphTracker() +# push!(tracker, s1, 2) +# push!(tracker, s2, 2) +# push!(tracker, s1 => 3, s2 => 3) +# triggers = gettrackingtriggers(tracker) +# nodes = gettrackingnodes(tracker) +# @test length(triggers) == 4 +# @test map(x -> x.id, triggers) == [1, 2, 3, 3] +# @test length(nodes) == 3 * 4 +# @test map(x -> x.id, nodes) == [i for i = 1:3 for _ = 1:4] +# @test map(x -> x.bytes_allocated, nodes) == [0 for i = 1:3 for _ = 1:4] +# end From 2919afec6e55af582b58addc2532a1ab4335e7e3 Mon Sep 17 00:00:00 2001 From: Romain Poncet Date: Wed, 2 Aug 2023 13:39:13 +0100 Subject: [PATCH 3/8] documentation --- README.md | 8 ++--- benchmark/benchmark.jl | 3 -- docs/src/index.md | 69 +++++++++++++++++++++--------------------- src/ReactiveGraphs.jl | 3 +- src/compilation.jl | 64 ++++++++++++++++++++++----------------- src/graph.jl | 40 ------------------------ test/runtests.jl | 4 +-- 7 files changed, 76 insertions(+), 115 deletions(-) diff --git a/README.md b/README.md index 17f3b9b..6da7e32 100644 --- a/README.md +++ b/README.md @@ -22,12 +22,10 @@ julia>i1 = input(Int) n4 = inlinedmap(+,n2,n3) n5 = lag(1, n4) - s1 = Source(i1) - s2 = Source(i2) - s3 = Source(i3) - push!(s1 => true, s2 => true, s3 => 1) + graph, s1, s2, s3 = compile(i1, i2, i3) + push!(graph, s1 => true, s2 => true, s3 => 1) v = 1 - @benchmark push!($s1, $v) + @benchmark push!($graph, $s1, $v) BenchmarkTools.Trial: 10000 samples with 1000 evaluations. Range (min … max): 8.708 ns … 27.625 ns ┊ GC (min … max): 0.00% … 0.00% Time (median): 8.792 ns ┊ GC (median): 0.00% diff --git a/benchmark/benchmark.jl b/benchmark/benchmark.jl index 2b8831d..f40f63c 100644 --- a/benchmark/benchmark.jl +++ b/benchmark/benchmark.jl @@ -11,9 +11,6 @@ n3 = foldl((state, x) -> state + x, 1, i1s) n4 = inlinedmap(+, n2, n3) n5 = lag(1, n4) -s1 = Source(i1) -s2 = Source(i2) -s3 = Source(i3) g, s1, s2, s3 = compile(i1, i2, i3) push!(g, s1, 1) push!(g, s2, true) diff --git a/docs/src/index.md b/docs/src/index.md index 3708d0b..eddff02 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -73,15 +73,16 @@ end To use this graph, we need to push values into the two inputs. We can either push a new value, or a function that mutates the current value. -To do so, we need to wrap our inputs into a `Source`, and push to the sources, -and not the inputs directly. + +To do so, the graph is compiled by calling `compile` on the inputs. +The methods returns the compiled graph, and some empty objects that identify +the inputs in the graph. ```jldoctest; output = false -s1 = Source(input_1) # Captures the current state of the graph. Nodes must not be added afterwards. -s2 = Source(input_2) # Captures the current state of the graph. Nodes must not be added afterwards. -push!(s1, 1) # prints "node 1: 2". The second node cannot be evaluated since the data is missing -push!(s2, x -> x[] = 2)# prints "node 2: 5" -push!(s1, 3) # prints "node 1: 6" "node 2: 11". +g, s1, s2 = compile(input_1, input_2) +push!(g, s1, 1) # prints "node 1: 2". The second node cannot be evaluated since the data is missing +push!(g, s2, x -> x[] = 2)# prints "node 2: 5" +push!(g, s1, 3) # prints "node 1: 6" "node 2: 11". # output @@ -97,14 +98,15 @@ DocTestSetup = quote end ``` -Introducing this wrapping may seem a bit cumbersome to the user. +Introducing this compilation step may seem a bit cumbersome to the user. But it is used to achieve high performance. -When creating a source, the current state of the graph is being captured in the -parameters of the `Source` type. -This allows to dispatch the subsequent calls to `push!` on performant generated methods. -This capture of the the graph implies that we first need to build the complete graph -before wrapping the inputs into sources. -Otherwise, the nodes added subsequently will be ignored. +To update the graph efficiently, we need to generate methods that are specific +to the topology of the graph. +When compiling the graph, we simply generate an object whose type encodes this topology. +Hence, the update methods can simply be implemented as generated functions, based on the +type of the graph. + +Note that, the nodes added subsequently to the compiled graph will be ignored. As explained, in the introduction, each time an input is updated, the data will flow down the graph, @@ -129,7 +131,8 @@ n2 = input(Ref(0)) map(n1, n2) do x, y println(x + y[]) end -push!(Source(n1) => 1, Source(n2) => (x->x[]=2)) # prints "3" +g, s1, s2 = compile(n1, n2) +push!(g, s1 => 1, s2 => (x->x[]=2)) # prints "3" # output @@ -157,9 +160,9 @@ input_1 = input(Float64) filtered = filter(x->!isnan(x), input_1) n = map(x->println("new update: $x"), filtered) -s1 = Source(input_1) # compiles the graph. Nodes must not be added afterwards. -push!(s1, 1.0) # prints "new update: 1.0" -push!(s1, NaN) # prints nothing +g, s1= compile(input_1) +push!(g, s1, 1.0) # prints "new update: 1.0" +push!(g, s1, NaN) # prints nothing # output @@ -188,12 +191,11 @@ selected = select(x->!isnan(x), input_1) map((x,y) -> println("filtered"), filtered, input_2) map((x,y) -> println("selected"), selected, input_2) -s1 = Source(input_1) # compiles the graph. Nodes must not be added afterwards. -s2 = Source(input_2) # compiles the graph. Nodes must not be added afterwards. -push!(s1, 1.0) # prints nothing -push!(s2, nothing) # prints "filtered" and then "selected" -push!(s1, NaN) # prints nothing -push!(s2, nothing) # prints "filtered" only +g, s1, s2 = compile(input_1, input_2) +push!(g, s1, 1.0) # prints nothing +push!(g, s2, nothing) # prints "filtered" and then "selected" +push!(g, s1, NaN) # prints nothing +push!(g, s2, nothing) # prints "filtered" only # output @@ -219,7 +221,8 @@ x1 = input(nothing) x2 = map(x->println("x2"), x1) x3 = map(x->println("x3"), x1) x4 = map((x,y)->println("x4"), x2, x3) -push!(Source(x1), nothing) # prints x2 x3 x4 +g, s1 = compile(x1) +push!(g, s1, nothing) # prints x2 x3 x4 # output x2 @@ -255,8 +258,8 @@ BenchmarkTools.Trial: 10000 samples with 195 evaluations. julia> x1 = input(Int) x2 = map(x->x+1, x1) - s = Source(x1) - @benchmark push!($x1, 2) + g, s = compile(x1) + @benchmark push!($g, $x1, 2) BenchmarkTools.Trial: 10000 samples with 1000 evaluations. Range (min … max): 1.500 ns … 10.625 ns ┊ GC (min … max): 0.00% … 0.00% Time (median): 1.542 ns ┊ GC (median): 0.00% @@ -285,14 +288,12 @@ julia>i1 = input(Int) n4 = inlinedmap(+,n2,n3) n5 = lag(1, n4) - s1 = Source(i1) - s2 = Source(i2) - s3 = Source(i3) - push!(s1, 1) - push!(s2, true) - push!(s3, true) + g, s1, s2, s3 = compile(i1, i2, i3) + push!(g, s1, 1) + push!(g, s2, true) + push!(g, s3, true) v = 1 - @benchmark push!($s1, $v) + @benchmark push!($g, $s1, $v) BenchmarkTools.Trial: 10000 samples with 1000 evaluations. Range (min … max): 8.708 ns … 27.625 ns ┊ GC (min … max): 0.00% … 0.00% Time (median): 8.792 ns ┊ GC (median): 0.00% diff --git a/src/ReactiveGraphs.jl b/src/ReactiveGraphs.jl index 518ed7a..bfdd682 100644 --- a/src/ReactiveGraphs.jl +++ b/src/ReactiveGraphs.jl @@ -1,7 +1,7 @@ module ReactiveGraphs +export compile export input -export Source export constant export inlinedmap export quiet @@ -11,7 +11,6 @@ export updated export PerformanceGraphTracker export gettrackingnodes, gettrackingtriggers -export compile macro tryinline(e) @static if VERSION >= v"1.8" diff --git a/src/compilation.jl b/src/compilation.jl index 3de16c9..fca3651 100644 --- a/src/compilation.jl +++ b/src/compilation.jl @@ -11,20 +11,11 @@ getparentnames(::TypeOrValue{CompiledNode{name,parentnames}}) where {name,parent getoperationtype(::TypeOrValue{CompiledNode{name,parentnames,Op}}) where {name,parentnames,Op} = Op getoperation(n::CompiledNode) = n.operation -# function CompiledNode(node::Node) -# CompiledNode(node.name, Tuple(node.parentnames), node.operation) -# end - struct CompiledGraph{N,T<:NTuple{N,CompiledNode},Tr<:AbstractGraphTracker} nodes::T tracker::Tr end -# CompiledGraph(node::Node) = CompiledGraph(node.ref[]) -# function CompiledGraph(nodes::Vector{Node}) -# Tuple(CompiledNode(n) for n in nodes) |> CompiledGraph -# end - nodetypes(::TypeOrValue{CompiledGraph{N,T}}) where {N,T} = T.parameters @generated function Base.getindex(g::CompiledGraph{N}, ::TypeSymbol{name}) where {N, name} @@ -34,51 +25,68 @@ nodetypes(::TypeOrValue{CompiledGraph{N,T}}) where {N,T} = T.parameters throw(ErrorException("symbol $name not found in graph $g")) end +""" + Source{inputname,T} + +An empty object that indexes an `Input` node within a `CompiledGraph`. +""" struct Source{inputname,T} end Base.eltype(::TypeOrValue{Source{inputname,T}}) where {inputname,T} = T getinputname(::TypeOrValue{Source{inputname,T}}) where {inputname,T} = inputname +function Source(node::Node) + @assert node.operation isa Input + Source{node.name, eltype(node.operation)}() +end + """ - Source(::Node) + compile(inputs::Node....; tracker::AbstractGraphTracker=NullGraphTracker) + +This library does not optimize for the construction of the graph, but once +built, updating it should be as fast as possible, and allocation free. + +In practice, the user first creates a computation graph iteratively, by creating nodes. +Then, in order to generate efficient methods to update such a graph, this library +requires the user to "compile" those nodes. +This simply amounts to build an object that whose type completely encodes the topology of the graph. +This way, it is possible to generate update methods based on the type of graph object, which +is quite idiomatic in Julia. -Transforms an input node into a `Source`, which is a type stable version of the former. -This type is used to update the roots of the graph with `Base.push!`. -The input objects are not used directly, for performance considerations. +This conversion from a collection of node, to a strongly typed graph is done by this function. +It needs to be called on all the inputs of the graph. +The methods returns: +- a compiled graph object +- a `Source` for each of the inputs ```jldoctest julia> i = input(String) map(println, i) - s = Source(i) - push!(s, "example") + g, s = compile(i) + push!(g, s, "example") example julia> i = input(Ref(0)) map(x->println(x[]), i) - s = Source(i) - push!(s, ref -> ref[] = 123) + g, s = compile(i) + push!(g, s, ref -> ref[] = 123) 123 ``` -Sources can also be used simultaneously, +It is also possible to update several inputs simultaneously ```jldoctest julia> i1 = input(Int) i2 = input(Int) m = map(+, i1, i2) map(print, m) - s1 = Source(i1) - s2 = Source(i2) - push!(s1 => 1, s2 => 2) - push!(s1 => 3, s2 => 4) + g, s1, s2 = compile(i1, i2) + push!(g, s1 => 1, s2 => 2) + push!(g, s1 => 3, s2 => 4) 37 ``` -""" -function Source(node::Node) - @assert node.operation isa Input - Source{node.name, eltype(node.operation)}() -end +""" function compile(inputs::Node...; tracker::AbstractGraphTracker=NullGraphTracker()) @assert !isempty(inputs) sources = Source.(inputs) @@ -130,4 +138,4 @@ function getvalue(graph::CompiledGraph, name::TypeSymbol) node = graph[name] getvalue(graph, node, getoperation(node)) end -getvalue(::CompiledGraph, ::CompiledNode, op::Operation) = getvalue(op) # specific implementations should dispatch on this method +getvalue(::CompiledGraph, ::CompiledNode, op::Operation) = getvalue(op) # default implementaion. Concrete operations should dispatch on this method diff --git a/src/graph.jl b/src/graph.jl index 5b5888d..204e9e8 100644 --- a/src/graph.jl +++ b/src/graph.jl @@ -38,43 +38,3 @@ function mergegraphs!(node1::Node, node2::Node, nodes::Node...) end mergegraphs!(node1, nodes...) end - -# """ -# Node{name} - -# Objects of type `Node` correspond to the nodes of the computational graph. -# Each node is identified by a uniquely generated name `name`. -# """ -# struct Node{name} -# graph::GraphRef -# Node(name::Symbol, graph::GraphRef) = new{name}(graph) -# end - -# function Base.show(io::IO, node::Node) -# name = getname(node) -# type = eltype(node) -# print(io, "Node($name,$type)") -# end - -# Base.eltype(node::Node) = node |> getedge |> eltype -# getname(::TypeOrValue{Node{name}}) where {name} = name -# getgraph(node::Node) = node.graph -# getedge(node::Node) = getedge(getgraph(node), TypeSymbol(getname(node))) - -# function Node(name::Symbol, op::Op, parents::Node...) where {Op <: Operation} -# graph = if isempty(parents) -# GraphRef() -# else -# merge!((n.graph for n in parents)...) # return Root() if empty -# end -# parentnames = Tuple(getname(a) for a in parents) -# push!(graph, name, parentnames, op) -# Node(name, graph) -# end - -# update!(node, args...) = update!(getedge(node), args...) -# # isinitialized(node, args...) = isinitialized(getedge(node), args...) -# (node::Node)(args...) = getedge(node)(args...) - -# @inline getvalue(::Graph, element::Operation) = getvalue(element) -# getoperationtype(node::Node) = getedge(node) |> getoperationtype diff --git a/test/runtests.jl b/test/runtests.jl index bcf1616..82c81ca 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -306,10 +306,8 @@ end n3 = foldl((state, x) -> state + x, 1, i1s) n4 = inlinedmap(+, n2, n3) n5 = lag(1, n4) + n6 = updated(n5) - s1 = Source(i1) - s2 = Source(i2) - s3 = Source(i3) g, s1, s2, s3 = compile(i1, i2, i3) push!(g, s1, 1) push!(g, s2, true) From c70df64c785748f59f95b6778080a5f5f36d4d0d Mon Sep 17 00:00:00 2001 From: Romain Poncet Date: Wed, 2 Aug 2023 13:50:43 +0100 Subject: [PATCH 4/8] cleanup getsymbol --- src/ReactiveGraphs.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/ReactiveGraphs.jl b/src/ReactiveGraphs.jl index bfdd682..e117ff4 100644 --- a/src/ReactiveGraphs.jl +++ b/src/ReactiveGraphs.jl @@ -38,8 +38,6 @@ struct TypeSymbol{x} TypeSymbol(x::Symbol) = new{x}() end -getsymbol(::TypeOrValue{TypeSymbol{x}}) where {x} = x - include("operations.jl") include("graph.jl") include("trackers.jl") From ca3d3bd316b830ccd70e046e2169192ce84176c2 Mon Sep 17 00:00:00 2001 From: Romain Poncet Date: Wed, 2 Aug 2023 17:00:58 +0100 Subject: [PATCH 5/8] trackers --- src/compilation.jl | 5 +++-- src/trackers.jl | 27 +++++++-------------------- test/runtests.jl | 40 +++++++++++++++++++--------------------- 3 files changed, 29 insertions(+), 43 deletions(-) diff --git a/src/compilation.jl b/src/compilation.jl index fca3651..eacc76f 100644 --- a/src/compilation.jl +++ b/src/compilation.jl @@ -17,6 +17,7 @@ struct CompiledGraph{N,T<:NTuple{N,CompiledNode},Tr<:AbstractGraphTracker} end nodetypes(::TypeOrValue{CompiledGraph{N,T}}) where {N,T} = T.parameters +gettrackingnodes(g::CompiledGraph) = gettrackingnodes(g.tracker) @generated function Base.getindex(g::CompiledGraph{N}, ::TypeSymbol{name}) where {N, name} for (i, node) in nodetypes(g) |> enumerate @@ -108,7 +109,7 @@ end function generate(graph::Type{<:CompiledGraph}, sources::Type{<:Source}...) inputnames = getinputname.(sources) expr = quote - on_update_start!(graph.tracker, $(inputnames)) + on_update_start!(graph.tracker) end for compilednodetype in nodetypes(graph) generate!(expr, compilednodetype, inputnames...) @@ -125,7 +126,7 @@ function generate!( ) where {name,parentnames,Op} e = generate(inputnames, name, parentnames, Op) append!(expr.args, e.args) - push!(expr.args, :(on_update_node!(graph.tracker, $(Meta.quot(name))))) + push!(expr.args, :(on_update_node!(graph.tracker, $(Meta.quot(name)), $(name in inputnames)))) expr end diff --git a/src/trackers.jl b/src/trackers.jl index 93f2516..dc1d39f 100644 --- a/src/trackers.jl +++ b/src/trackers.jl @@ -1,10 +1,9 @@ abstract type AbstractGraphTracker end -using Base: nothing_sentinel struct NullGraphTracker <: AbstractGraphTracker end -on_update_start!(::AbstractGraphTracker, inputnames) = nothing -on_update_node!(::AbstractGraphTracker, name) = nothing +on_update_start!(::AbstractGraphTracker) = nothing +on_update_node!(::AbstractGraphTracker, name, isinput) = nothing on_update_stop!(::AbstractGraphTracker) = nothing struct TrackingNode @@ -12,16 +11,11 @@ struct TrackingNode id::Int64 elapsed_time::Int64 bytes_allocated::Int -end - -struct TrackingTriggers - name::Symbol - id::Int64 + isinput::Bool end mutable struct PerformanceGraphTracker <: AbstractGraphTracker @tryconst trackingnodes::Vector{TrackingNode} - @tryconst trackingtriggers::Vector{TrackingTriggers} currentid::Int64 lasttime::UInt64 @tryconst total_bytes_allocated::Base.RefValue{Int64} @@ -29,14 +23,12 @@ end PerformanceGraphTracker() = PerformanceGraphTracker( TrackingNode[], - TrackingTriggers[], zero(Int64), zero(UInt64), Ref(zero(Int64)), ) gettrackingnodes(pm::PerformanceGraphTracker) = pm.trackingnodes -gettrackingtriggers(pm::PerformanceGraphTracker) = pm.trackingtriggers gettime(::PerformanceGraphTracker) = time_ns() getelapsedtime(pm::PerformanceGraphTracker) = reinterpret(Int64, gettime(pm) - pm.lasttime) @@ -49,23 +41,18 @@ function getallocatedbytes!(pm::PerformanceGraphTracker) pm.total_bytes_allocated[] - old_total_bytes_allocated end -function on_update_start!(pm::PerformanceGraphTracker, names) - id = pm.currentid += 1 - for name in names - rs = TrackingTriggers(name, id) - push!(pm.trackingtriggers, rs) - end - +function on_update_start!(pm::PerformanceGraphTracker) + pm.currentid += 1 setallocatedbytes!(pm) settime!(pm) nothing end -function on_update_node!(pm::PerformanceGraphTracker, name) +function on_update_node!(pm::PerformanceGraphTracker, name, isinput::Bool) elapsed_time = getelapsedtime(pm) bytes_allocated = getallocatedbytes!(pm) - rs = TrackingNode(name, pm.currentid, elapsed_time, bytes_allocated) + rs = TrackingNode(name, pm.currentid, elapsed_time, bytes_allocated, isinput) push!(pm.trackingnodes, rs) setallocatedbytes!(pm) diff --git a/test/runtests.jl b/test/runtests.jl index 82c81ca..95a75c2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -315,24 +315,22 @@ end @testnoalloc push!($g, $s1, 1) end -# @testset "PerformanceTrackers" begin -# i1 = input(Int; name = "input1") -# i2 = input(Int; name = "input2") -# n1 = map(x -> 2x, i1) -# n2 = map(+, n1, i2) -# s1 = Source(i1) -# s2 = Source(i2) -# push!(s1, 1) -# push!(s2, 2) -# tracker = PerformanceGraphTracker() -# push!(tracker, s1, 2) -# push!(tracker, s2, 2) -# push!(tracker, s1 => 3, s2 => 3) -# triggers = gettrackingtriggers(tracker) -# nodes = gettrackingnodes(tracker) -# @test length(triggers) == 4 -# @test map(x -> x.id, triggers) == [1, 2, 3, 3] -# @test length(nodes) == 3 * 4 -# @test map(x -> x.id, nodes) == [i for i = 1:3 for _ = 1:4] -# @test map(x -> x.bytes_allocated, nodes) == [0 for i = 1:3 for _ = 1:4] -# end +@testset "PerformanceTrackers" begin + i1 = input(Int; name = "input1") + i2 = input(Int; name = "input2") + n1 = map(x -> 2x, i1) + n2 = map(+, n1, i2) + + g, s1, s2 = compile(i1, i2) + push!(g, s1, 1) + push!(g, s2, 2) + g, s1, s2 = compile(i1, i2; tracker=PerformanceGraphTracker()) + push!(g, s1, 2) + push!(g, s2, 2) + push!(g, s1 => 3, s2 => 3) + nodes = gettrackingnodes(g) + @test length(nodes) == 3 * 4 + @test map(x -> x.id, nodes) == [i for i = 1:3 for _ = 1:4] + @test map(x -> x.bytes_allocated, nodes) == [0 for i = 1:3 for _ = 1:4] + @test map(x -> x.isinput, nodes) == [(i in x) for x in [(1,), (3,), (1,3)] for i = 1:4] +end From 72923ae9905f5d5c529878186ad79753fc1ce35f Mon Sep 17 00:00:00 2001 From: Romain Poncet Date: Thu, 3 Aug 2023 09:59:36 +0100 Subject: [PATCH 6/8] no export gettrackingnodes --- src/ReactiveGraphs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ReactiveGraphs.jl b/src/ReactiveGraphs.jl index e117ff4..5a9973f 100644 --- a/src/ReactiveGraphs.jl +++ b/src/ReactiveGraphs.jl @@ -10,7 +10,7 @@ export lag export updated export PerformanceGraphTracker -export gettrackingnodes, gettrackingtriggers +export gettrackingnodes macro tryinline(e) @static if VERSION >= v"1.8" From 1ce7c9405f415b09646ec46169e818888df63ae4 Mon Sep 17 00:00:00 2001 From: Romain Poncet Date: Thu, 3 Aug 2023 10:05:03 +0100 Subject: [PATCH 7/8] format --- src/compilation.jl | 31 +++++++++++++++++++++---------- src/trackers.jl | 8 ++------ test/runtests.jl | 4 ++-- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/compilation.jl b/src/compilation.jl index eacc76f..c165edf 100644 --- a/src/compilation.jl +++ b/src/compilation.jl @@ -1,14 +1,20 @@ struct CompiledNode{name,parentnames,Op<:Operation} operation::Op - function CompiledNode(name::Symbol, parentnames::NTuple{N,Symbol}, operation::Operation) where {N} - new{name, parentnames, typeof(operation)}(operation) + function CompiledNode( + name::Symbol, + parentnames::NTuple{N,Symbol}, + operation::Operation, + ) where {N} + new{name,parentnames,typeof(operation)}(operation) end end getname(::TypeOrValue{CompiledNode{name}}) where {name} = name getparentnames(::TypeOrValue{CompiledNode{name,parentnames}}) where {name,parentnames} = parentnames -getoperationtype(::TypeOrValue{CompiledNode{name,parentnames,Op}}) where {name,parentnames,Op} = Op +getoperationtype( + ::TypeOrValue{CompiledNode{name,parentnames,Op}}, +) where {name,parentnames,Op} = Op getoperation(n::CompiledNode) = n.operation struct CompiledGraph{N,T<:NTuple{N,CompiledNode},Tr<:AbstractGraphTracker} @@ -19,7 +25,7 @@ end nodetypes(::TypeOrValue{CompiledGraph{N,T}}) where {N,T} = T.parameters gettrackingnodes(g::CompiledGraph) = gettrackingnodes(g.tracker) -@generated function Base.getindex(g::CompiledGraph{N}, ::TypeSymbol{name}) where {N, name} +@generated function Base.getindex(g::CompiledGraph{N}, ::TypeSymbol{name}) where {N,name} for (i, node) in nodetypes(g) |> enumerate getname(node) == name && return :(g.nodes[$i]) end @@ -38,7 +44,7 @@ getinputname(::TypeOrValue{Source{inputname,T}}) where {inputname,T} = inputname function Source(node::Node) @assert node.operation isa Input - Source{node.name, eltype(node.operation)}() + Source{node.name,eltype(node.operation)}() end """ @@ -88,13 +94,15 @@ julia> i1 = input(Int) ``` """ -function compile(inputs::Node...; tracker::AbstractGraphTracker=NullGraphTracker()) - @assert !isempty(inputs) +function compile(inputs::Node...; tracker::AbstractGraphTracker = NullGraphTracker()) + @assert !isempty(inputs) sources = Source.(inputs) - @assert allequal(map(x->x.graph.ref, inputs)) + @assert allequal(map(x -> x.graph.ref, inputs)) nodes = inputs[1].graph.ref[] - compilednodes = Tuple(CompiledNode(node.name, Tuple(node.parentnames), node.operation) for node in nodes) + compilednodes = Tuple( + CompiledNode(node.name, Tuple(node.parentnames), node.operation) for node in nodes + ) compiledgraph = CompiledGraph(compilednodes, tracker) compiledgraph, sources... @@ -126,7 +134,10 @@ function generate!( ) where {name,parentnames,Op} e = generate(inputnames, name, parentnames, Op) append!(expr.args, e.args) - push!(expr.args, :(on_update_node!(graph.tracker, $(Meta.quot(name)), $(name in inputnames)))) + push!( + expr.args, + :(on_update_node!(graph.tracker, $(Meta.quot(name)), $(name in inputnames))), + ) expr end diff --git a/src/trackers.jl b/src/trackers.jl index dc1d39f..9034f33 100644 --- a/src/trackers.jl +++ b/src/trackers.jl @@ -21,12 +21,8 @@ mutable struct PerformanceGraphTracker <: AbstractGraphTracker @tryconst total_bytes_allocated::Base.RefValue{Int64} end -PerformanceGraphTracker() = PerformanceGraphTracker( - TrackingNode[], - zero(Int64), - zero(UInt64), - Ref(zero(Int64)), -) +PerformanceGraphTracker() = + PerformanceGraphTracker(TrackingNode[], zero(Int64), zero(UInt64), Ref(zero(Int64))) gettrackingnodes(pm::PerformanceGraphTracker) = pm.trackingnodes diff --git a/test/runtests.jl b/test/runtests.jl index 95a75c2..a82d168 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -324,7 +324,7 @@ end g, s1, s2 = compile(i1, i2) push!(g, s1, 1) push!(g, s2, 2) - g, s1, s2 = compile(i1, i2; tracker=PerformanceGraphTracker()) + g, s1, s2 = compile(i1, i2; tracker = PerformanceGraphTracker()) push!(g, s1, 2) push!(g, s2, 2) push!(g, s1 => 3, s2 => 3) @@ -332,5 +332,5 @@ end @test length(nodes) == 3 * 4 @test map(x -> x.id, nodes) == [i for i = 1:3 for _ = 1:4] @test map(x -> x.bytes_allocated, nodes) == [0 for i = 1:3 for _ = 1:4] - @test map(x -> x.isinput, nodes) == [(i in x) for x in [(1,), (3,), (1,3)] for i = 1:4] + @test map(x -> x.isinput, nodes) == [(i in x) for x in [(1,), (3,), (1, 3)] for i = 1:4] end From 6ff8519857c5eb4ce3c1e82c8f07f8fa8bd05b67 Mon Sep 17 00:00:00 2001 From: Romain Poncet Date: Mon, 7 Aug 2023 18:59:05 +0100 Subject: [PATCH 8/8] bug fix, adding several times the same nodes --- src/graph.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/graph.jl b/src/graph.jl index 204e9e8..5a039c2 100644 --- a/src/graph.jl +++ b/src/graph.jl @@ -34,6 +34,7 @@ mergegraphs!(node::Node) = node.graph function mergegraphs!(node1::Node, node2::Node, nodes::Node...) if node1.graph.ref != node2.graph.ref append!(node1.graph.ref[], node2.graph.ref[]) + empty!(node2.graph.ref[]) node2.graph.ref = node1.graph.ref end mergegraphs!(node1, nodes...)