diff --git a/README.md b/README.md index 17f3b9b..6da7e32 100644 --- a/README.md +++ b/README.md @@ -22,12 +22,10 @@ julia>i1 = input(Int) n4 = inlinedmap(+,n2,n3) n5 = lag(1, n4) - s1 = Source(i1) - s2 = Source(i2) - s3 = Source(i3) - push!(s1 => true, s2 => true, s3 => 1) + graph, s1, s2, s3 = compile(i1, i2, i3) + push!(graph, s1 => true, s2 => true, s3 => 1) v = 1 - @benchmark push!($s1, $v) + @benchmark push!($graph, $s1, $v) BenchmarkTools.Trial: 10000 samples with 1000 evaluations. Range (min … max): 8.708 ns … 27.625 ns ┊ GC (min … max): 0.00% … 0.00% Time (median): 8.792 ns ┊ GC (median): 0.00% diff --git a/benchmark/Project.toml b/benchmark/Project.toml index da523e2..32e44d8 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -1,4 +1,3 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ReactiveGraphs = "e49d8811-4385-4fad-be34-8522254608f3" -Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/benchmark/benchmark.jl b/benchmark/benchmark.jl index 82a5426..f40f63c 100644 --- a/benchmark/benchmark.jl +++ b/benchmark/benchmark.jl @@ -1,4 +1,3 @@ -using Revise using ReactiveGraphs using BenchmarkTools @@ -12,11 +11,9 @@ n3 = foldl((state, x) -> state + x, 1, i1s) n4 = inlinedmap(+, n2, n3) n5 = lag(1, n4) -s1 = Source(i1) -s2 = Source(i2) -s3 = Source(i3) -s1[] = 1 -s2[] = true -s3[] = true -v = 1 -@benchmark setindex!($s1, $v) +g, s1, s2, s3 = compile(i1, i2, i3) +push!(g, s1, 1) +push!(g, s2, true) +push!(g, s3, true) + +@benchmark push!($g, $s1, 1) diff --git a/docs/src/index.md b/docs/src/index.md index 3708d0b..eddff02 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -73,15 +73,16 @@ end To use this graph, we need to push values into the two inputs. We can either push a new value, or a function that mutates the current value. -To do so, we need to wrap our inputs into a `Source`, and push to the sources, -and not the inputs directly. + +To do so, the graph is compiled by calling `compile` on the inputs. +The methods returns the compiled graph, and some empty objects that identify +the inputs in the graph. ```jldoctest; output = false -s1 = Source(input_1) # Captures the current state of the graph. Nodes must not be added afterwards. -s2 = Source(input_2) # Captures the current state of the graph. Nodes must not be added afterwards. -push!(s1, 1) # prints "node 1: 2". The second node cannot be evaluated since the data is missing -push!(s2, x -> x[] = 2)# prints "node 2: 5" -push!(s1, 3) # prints "node 1: 6" "node 2: 11". +g, s1, s2 = compile(input_1, input_2) +push!(g, s1, 1) # prints "node 1: 2". The second node cannot be evaluated since the data is missing +push!(g, s2, x -> x[] = 2)# prints "node 2: 5" +push!(g, s1, 3) # prints "node 1: 6" "node 2: 11". # output @@ -97,14 +98,15 @@ DocTestSetup = quote end ``` -Introducing this wrapping may seem a bit cumbersome to the user. +Introducing this compilation step may seem a bit cumbersome to the user. But it is used to achieve high performance. -When creating a source, the current state of the graph is being captured in the -parameters of the `Source` type. -This allows to dispatch the subsequent calls to `push!` on performant generated methods. -This capture of the the graph implies that we first need to build the complete graph -before wrapping the inputs into sources. -Otherwise, the nodes added subsequently will be ignored. +To update the graph efficiently, we need to generate methods that are specific +to the topology of the graph. +When compiling the graph, we simply generate an object whose type encodes this topology. +Hence, the update methods can simply be implemented as generated functions, based on the +type of the graph. + +Note that, the nodes added subsequently to the compiled graph will be ignored. As explained, in the introduction, each time an input is updated, the data will flow down the graph, @@ -129,7 +131,8 @@ n2 = input(Ref(0)) map(n1, n2) do x, y println(x + y[]) end -push!(Source(n1) => 1, Source(n2) => (x->x[]=2)) # prints "3" +g, s1, s2 = compile(n1, n2) +push!(g, s1 => 1, s2 => (x->x[]=2)) # prints "3" # output @@ -157,9 +160,9 @@ input_1 = input(Float64) filtered = filter(x->!isnan(x), input_1) n = map(x->println("new update: $x"), filtered) -s1 = Source(input_1) # compiles the graph. Nodes must not be added afterwards. -push!(s1, 1.0) # prints "new update: 1.0" -push!(s1, NaN) # prints nothing +g, s1= compile(input_1) +push!(g, s1, 1.0) # prints "new update: 1.0" +push!(g, s1, NaN) # prints nothing # output @@ -188,12 +191,11 @@ selected = select(x->!isnan(x), input_1) map((x,y) -> println("filtered"), filtered, input_2) map((x,y) -> println("selected"), selected, input_2) -s1 = Source(input_1) # compiles the graph. Nodes must not be added afterwards. -s2 = Source(input_2) # compiles the graph. Nodes must not be added afterwards. -push!(s1, 1.0) # prints nothing -push!(s2, nothing) # prints "filtered" and then "selected" -push!(s1, NaN) # prints nothing -push!(s2, nothing) # prints "filtered" only +g, s1, s2 = compile(input_1, input_2) +push!(g, s1, 1.0) # prints nothing +push!(g, s2, nothing) # prints "filtered" and then "selected" +push!(g, s1, NaN) # prints nothing +push!(g, s2, nothing) # prints "filtered" only # output @@ -219,7 +221,8 @@ x1 = input(nothing) x2 = map(x->println("x2"), x1) x3 = map(x->println("x3"), x1) x4 = map((x,y)->println("x4"), x2, x3) -push!(Source(x1), nothing) # prints x2 x3 x4 +g, s1 = compile(x1) +push!(g, s1, nothing) # prints x2 x3 x4 # output x2 @@ -255,8 +258,8 @@ BenchmarkTools.Trial: 10000 samples with 195 evaluations. julia> x1 = input(Int) x2 = map(x->x+1, x1) - s = Source(x1) - @benchmark push!($x1, 2) + g, s = compile(x1) + @benchmark push!($g, $x1, 2) BenchmarkTools.Trial: 10000 samples with 1000 evaluations. Range (min … max): 1.500 ns … 10.625 ns ┊ GC (min … max): 0.00% … 0.00% Time (median): 1.542 ns ┊ GC (median): 0.00% @@ -285,14 +288,12 @@ julia>i1 = input(Int) n4 = inlinedmap(+,n2,n3) n5 = lag(1, n4) - s1 = Source(i1) - s2 = Source(i2) - s3 = Source(i3) - push!(s1, 1) - push!(s2, true) - push!(s3, true) + g, s1, s2, s3 = compile(i1, i2, i3) + push!(g, s1, 1) + push!(g, s2, true) + push!(g, s3, true) v = 1 - @benchmark push!($s1, $v) + @benchmark push!($g, $s1, $v) BenchmarkTools.Trial: 10000 samples with 1000 evaluations. Range (min … max): 8.708 ns … 27.625 ns ┊ GC (min … max): 0.00% … 0.00% Time (median): 8.792 ns ┊ GC (median): 0.00% diff --git a/src/ReactiveGraphs.jl b/src/ReactiveGraphs.jl index 3f4e0cf..5a9973f 100644 --- a/src/ReactiveGraphs.jl +++ b/src/ReactiveGraphs.jl @@ -1,7 +1,7 @@ module ReactiveGraphs +export compile export input -export Source export constant export inlinedmap export quiet @@ -10,7 +10,7 @@ export lag export updated export PerformanceGraphTracker -export gettrackingnodes, gettrackingtriggers +export gettrackingnodes macro tryinline(e) @static if VERSION >= v"1.8" @@ -38,17 +38,14 @@ struct TypeSymbol{x} TypeSymbol(x::Symbol) = new{x}() end -getsymbol(::TypeOrValue{TypeSymbol{x}}) where {x} = x - -include("graph.jl") include("operations.jl") +include("graph.jl") include("trackers.jl") include("compilation.jl") genname(::Nothing) = gensym() genname(s::Symbol) = gensym(s) genname(s::AbstractString) = gensym(string(s)) -getoperationtype(node::Node) = getnode(node) |> getelement |> eltype include("input.jl") include("map.jl") diff --git a/src/compilation.jl b/src/compilation.jl index c582c3d..c165edf 100644 --- a/src/compilation.jl +++ b/src/compilation.jl @@ -1,97 +1,153 @@ -# todo: ensure that the node exists, and is an input -struct Source{inputname,T,LN<:ListNode} - list::LN - function Source(listnode::LN, inputname::Symbol) where {LN<:ListNode} - node = getnode(listnode, TypeSymbol(inputname)) - T = eltype(node) - new{inputname,T,LN}(listnode) +struct CompiledNode{name,parentnames,Op<:Operation} + operation::Op + function CompiledNode( + name::Symbol, + parentnames::NTuple{N,Symbol}, + operation::Operation, + ) where {N} + new{name,parentnames,typeof(operation)}(operation) end end +getname(::TypeOrValue{CompiledNode{name}}) where {name} = name +getparentnames(::TypeOrValue{CompiledNode{name,parentnames}}) where {name,parentnames} = + parentnames +getoperationtype( + ::TypeOrValue{CompiledNode{name,parentnames,Op}}, +) where {name,parentnames,Op} = Op +getoperation(n::CompiledNode) = n.operation + +struct CompiledGraph{N,T<:NTuple{N,CompiledNode},Tr<:AbstractGraphTracker} + nodes::T + tracker::Tr +end + +nodetypes(::TypeOrValue{CompiledGraph{N,T}}) where {N,T} = T.parameters +gettrackingnodes(g::CompiledGraph) = gettrackingnodes(g.tracker) + +@generated function Base.getindex(g::CompiledGraph{N}, ::TypeSymbol{name}) where {N,name} + for (i, node) in nodetypes(g) |> enumerate + getname(node) == name && return :(g.nodes[$i]) + end + throw(ErrorException("symbol $name not found in graph $g")) +end + +""" + Source{inputname,T} + +An empty object that indexes an `Input` node within a `CompiledGraph`. +""" +struct Source{inputname,T} end + +Base.eltype(::TypeOrValue{Source{inputname,T}}) where {inputname,T} = T +getinputname(::TypeOrValue{Source{inputname,T}}) where {inputname,T} = inputname + +function Source(node::Node) + @assert node.operation isa Input + Source{node.name,eltype(node.operation)}() +end + """ - Source(::Node) + compile(inputs::Node....; tracker::AbstractGraphTracker=NullGraphTracker) + +This library does not optimize for the construction of the graph, but once +built, updating it should be as fast as possible, and allocation free. -Transforms an input node into a `Source`, which is a type stable version of the former. -This type is used to update the roots of the graph with `Base.push!`. -The input objects are not used directly, for performance considerations. +In practice, the user first creates a computation graph iteratively, by creating nodes. +Then, in order to generate efficient methods to update such a graph, this library +requires the user to "compile" those nodes. +This simply amounts to build an object that whose type completely encodes the topology of the graph. +This way, it is possible to generate update methods based on the type of graph object, which +is quite idiomatic in Julia. + +This conversion from a collection of node, to a strongly typed graph is done by this function. +It needs to be called on all the inputs of the graph. +The methods returns: +- a compiled graph object +- a `Source` for each of the inputs ```jldoctest julia> i = input(String) map(println, i) - s = Source(i) - push!(s, "example") + g, s = compile(i) + push!(g, s, "example") example julia> i = input(Ref(0)) map(x->println(x[]), i) - s = Source(i) - push!(s, ref -> ref[] = 123) + g, s = compile(i) + push!(g, s, ref -> ref[] = 123) 123 ``` -Sources can also be used simultaneously, +It is also possible to update several inputs simultaneously ```jldoctest julia> i1 = input(Int) i2 = input(Int) m = map(+, i1, i2) map(print, m) - s1 = Source(i1) - s2 = Source(i2) - push!(s1 => 1, s2 => 2) - push!(s1 => 3, s2 => 4) + g, s1, s2 = compile(i1, i2) + push!(g, s1 => 1, s2 => 2) + push!(g, s1 => 3, s2 => 4) 37 ``` -""" -Source(node::Node) = Source(getgraph(node)[], getname(node)) -function Base.show(io::IO, s::Source{inputname,T}) where {inputname,T} - print(io, "Source($inputname, $T)") -end +""" +function compile(inputs::Node...; tracker::AbstractGraphTracker = NullGraphTracker()) + @assert !isempty(inputs) + sources = Source.(inputs) -getlisttype(::TypeOrValue{Source{inputname,T,LN}}) where {inputname,T,LN} = LN -getinputname(::TypeOrValue{Source{inputname,T,LN}}) where {inputname,T,LN} = inputname + @assert allequal(map(x -> x.graph.ref, inputs)) + nodes = inputs[1].graph.ref[] + compilednodes = Tuple( + CompiledNode(node.name, Tuple(node.parentnames), node.operation) for node in nodes + ) + compiledgraph = CompiledGraph(compilednodes, tracker) -@inline Base.push!(src::Source, x) = push!(src => x) -@inline Base.push!(monitor::AbstractGraphTracker, src::Source, x) = push!(monitor, src => x) -@inline Base.push!(p::Pair{<:Source,<:Any}...) = push!(NullGraphTracker(), p...) -@generated function Base.push!(monitor::AbstractGraphTracker, p::Pair{<:Source,<:Any}...) - src_types = p .|> fieldtypes .|> first - inputnames = getinputname.(src_types) - listtypes = getlisttype.(src_types) + compiledgraph, sources... +end - if !allequal(listtypes) - throw(ErrorException("nodes do not belong to the same graph")) - end - LN = first(listtypes) +@inline Base.push!(graph::CompiledGraph, src::Source, x) = push!(graph, src => x) +@generated function Base.push!(graph::CompiledGraph, p::Pair{<:Source,<:Any}...) + sources = p .|> fieldtypes .|> first + generate(graph, sources...) +end +function generate(graph::Type{<:CompiledGraph}, sources::Type{<:Source}...) + inputnames = getinputname.(sources) expr = quote - on_update_start!(monitor, $inputnames) - list = p[1][1].list + on_update_start!(graph.tracker) + end + for compilednodetype in nodetypes(graph) + generate!(expr, compilednodetype, inputnames...) end - generate!(expr, LN, inputnames...) - push!(expr.args, :(on_update_stop!(monitor))) + push!(expr.args, :(on_update_stop!(graph.tracker))) push!(expr.args, nothing) expr end -generate!(::Expr, ::Type{Root}, ::Symbol...) = nothing function generate!( expr::Expr, - ::Type{ListNode{name,parentnames,X,Next}}, + ::Type{CompiledNode{name,parentnames,Op}}, inputnames::Symbol..., -) where {name,parentnames,X,Next} - generate!(expr, Next, inputnames...) - e = generate(inputnames, name, parentnames, X) +) where {name,parentnames,Op} + e = generate(inputnames, name, parentnames, Op) append!(expr.args, e.args) - push!(expr.args, :(on_update_node!(monitor, $(Meta.quot(name))))) + push!( + expr.args, + :(on_update_node!(graph.tracker, $(Meta.quot(name)), $(name in inputnames))), + ) expr end -# getvalue(list::ListNode, name::Symbol) = getnode(list, name) |> getvalue -getvalue(list::ListNode, v::TypeSymbol) = getnode(list, v) |> getvalue -getvalue(node::ListNode) = getvalue(node, getelement(node)) +function debugsource(graph::CompiledGraph, sources::Source...) + generate(graph, sources...) +end -function debugsource(src::Source{inputnames,LN}) where {inputnames,LN} - generate!(Expr(:block), LN, inputnames...) +getoperation(graph::CompiledGraph, name::TypeSymbol) = graph[name].operation +function getvalue(graph::CompiledGraph, name::TypeSymbol) + node = graph[name] + getvalue(graph, node, getoperation(node)) end +getvalue(::CompiledGraph, ::CompiledNode, op::Operation) = getvalue(op) # default implementaion. Concrete operations should dispatch on this method diff --git a/src/constant.jl b/src/constant.jl index 044331f..be18f49 100644 --- a/src/constant.jl +++ b/src/constant.jl @@ -30,8 +30,6 @@ function constant(x; name = nothing) Node(uniquename, op) end -getvalue(::ListNode, element::AbstractConstant) = getvalue(element) - function generate(::Any, name::Symbol, parentnames::Tuple{}, ::Type{<:AbstractConstant}) updated_s = Symbol(:updated, name) initialized_s = Symbol(:initialized, name) diff --git a/src/filter.jl b/src/filter.jl index e69f56a..a817018 100644 --- a/src/filter.jl +++ b/src/filter.jl @@ -15,7 +15,7 @@ function Base.filter(x::Node, condition::Node; name = nothing) throw(ErrorException("eltype(condition) is $(eltype(condition)), expected Bool")) end uniquename = genname(name) - op = Filter{getoperationtype(x)}() + op = Filter{eltype(x)}() Node(uniquename, op, x, condition) end @@ -37,15 +37,15 @@ function Base.filter(f::Function, x::Node; name = nothing) filter(x, condition; name) end -function getvalue(node::ListNode, ::Filter) - node_name, _ = getparentnames(node) - getvalue(node, TypeSymbol(node_name)) # todo: avoid starting from the leaf +function getvalue(graph::CompiledGraph, node::CompiledNode, ::Filter) + parent_name = getparentnames(node) |> first |> TypeSymbol + getvalue(graph, parent_name) end function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type{<:Filter}) updated_s = Symbol(:updated, name) initialized_s = Symbol(:initialized, name) - args = [:(getvalue(list, $(TypeSymbol(n)))) for n in parentnames] + args = [:(getvalue(graph, $(TypeSymbol(n)))) for n in parentnames] condition_updated = Expr(:call, :|, (Symbol(:updated, n) for n in parentnames)...) condition_initialized = Expr(:call, :&, (Symbol(:initialized, n) for n in parentnames)...) diff --git a/src/foldl.jl b/src/foldl.jl index 6e0e2ae..63850ac 100644 --- a/src/foldl.jl +++ b/src/foldl.jl @@ -3,7 +3,6 @@ mutable struct Foldl{TState,F} <: Operation{TState} state::TState end @inline getvalue(x::Foldl) = x.state -@inline getvalue(::ListNode, element::Foldl) = getvalue(element) @inline function update!(m::Foldl, args...) @tryinline state = m.f(m.state, args...) @@ -39,14 +38,14 @@ end function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type{<:Foldl}) updated_s = Symbol(:updated, name) initialized_s = Symbol(:initialized, name) - args = (:(getvalue(list, $(TypeSymbol(n)))) for n in parentnames) + args = (:(getvalue(graph, $(TypeSymbol(n)))) for n in parentnames) condition_updated = Expr(:call, :|, (Symbol(:updated, n) for n in parentnames)...) condition_initialized = Expr(:call, :&, (Symbol(:initialized, n) for n in parentnames)...) nodename_s = Symbol(:node, name) quote $updated_s = if $condition_initialized & $condition_updated - $nodename_s = getnode(list, $(TypeSymbol(name))) + $nodename_s = getoperation(graph, $(TypeSymbol(name))) $(Expr(:call, :update!, nodename_s, args...)) true else diff --git a/src/graph.jl b/src/graph.jl index 3cba936..5a039c2 100644 --- a/src/graph.jl +++ b/src/graph.jl @@ -1,107 +1,41 @@ -struct Root end - -struct ListNode{name,parentnames,X,Next} - x::X - next::Next - function ListNode( - name::Symbol, - parentnames::NTuple{<:Any,Symbol}, - x, - next::Union{Root,ListNode}, - ) - new{name,parentnames,typeof(x),typeof(next)}(x, next) - end -end - -getname(::TypeOrValue{ListNode{name}}) where {name} = name -getparentnames(::TypeOrValue{ListNode{name,parentnames}}) where {name,parentnames} = - parentnames -getelementtype(::TypeOrValue{ListNode{name,parentnames,X}}) where {name,parentnames,X} = X -getelement(n::ListNode) = n.x -getnext(n::ListNode) = n.next -Base.eltype(x::TypeOrValue{<:ListNode}) = eltype(getelementtype(x)) - -mutable struct Graph - last::Ref{Union{Root,ListNode}} -end - -Graph() = Graph(Ref{Union{Root,ListNode}}(Root())) - -Base.:(==)(g1::Graph, g2::Graph) = g1.last == g2.last - -Base.getindex(g::Graph) = g.last[] -function Base.setindex!(g::Graph, x::Union{Root,ListNode}) - g.last[] = x - g +mutable struct Graph{N} + ref::Base.RefValue{Vector{N}} end -Base.map(f::Function, g::Graph) = map(f, g.last[]) -Base.map(::Function, ::Root) = nothing -function Base.map(f::Function, n::ListNode) - map(f, n.next) - f(getname(n), getparentnames(n), n.x) -end +Graph{N}() where {N} = Graph(Ref(N[])) +Base.push!(g::Graph{N}, n::N) where {N} = push!(g.ref[], n) -function Base.push!(graph::Graph, name, parentnames, x) - graph[] = ListNode(name, parentnames, x, graph[]) - graph +mutable struct Node + graph::Graph{Node} + @tryconst name::Symbol + @tryconst parentnames::Vector{Symbol} + @tryconst operation::Operation end -Base.merge!(graph::Graph) = graph -function Base.merge!(graph1::Graph, graph2::Graph, graphs...) - if graph1 != graph2 - map(graph2) do name, parentnames, x - push!(graph1, name, parentnames, x) - end - graph2.last = graph1.last - end - merge!(graph1, graphs...) -end - -getnode(graph::Graph, s::TypeSymbol) = getnode(graph[], s) -getnode(::Root, s::TypeSymbol) = - throw(ErrorException("symbol $(getsymbol(s)) not found in graph")) -function getnode(x::ListNode, v::TypeSymbol) - if getname(x) == getsymbol(v) - x - else - getnode(x.next, v) - end +# todo: make sure name is unique +function Node(name::Symbol, op::Operation, parents::Node...) + graph = mergegraphs!(parents...) + parentnames = [a.name for a in parents] + node = Node(graph, name, parentnames, op) + push!(graph, node) + node end -""" - Node{name} - -Objects of type `Node` correspond to the nodes of the computational graph. -Each node is identified by a uniquely generated name `name`. -""" -struct Node{name} - graph::Graph - Node(name::Symbol, graph::Graph) = new{name}(graph) -end +Base.eltype(node::Node) = eltype(node.operation) function Base.show(io::IO, node::Node) - name = getname(node) + name = node.name type = eltype(node) print(io, "Node($name,$type)") end -Base.eltype(node::Node) = node |> getnode |> eltype -getname(::TypeOrValue{Node{name}}) where {name} = name -getgraph(node::Node) = node.graph -getnode(node::Node) = getnode(getgraph(node), TypeSymbol(getname(node))) - -function Node(name::Symbol, x::X, parents::Node...) where {X} - graph = if isempty(parents) - Graph() - else - merge!((n.graph for n in parents)...) # return Root() if empty +mergegraphs!() = Graph{Node}() +mergegraphs!(node::Node) = node.graph +function mergegraphs!(node1::Node, node2::Node, nodes::Node...) + if node1.graph.ref != node2.graph.ref + append!(node1.graph.ref[], node2.graph.ref[]) + empty!(node2.graph.ref[]) + node2.graph.ref = node1.graph.ref end - parentnames = Tuple(getname(a) for a in parents) - push!(graph, name, parentnames, x) - Node(name, graph) + mergegraphs!(node1, nodes...) end - -update!(node, args...) = update!(getelement(node), args...) -isinitialized(node, args...) = isinitialized(getelement(node), args...) -(node::Node)(args...) = getelement(node)(args...) diff --git a/src/inlinedmap.jl b/src/inlinedmap.jl index c35fb5b..52ad78b 100644 --- a/src/inlinedmap.jl +++ b/src/inlinedmap.jl @@ -21,7 +21,7 @@ generated symbol that identifies the node. """ function inlinedmap(f, arg::Node, args::Node...; name = nothing) uniquename = genname(name) - argtypes = getoperationtype.((arg, args...)) + argtypes = eltype.((arg, args...)) T = Base._return_type(f, Tuple{argtypes...}) op = InlinedMap(T, f) Node(uniquename, op, arg, args...) @@ -44,9 +44,7 @@ function generate( end end -@generated function getvalue(node::ListNode, imap::InlinedMap) - names = getparentnames(node) - quote - Base.@ncall $(length(names)) imap.f i -> (getvalue(node, TypeSymbol($names[i]))) - end +function getvalue(graph::CompiledGraph, node::CompiledNode, imap::InlinedMap) + parentnames = getparentnames(node) .|> TypeSymbol + imap.f((getvalue(graph, n) for n in parentnames)...) end diff --git a/src/input.jl b/src/input.jl index a82f246..03bf49c 100644 --- a/src/input.jl +++ b/src/input.jl @@ -66,8 +66,6 @@ function input(x::T; name::Union{Nothing,Symbol} = nothing) where {T} Node(uniquename, op) end -getvalue(::ListNode, element::Input) = getvalue(element) - function generate( inputnames::NTuple{<:Any,Symbol}, name::Symbol, @@ -82,7 +80,7 @@ function generate( expr = Expr(:quote) push!(expr.args, :($updated_s = $updated)) - push!(expr.args, :($nodename_s = getnode(list, $(TypeSymbol(name))))) + push!(expr.args, :($nodename_s = getoperation(graph, $(TypeSymbol(name))))) if updated push!( expr.args, diff --git a/src/lag.jl b/src/lag.jl index 61fbf9a..d87a668 100644 --- a/src/lag.jl +++ b/src/lag.jl @@ -44,7 +44,7 @@ julia> i = input(Int) ``` """ function lag(n::Integer, node::Node; name = nothing) - T = getoperationtype(node) + T = eltype(node) lagnode = foldl(Lag(T, n), node; name) do state, x push!(state, x) state diff --git a/src/map.jl b/src/map.jl index 91a93f7..ea92686 100644 --- a/src/map.jl +++ b/src/map.jl @@ -5,7 +5,6 @@ mutable struct Map{T,F} <: Operation{T} end @inline getvalue(x::Map) = x.x -@inline getvalue(::ListNode, element::Map) = getvalue(element) @inline function update!(m::Map, args...) @tryinline m.x = m.f(args...) @@ -22,7 +21,7 @@ generated symbol that identifies the node. """ function Base.map(f::Function, arg::Node, args::Node...; name = nothing) uniquename = genname(name) - argtypes = getoperationtype.((arg, args...)) + argtypes = eltype.((arg, args...)) T = Base._return_type(f, Tuple{argtypes...}) op = Map{T}(f) Node(uniquename, op, arg, args...) @@ -32,7 +31,7 @@ end function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type{<:Map}) updated_s = Symbol(:updated, name) initialized_s = Symbol(:initialized, name) - args = (:(getvalue(list, $(TypeSymbol(n)))) for n in parentnames) + args = (:(getvalue(graph, $(TypeSymbol(n)))) for n in parentnames) condition_updated = Expr(:call, :|, (Symbol(:updated, n) for n in parentnames)...) condition_initialized = Expr(:call, :&, (Symbol(:initialized, n) for n in parentnames)...) @@ -40,7 +39,7 @@ function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type quote $initialized_s = $condition_initialized $updated_s = if $condition_initialized & $condition_updated - $nodename_s = getnode(list, $(TypeSymbol(name))) + $nodename_s = getoperation(graph, $(TypeSymbol(name))) $(Expr(:call, :update!, nodename_s, args...)) true else diff --git a/src/selecter.jl b/src/selecter.jl index 0379dc5..9c5072b 100644 --- a/src/selecter.jl +++ b/src/selecter.jl @@ -15,7 +15,7 @@ function select(x::Node, condition::Node; name::Union{Nothing,Symbol} = nothing) throw(ErrorException("eltype(condition) is $(eltype(condition)), expected Bool")) end uniquename = genname(name) - op = Selecter{getoperationtype(x)}() + op = Selecter{eltype(x)}() Node(uniquename, op, x, condition) end @@ -37,9 +37,9 @@ function select(f::Function, x::Node; name::Union{Nothing,Symbol} = nothing) select(x, condition; name) end -function getvalue(node::ListNode, ::Selecter) - node_name, _ = getparentnames(node) - getvalue(node, TypeSymbol(node_name)) # todo: avoid starting from the leaf +function getvalue(graph::CompiledGraph, node::CompiledNode, ::Selecter) + parent_name = getparentnames(node) |> first |> TypeSymbol + getvalue(graph, parent_name) end function generate( @@ -50,7 +50,7 @@ function generate( ) updated_s = Symbol(:updated, name) initialized_s = Symbol(:initialized, name) - args = [:(getvalue(list, $(TypeSymbol(n)))) for n in parentnames] + args = [:(getvalue(graph, $(TypeSymbol(n)))) for n in parentnames] condition_updated = Expr(:call, :|, (Symbol(:updated, n) for n in parentnames)...) condition_initialized = Expr(:call, :&, (Symbol(:initialized, n) for n in parentnames)...) diff --git a/src/trackers.jl b/src/trackers.jl index 93f2516..9034f33 100644 --- a/src/trackers.jl +++ b/src/trackers.jl @@ -1,10 +1,9 @@ abstract type AbstractGraphTracker end -using Base: nothing_sentinel struct NullGraphTracker <: AbstractGraphTracker end -on_update_start!(::AbstractGraphTracker, inputnames) = nothing -on_update_node!(::AbstractGraphTracker, name) = nothing +on_update_start!(::AbstractGraphTracker) = nothing +on_update_node!(::AbstractGraphTracker, name, isinput) = nothing on_update_stop!(::AbstractGraphTracker) = nothing struct TrackingNode @@ -12,31 +11,20 @@ struct TrackingNode id::Int64 elapsed_time::Int64 bytes_allocated::Int -end - -struct TrackingTriggers - name::Symbol - id::Int64 + isinput::Bool end mutable struct PerformanceGraphTracker <: AbstractGraphTracker @tryconst trackingnodes::Vector{TrackingNode} - @tryconst trackingtriggers::Vector{TrackingTriggers} currentid::Int64 lasttime::UInt64 @tryconst total_bytes_allocated::Base.RefValue{Int64} end -PerformanceGraphTracker() = PerformanceGraphTracker( - TrackingNode[], - TrackingTriggers[], - zero(Int64), - zero(UInt64), - Ref(zero(Int64)), -) +PerformanceGraphTracker() = + PerformanceGraphTracker(TrackingNode[], zero(Int64), zero(UInt64), Ref(zero(Int64))) gettrackingnodes(pm::PerformanceGraphTracker) = pm.trackingnodes -gettrackingtriggers(pm::PerformanceGraphTracker) = pm.trackingtriggers gettime(::PerformanceGraphTracker) = time_ns() getelapsedtime(pm::PerformanceGraphTracker) = reinterpret(Int64, gettime(pm) - pm.lasttime) @@ -49,23 +37,18 @@ function getallocatedbytes!(pm::PerformanceGraphTracker) pm.total_bytes_allocated[] - old_total_bytes_allocated end -function on_update_start!(pm::PerformanceGraphTracker, names) - id = pm.currentid += 1 - for name in names - rs = TrackingTriggers(name, id) - push!(pm.trackingtriggers, rs) - end - +function on_update_start!(pm::PerformanceGraphTracker) + pm.currentid += 1 setallocatedbytes!(pm) settime!(pm) nothing end -function on_update_node!(pm::PerformanceGraphTracker, name) +function on_update_node!(pm::PerformanceGraphTracker, name, isinput::Bool) elapsed_time = getelapsedtime(pm) bytes_allocated = getallocatedbytes!(pm) - rs = TrackingNode(name, pm.currentid, elapsed_time, bytes_allocated) + rs = TrackingNode(name, pm.currentid, elapsed_time, bytes_allocated, isinput) push!(pm.trackingnodes, rs) setallocatedbytes!(pm) diff --git a/src/updated.jl b/src/updated.jl index a4b32f8..71f2c53 100644 --- a/src/updated.jl +++ b/src/updated.jl @@ -32,10 +32,9 @@ function generate(::Any, name::Symbol, parentnames::NTuple{<:Any,Symbol}, ::Type quote $initialized_s = $condition_initialized $updated_s = $condition_initialized & $condition_updated - $nodename_s = getnode(list, $(TypeSymbol(name))) + $nodename_s = getoperation(graph, $(TypeSymbol(name))) $(Expr(:call, :update!, nodename_s, updated_s)) end end @inline getvalue(x::Updated) = x.updated -@inline getvalue(::ListNode, element::Updated) = getvalue(element) diff --git a/test/runtests.jl b/test/runtests.jl index 1795c84..a82d168 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using BenchmarkTools using ReactiveGraphs using Test -import ReactiveGraphs: Graph, getoperationtype, Node +import ReactiveGraphs: getoperationtype, Node macro testnoalloc(expr, kws...) esc(quote @@ -11,45 +11,43 @@ macro testnoalloc(expr, kws...) end) end -function collectgraph(g::Graph) - c = [] - map((_, _, x) -> push!(c, x), g) - c -end +# function collectgraph(g::Graph) +# c = [] +# map((_, _, x) -> push!(c, x), g) +# c +# end sink(arg::Node) = sink(identity, arg) function sink(f::Function, args::Node...) - T = Base._return_type(f, Tuple{(getoperationtype(a) for a in args)...}) + m = inlinedmap(f, args...) + T = eltype(m) v = Vector{T}() - map(args...) do x::Vararg - push!(v, f(x...)) - nothing - end + foldl(push!, v, m) v end -@testset "graph" begin - g1 = Graph() - @test g1 == merge!(g1) +# @testset "graph" begin +# g1 = mergegraphs!() +# @test g1 == merge!(g1) - push!(g1, :a, (), 1) - push!(g1, :b, (), 2) - @test collectgraph(g1) == [1, 2] +# push!(g1, :a, (), 1) +# push!(g1, :b, (), 2) +# @test collectgraph(g1) == [1, 2] - g2 = Graph() - push!(g2, :c, (), 3) - merge!(g1, g2) - @test g1 == g2 +# g2 = Graph() +# push!(g2, :c, (), 3) +# merge!(g1, g2) +# @test g1 == g2 - @test collectgraph(g1) == [1, 2, 3] -end +# @test collectgraph(g1) == [1, 2, 3] +# end @testset "input" begin @testset "1 input" begin n1 = input(Int) c = sink(x -> 2x, n1) - s = Source(n1) - push!(s, 2) + g, s = compile(n1) + push!(g, s => 2) @test c == [4] end @@ -57,18 +55,17 @@ end n1 = input(Int) n2 = input(Int) c = sink(+, n1, n2) - s1 = Source(n1) - s2 = Source(n2) - push!(s1 => 1, s2 => 2) - push!(s1 => 3, s2 => 4) + g, s1, s2 = compile(n1, n2) + push!(g, s1 => 1, s2 => 2) + push!(g, s1 => 3, s2 => 4) @test c == [3, 7] end @testset "1 mutable input" begin n1 = input(Ref(0)) c = sink(x -> 2x[], n1) - s = Source(n1) - push!(s, x -> x[] = 2) + g, s = compile(n1) + push!(g, s, x -> x[] = 2) @test c == [4] end @@ -76,18 +73,15 @@ end n1 = input(Ref(0)) n2 = input(Int) c = sink((x, y) -> 2x[] + y, n1, n2) - s1 = Source(n1) - s2 = Source(n2) - push!(s1 => x -> x[] = 2, s2 => 3) + g, s1, s2 = compile(n1, n2) + push!(g, s1 => x -> x[] = 2, s2 => 3) @test c == [7] end @testset "disjoint graphs" begin n1 = input(Int) n2 = input(Int) - s1 = Source(n1) - s2 = Source(n2) - @test_throws ErrorException push!(s1 => 1, s2 => 2) + @test_throws AssertionError compile(n1, n2) end end @@ -96,11 +90,10 @@ end n1 = input(Int) n2 = input(Int) c = sink(+, n1, n2) - s1 = Source(n1) - s2 = Source(n2) - push!(s1, 1) - push!(s2, 2) - push!(s1, 3) + g, s1, s2 = compile(n1, n2) + push!(g, s1, 1) + push!(g, s2, 2) + push!(g, s1, 3) @test c == [3, 5] end @@ -112,11 +105,10 @@ end @test eltype(n2) == Int @test eltype(m) == Int c = sink(m) - s1 = Source(n1) - s2 = Source(n2) - push!(s1, 1) - push!(s2, 2) - push!(s1, 3) + g, s1, s2 = compile(n1, n2) + push!(g, s1, 1) + push!(g, s2, 2) + push!(g, s1, 3) @test c == [3, 5] end @@ -125,11 +117,10 @@ end n2 = input(Int) c1 = sink(+, n1, n2) c2 = sink((x, y) -> -(x + y), n1, n2) - s1 = Source(n1) - s2 = Source(n2) - push!(s1, 1) - push!(s2, 2) - push!(s1, 3) + g, s1, s2 = compile(n1, n2) + push!(g, s1, 1) + push!(g, s2, 2) + push!(g, s1, 3) @test c1 == [3, 5] @test c2 == [-3, -5] end @@ -140,9 +131,9 @@ end n1 = input(Int) n2 = foldl(+, 1, n1) c = sink(n2) - s = Source(n1) - push!(s, 2) - push!(s, 3) + g, s = compile(n1) + push!(g, s, 2) + push!(g, s, 3) @test c == [3, 6] end @@ -154,9 +145,9 @@ end end n3 = map(x -> x[], n2) c = sink(n3) - s = Source(n1) - push!(s, 2) - push!(s, 3) + g, s = compile(n1) + push!(g, s, 2) + push!(g, s, 3) @test c == [3, 6] end end @@ -166,10 +157,9 @@ end n2 = input(Int) n3 = inlinedmap(+, n1, n2) c = sink(n3) - s1 = Source(n1) - s2 = Source(n2) - push!(s1, 1) - push!(s2, 2) + g, s1, s2 = compile(n1, n2) + push!(g, s1, 1) + push!(g, s2, 2) @test c == [3] end @@ -179,15 +169,14 @@ end n2 = input(Bool) n3 = filter(n1, n2) c = sink(n3) - s1 = Source(n1) - s2 = Source(n2) - push!(s1, 2) - push!(s2, false) - push!(s1, 3) - push!(s2, true) - push!(s1, 4) - push!(s2, false) - push!(s1, 5) + g, s1, s2 = compile(n1, n2) + push!(g, s1, 2) + push!(g, s2, false) + push!(g, s1, 3) + push!(g, s2, true) + push!(g, s1, 4) + push!(g, s2, false) + push!(g, s1, 5) @test c == [3, 4] end @@ -195,9 +184,9 @@ end n1 = input(Int) n2 = filter(iseven, n1) c = sink(n2) - s1 = Source(n1) + g, s1 = compile(n1) for i = 1:4 - push!(s1, i) + push!(g, s1, i) end @test c == [2, 4] end @@ -218,18 +207,16 @@ end n4 = input(Int) n5 = map(+, n3, n4) c = sink(n5) - s1 = Source(n1) - s2 = Source(n2) - s4 = Source(n4) - push!(s1, 1) - push!(s2, false) - push!(s4, 2) - push!(s2, true) - push!(s1, 3) - push!(s4, 4) - push!(s2, false) - push!(s1, 5) - push!(s4, 6) + g, s1, s2, s4 = compile(n1, n2, n4) + push!(g, s1, 1) + push!(g, s2, false) + push!(g, s4, 2) + push!(g, s2, true) + push!(g, s1, 3) + push!(g, s4, 4) + push!(g, s2, false) + push!(g, s1, 5) + push!(g, s4, 6) @test c == [3, 5, 7] end @@ -237,9 +224,9 @@ end n1 = input(Int) n2 = select(iseven, n1) c = sink(n2) - s1 = Source(n1) + g, s1 = compile(n1) for i = 1:7 - push!(s1, i) + push!(g, s1, i) end @test c == [2, 4, 6] end @@ -256,16 +243,16 @@ end n1 = input(Int) n2 = constant(1) c = sink(+, n1, n2) - s = Source(n1) - push!(s, 2) + g, s = compile(n1) + push!(g, s, 2) @test c == [3] n1 = input(Bool) n2 = constant(true) c = sink(&, n1, n2) - s = Source(n1) - push!(s, true) - push!(s, false) + g, s = compile(n1) + push!(g, s, true) + push!(g, s, false) @test c == [true, false] end @@ -274,13 +261,12 @@ end n2 = input(Int) n3 = quiet(n2) c = sink(+, n1, n3) - s1 = Source(n1) - s2 = Source(n2) - push!(s1, 1) - push!(s2, 2) - push!(s1, 3) - push!(s2, 4) - push!(s1, 5) + g, s1, s2 = compile(n1, n2) + push!(g, s1, 1) + push!(g, s2, 2) + push!(g, s1, 3) + push!(g, s2, 4) + push!(g, s1, 5) @test c == [5, 9] end @@ -288,9 +274,9 @@ end n1 = input(Int) n2 = lag(2, n1) c = sink(n2) - s1 = Source(n1) + g, s1 = compile(n1) for i = 1:10 - push!(s1, i) + push!(g, s1, i) end @test c == 1:8 end @@ -302,9 +288,9 @@ end c1 = sink(identity, n3) c2 = sink((x, y) -> x, n3, n1) - s = Source(n1) + g, s = compile(n1) for i = 1:4 - push!(s, i) + push!(g, s, i) end @test c1 == [true, true] @test c2 == [false, true, false, true] @@ -320,14 +306,13 @@ end n3 = foldl((state, x) -> state + x, 1, i1s) n4 = inlinedmap(+, n2, n3) n5 = lag(1, n4) + n6 = updated(n5) - s1 = Source(i1) - s2 = Source(i2) - s3 = Source(i3) - push!(s1, 1) - push!(s2, true) - push!(s3, true) - @testnoalloc push!($s1, 1) + g, s1, s2, s3 = compile(i1, i2, i3) + push!(g, s1, 1) + push!(g, s2, true) + push!(g, s3, true) + @testnoalloc push!($g, $s1, 1) end @testset "PerformanceTrackers" begin @@ -335,19 +320,17 @@ end i2 = input(Int; name = "input2") n1 = map(x -> 2x, i1) n2 = map(+, n1, i2) - s1 = Source(i1) - s2 = Source(i2) - push!(s1, 1) - push!(s2, 2) - tracker = PerformanceGraphTracker() - push!(tracker, s1, 2) - push!(tracker, s2, 2) - push!(tracker, s1 => 3, s2 => 3) - triggers = gettrackingtriggers(tracker) - nodes = gettrackingnodes(tracker) - @test length(triggers) == 4 - @test map(x -> x.id, triggers) == [1, 2, 3, 3] + + g, s1, s2 = compile(i1, i2) + push!(g, s1, 1) + push!(g, s2, 2) + g, s1, s2 = compile(i1, i2; tracker = PerformanceGraphTracker()) + push!(g, s1, 2) + push!(g, s2, 2) + push!(g, s1 => 3, s2 => 3) + nodes = gettrackingnodes(g) @test length(nodes) == 3 * 4 @test map(x -> x.id, nodes) == [i for i = 1:3 for _ = 1:4] @test map(x -> x.bytes_allocated, nodes) == [0 for i = 1:3 for _ = 1:4] + @test map(x -> x.isinput, nodes) == [(i in x) for x in [(1,), (3,), (1, 3)] for i = 1:4] end diff --git a/todo.txt b/todo.txt new file mode 100644 index 0000000..c4a3a88 --- /dev/null +++ b/todo.txt @@ -0,0 +1,5 @@ +# refacto of the logic +- how to name the edges? can we use integers instead of Symbols, and make the name optional? +getelementtype -> getoperationtype +getelement->getoperation +getnode => getedge