From 3295f197388f453ae7faea89cdd66b10a5cfa653 Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Wed, 31 Jul 2024 15:05:31 +0200 Subject: [PATCH 1/4] Slight formalization of naming strategies Add support for Chain with named layers --- src/serialize/namingutil.jl | 79 ++++++++++++++++++++------- src/serialize/serialize.jl | 17 ++++-- test/serialize/serialize.jl | 105 ++++++++++++++++++++++++++++++++++++ 3 files changed, 179 insertions(+), 22 deletions(-) diff --git a/src/serialize/namingutil.jl b/src/serialize/namingutil.jl index e89aea6..79de850 100644 --- a/src/serialize/namingutil.jl +++ b/src/serialize/namingutil.jl @@ -1,17 +1,68 @@ - default_namestrat(f) = name_runningnr() default_namestrat(g::CompGraph) = default_namestrat(vertices(g)) function default_namestrat(vs::AbstractVector{<:AbstractVertex}) # Even if all vertices have unique names, we can't be certain that no vertex produces more than one node # Therefore, we must take a pass through name_runningnr for each op even after we have mapped the vertex to a nextname. This is the reason for the v -> f -> namegen(v, "") wierdness + + all(isnamed, vs) && length(unique(name.(vs))) == length(name.(vs)) && return NamedNodeContext("", name_runningnr(;addtofirst=false)) namegen = name_runningnr() - all(isnamed, vs) && length(unique(name.(vs))) == length(name.(vs)) && return v -> f -> namegen(v, "") ng(v::AbstractVertex) = namegen ng(f) = namegen(f) return ng end +function default_namestrat(c::Chain) + !(eltype(keys(c)) <: NameType) && return name_runningnr() + NamedNodeContext("", name_runningnr(;addtofirst=false)) +end + +struct NameRunningNr{F} + addtofirst::Bool + init::Int + runningnrs::Dict{String, Int} + namefun::F +end + +Base.Broadcast.broadcastable(n::NameRunningNr) = Ref(n) + +function (n::NameRunningNr)(f) + bname = n.namefun(f) + nextnr = get(n.runningnrs, bname, n.init) + n.runningnrs[bname] = nextnr + 1 + return if nextnr != n.init || n.addtofirst + string(bname, "_", nextnr) + else + bname + end +end + +name_runningnr(namefun = genname; addtofirst=true, init=0) = NameRunningNr(addtofirst, init, Dict{String, Int}(), namefun) + +const NameType = Union{Symbol, AbstractString} + +struct NamedNodeContext{F} + prefix::String + namegen::F +end + +Base.Broadcast.broadcastable(n::NamedNodeContext) = Ref(n) + +(ctx::NamedNodeContext)(args...) = isempty(ctx.prefix) ? ctx.namegen(args...) : ctx.namegen(ctx.prefix) + +function (ctx::NamedNodeContext)(v::AbstractVertex) + @set ctx.prefix = string(ctx.prefix, isempty(ctx.prefix) ? "" : ".", name(v)) +end + +chainlayername(f, ::Any, ::Any) = f +function chainlayername(ctx::NamedNodeContext, name::NameType, layer) + @set ctx.prefix = string(ctx.prefix, isempty(ctx.prefix) ? "" : ".", name) +end +function chainlayername(ctx::NamedNodeContext{<:NameRunningNr}, nr::Integer, layer) + !isempty(ctx.prefix) && return @set ctx.prefix = string(ctx.prefix, '[', nr, ']') + return name_runningnr() +end + isnamed(v::AbstractVertex) = isnamed(base(v)) isnamed(v::CompVertex) = false isnamed(v::InputVertex) = true @@ -21,26 +72,16 @@ isnamed(t::DecoratingTrait) = isnamed(base(t)) isnamed(t) = false isnamed(::NamedTrait) = true -function name_runningnr(namefun = genname) - exists = Set{String}() - - return function(f, init="_0") - bname = namefun(f) - candname = bname * init - next = -1 - while candname in exists - next += 1 - candname = bname * "_" * string(next) - end - push!(exists, candname) - return candname - end -end - genname(v::AbstractVertex) = name(v) -genname(f::F) where F = lowercase(string(nameof(F))) +genname(::F) where F = lowercase(string(nameof(F))) genname(s::AbstractString) = s genname(f::Function) = lowercase(string(f)) recursename(f, namestrat) = recursename(f, namestrat(f)) recursename(f, fname::String) = fname +function recursename(f, ctx::NamedNodeContext) + res = ctx(f) + res isa NamedNodeContext && return recursename(f, res.prefix) + res +end + diff --git a/src/serialize/serialize.jl b/src/serialize/serialize.jl index 24c57b2..b228061 100644 --- a/src/serialize/serialize.jl +++ b/src/serialize/serialize.jl @@ -224,15 +224,26 @@ function add_outputs!(gp, namestrat, pps::Tuple) add_output!.(output_pps) end -# Only purpose is to snag the name in case this is the naming strategy -function (v::NaiveNASlib.MutationVertex)(pps::AbstractProbe...) +# Stuff whose only purpose is to override the name in case this is the naming strategy +function (v::MutationVertex)(pps::AbstractProbe...) ppsname = map(pps) do pp newnamestrat(pp, nextname(pp)(v)) end - ppout = base(v)(ppsname...) + ppout = base(v)(ppsname...) return newnamestrat(ppout, nextname(pps[1])) end +function (c::Chain)(pp::AbstractProbe) + ppnext = pp + for (k, l) in zip(keys(c), c) + ppnext = l(newnamestrat(ppnext, chainlayername(nextname(pp), k, l))) + end + return newnamestrat(ppnext, nextname(pp)) +end + + +# End of stuff whose only purpose is to override the name in case this is the naming strategy + actfun(::FluxLayer, l) = l.σ function weightlayer(lt::FluxParLayer, l, pp, optype;attributes = ONNX.AttributeProto[]) lname = recursename(l, nextname(pp)) diff --git a/test/serialize/serialize.jl b/test/serialize/serialize.jl index 5a14307..e1fcf3c 100644 --- a/test/serialize/serialize.jl +++ b/test/serialize/serialize.jl @@ -848,6 +848,111 @@ end end + @testset "Chains" begin + import ONNXNaiveNASflux: modelproto + + function remodel(m, args...; kwargs...) + pb = PipeBuffer() + save(pb, m, args...; kwargs...) + return load(pb) + end + + @testset "Simple Chain" begin + org = Chain(Dense(1 => 2, relu), Dense(2 => 3, sigmoid), Dense(3 => 4)) + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + end + + @testset "Simple Named Chain" begin + org = Chain(layer1 = Dense(1 => 2, relu), layer2 = Dense(2 => 3, sigmoid), layer3 = Dense(3 => 4)) + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org) + @test name.(mp.graph.node) == ["layer1", "layer1_relu", "layer2", "layer2_sigmoid", "layer3"] + end + + @testset "Simple Named Chain with name_runningnr" begin + org = Chain(layer1 = Dense(1 => 2, relu), layer2 = Dense(2 => 3, sigmoid), layer3 = Dense(3 => 4)) + res = remodel(org; namestrat=ONNXNaiveNASflux.name_runningnr()) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org, namestrat=ONNXNaiveNASflux.name_runningnr()) + @test name.(mp.graph.node) == ["dense_0", "dense_0_relu", "dense_1", "dense_1_sigmoid", "dense_2"] + end + + @testset "Nested Named Chain" begin + org = Chain( + layer1 =Dense(1 => 2, relu), + layer2 = Dense(2 => 3, sigmoid), + inner = Chain( + Dense(3 => 3, tanh), + Dense(3=>3)), + layer3 = Dense(3 => 4)) + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org) + @test name.(mp.graph.node) == ["layer1", "layer1_relu", "layer2", "layer2_sigmoid", "inner[1]", "inner[1]_tanh", "inner[2]", "layer3"] + end + + @testset "Nested Named Chain Named Inner" begin + org = Chain( + layer1 = Dense(1 => 2, relu), + layer2 = Dense(2 => 3, sigmoid), + inner = Chain( + ilayer1 = Dense(3 => 3, tanh), + ilayer2 = Dense(3=>3)), + layer3 = Dense(3 => 4)) + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org) + @test name.(mp.graph.node) == ["layer1", "layer1_relu", "layer2", "layer2_sigmoid", "inner.ilayer1", "inner.ilayer1_tanh", "inner.ilayer2", "layer3"] + end + + @testset "Nested Chain Named Inner" begin + org = Chain( + Dense(1 => 2, relu), + Dense(2 => 3, sigmoid), + Chain( + ilayer1 = Dense(3 => 3, tanh), + ilayer2 = Dense(3=>3)), + Dense(3 => 4)) + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org) + @test name.(mp.graph.node) == ["dense_0", "dense_0_relu", "dense_1", "dense_1_sigmoid", "dense_2", "dense_2_tanh", "dense_3", "dense_4"] + end + + @testset "Named Chain Parallel" begin + org = Chain( + layer1 =Dense(1 => 2, relu), + layer2 = Dense(2 => 3, sigmoid), + fork = Parallel(+, + Chain( + Dense(3 => 3, tanh), + Dense(3=>3)), + Dense(3 => 3, elu), + ), + layer3 = Dense(3 => 4)) + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org) + @test name.(mp.graph.node) == ["layer1", "layer1_relu", "layer2", "layer2_sigmoid", "fork[1]", "fork[1]_tanh", "fork[2]", "fork", "fork_elu", "fork_1", "layer3"] + end + end + @testset "Models" begin import ONNXNaiveNASflux: modelproto, sizes, clean_size From d53d410b2160bc6020fc5e09ee4773c664e24b8e Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Thu, 1 Aug 2024 00:35:57 +0200 Subject: [PATCH 2/4] Generalize name intercept strategy to work for selected functors --- Project.toml | 2 + src/ONNXNaiveNASflux.jl | 1 + src/serialize/namingutil.jl | 124 ++++++++++++++++++++++++++++++------ src/serialize/serialize.jl | 111 +++++++++++++++++++++++++------- test/serialize/serialize.jl | 119 +++++++++++++++++++++++++++++++++- 5 files changed, 312 insertions(+), 45 deletions(-) diff --git a/Project.toml b/Project.toml index 802b0ae..4d9a1bd 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.2.13" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JuMP = "4076af6c-e467-56ae-b986-b466b2749572" NaiveNASflux = "85610aed-7d32-5e57-bb50-4c2e1c9e7997" NaiveNASlib = "bd45eb3e-47ce-54bd-9eaf-e86c5f900853" @@ -17,6 +18,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" Flux = "0.13, 0.14" +Functors = "0.4" JuMP = "0.21, 0.22, 0.23, 1" NaiveNASflux = "2.0.10" NaiveNASlib = "2.0.11" diff --git a/src/ONNXNaiveNASflux.jl b/src/ONNXNaiveNASflux.jl index 35b53b2..b309a8f 100644 --- a/src/ONNXNaiveNASflux.jl +++ b/src/ONNXNaiveNASflux.jl @@ -6,6 +6,7 @@ import .BaseOnnx: array const ONNX = BaseOnnx using Flux using Flux: params +import Functors using NaiveNASflux using NaiveNASflux: weights, bias using NaiveNASflux: indim, outdim, actdim, actrank, layertype, wrapped diff --git a/src/serialize/namingutil.jl b/src/serialize/namingutil.jl index 79de850..eb5d54b 100644 --- a/src/serialize/namingutil.jl +++ b/src/serialize/namingutil.jl @@ -1,17 +1,99 @@ +# The naming strategies not well organized and should probably be rewritten. Until that happens, +# here is a little synopsis of whats going on here: + +# There are two main ways to create names: +# 1) By just using the function name (or lowercase name of the struct, e.g dense for Flux.Dense) +# and add a running number after to make names unique (e.g. dense_1) + +# 2) By intercepting names (e.g. the named tuple keys of Flux.Chain or the names of NaiveNASlibs +# vertices) and make sure all ops "below" in the call stack (e.g. the actual layers) use those names + +# Both strategies will be attached to the input ProtoProbe as the (badly named) field nextname +# when we start traversing the call stack. +# The strategy might be changed based on what we encounter. One example of when this happens +# is for the activation functions of the Flux layers which need to be separate nodes in ONNX +# so we make them inherit the name of the layer node (e.g. dense, conv) and suffix them with +# the name of the function. + +# For strategy 1) we just use NameRunningNr and thats the end of the story, no fuss! + +# For strategy 2) however, there are quite a few moving pieces, three to be exact. +# The whole thing is probably overengineered and could/should be done in a simpler way. + +# The first entity is the strategy itself: NamedNodeContext which signals that we are in the +# business of trying to find named things in the call stack. It also remembers what the last +# name we saw was (and the hierarchy of names in case of nesting, e.g. a Chain inside a Parallel +# inside a Chain). Note that for the sake of safety NamedNodeContext wraps a NameRunningNr to +# ensure that we don't end up with duplicate names. + +# In addition, whenever we create an input ProtoProbe with the nextname strategy being a +# NamedNodeContext we wrap the ProtoProbe in a NameInterceptProbe which is the second entity in this +# little scheme. This allows us to dispatch on high level (i.e non-primitive) function calls such as +# Chain, Parallel, SkipConnection and MutationVertex to just catch the names we are after as they +# are not visible on the primitive (e.g. Flux.Dense) level. + +# We could also have tried to dispatch on something like +# (c::Chain)(pp::ProtoProbe{<:Any, <:NamedNodeContext}), but this has the following problem: What do +# we do then after we have caught the name? We then want Chain to just do its thing on pp, but if +# we call c(pp) we just end up with infinite recursion. + +# We also don't want to replace the naming strategy at this point since we want to handle nested +# name holders (i.e the Chain inside a Parallel inside a Chain, or a Chain inside a MutationVertex +# if you'd like). + +# Instead we just unwrap the NameInterceptProbe whenever we hit a name holder before we forward the +# call. + +# The annoying thing here is that when we hold an object like a Flux.Parallel in our hand we know +# its structure and can see how it names things (e.g. through Functors.fmap_with_path), but we +# don't know for sure in what order the stuff inside it will be called when we call it as a function +# (ok, in practice we know since it has a doc-string contract on what it does, but we also don't +# want to reimplement it and all other things we want to catch names from). + +# To prevent that we need to encode in this library what each possible high level name holder (i.e +# Chain, Parallel, etc.) does when called, we use a third entity, the NamedFunction, which just +# wraps anything callable (well, it does not care if it is callable or not, but obviously things +# will not work if it is not) and the name it has. We then use a shallow Functors.fmap_with_path +# to wrap all named things (with methods) inside the name holder in a NamedFunction. + +# Yup, you read that right, for example, if Chain(layer1=l1, Layer2=l2) is called with a +# NameInterceptProbe as input it will first be fmap_with_path:ed into +# Chain(layer1=NamedFunction(l1, "layer1"), layer2=NamedFunction(l2, "layer2")), then the new chain +# will be called with the ProtoProbe wrapped inside the NameInterceptProbe. + +# When a NamedFunction is called with an AbstractProbe as input, it does two things: +# 1) Wrap the AbstractProbe in a NameInterceptProbe so we can intercept nested name holders (e.g. a +# Chain inside a Parallel inside a Chain). +# 2) Create a new naming strategy for all ONNX ops encountered when calling the function wrapped +# in the NamedFunction which is to use the name of the NamedFunction. + +# Note that at step 2) we append the name to any previously encountered names so that nested name holders +# get the correct path. In other words, the new naming strategy is a new NamedNodeContext where we have +# appended the name of the named node to the prefix. + +# One option to simplify could have been to just fmap_with_path the entire model before calling it. +# This might turn out to be simpler. The reasons I did not go for it was: +# 1) There would not be any way to name things inside things that are not fmap:able +# 2) It is still not easy to know which things are named (although we could just always use +# fieldnames) +# 3) Less control over what we actually want to wrap inside a NamedFunction (maybe we could use +# dispatch for this when walking though) +# 4) The functor structure for CompGraphs is pretty horrible (but we anyways don't use the +# fmap_with_path method for it since the vertices are the universal name holders for it) +# None of the above seem like showstoppers, the current way was just the path of least resistance. + + default_namestrat(f) = name_runningnr() default_namestrat(g::CompGraph) = default_namestrat(vertices(g)) function default_namestrat(vs::AbstractVector{<:AbstractVertex}) - # Even if all vertices have unique names, we can't be certain that no vertex produces more than one node - # Therefore, we must take a pass through name_runningnr for each op even after we have mapped the vertex to a nextname. This is the reason for the v -> f -> namegen(v, "") wierdness - all(isnamed, vs) && length(unique(name.(vs))) == length(name.(vs)) && return NamedNodeContext("", name_runningnr(;addtofirst=false)) - namegen = name_runningnr() - ng(v::AbstractVertex) = namegen - ng(f) = namegen(f) - return ng + # TODO: Maybe we should use NamedNodeContext here as well and just add runningnumbers to duplicated names? + name_runningnr() end +const NameType = Union{Symbol, AbstractString} + function default_namestrat(c::Chain) !(eltype(keys(c)) <: NameType) && return name_runningnr() NamedNodeContext("", name_runningnr(;addtofirst=false)) @@ -39,7 +121,15 @@ end name_runningnr(namefun = genname; addtofirst=true, init=0) = NameRunningNr(addtofirst, init, Dict{String, Int}(), namefun) -const NameType = Union{Symbol, AbstractString} +struct NamedFunction{F} + f::F + name::String +end +NaiveNASlib.base(n::NamedFunction) = n.f +NaiveNASlib.name(n::NamedFunction) = n.name +(n::NamedFunction)(args...; kwargs...) = n.f(args...; kwargs...) + +const NamedNode = Union{NamedFunction, MutationVertex} struct NamedNodeContext{F} prefix::String @@ -49,18 +139,12 @@ end Base.Broadcast.broadcastable(n::NamedNodeContext) = Ref(n) (ctx::NamedNodeContext)(args...) = isempty(ctx.prefix) ? ctx.namegen(args...) : ctx.namegen(ctx.prefix) - -function (ctx::NamedNodeContext)(v::AbstractVertex) - @set ctx.prefix = string(ctx.prefix, isempty(ctx.prefix) ? "" : ".", name(v)) -end - -chainlayername(f, ::Any, ::Any) = f -function chainlayername(ctx::NamedNodeContext, name::NameType, layer) - @set ctx.prefix = string(ctx.prefix, isempty(ctx.prefix) ? "" : ".", name) -end -function chainlayername(ctx::NamedNodeContext{<:NameRunningNr}, nr::Integer, layer) - !isempty(ctx.prefix) && return @set ctx.prefix = string(ctx.prefix, '[', nr, ']') - return name_runningnr() +function (ctx::NamedNodeContext)(v::Union{NamedNode, AbstractVertex}) + nodename = name(v) + # Maybe add some field in NamedFunction to indicate wether it is an array element or a field instead of + # startswith here + sep = isempty(ctx.prefix) || startswith(nodename, '[') ? "" : "." + @set ctx.prefix = string(ctx.prefix, sep, name(v)) end isnamed(v::AbstractVertex) = isnamed(base(v)) diff --git a/src/serialize/serialize.jl b/src/serialize/serialize.jl index b228061..1217c27 100644 --- a/src/serialize/serialize.jl +++ b/src/serialize/serialize.jl @@ -105,6 +105,29 @@ Idea is that probe will "record" all seen operations based on how methods for th """ abstract type AbstractProbe end +nextshape(p::AbstractProbe, f::Function) = f(shape(p)) +Base.ndims(p::AbstractProbe) = length(shape(p)) + +""" + WrappedProbe + +Abstract class which promises that it wraps another `AbstractProbe` to which it delegates most of its tasks. +""" +abstract type WrappedProbe <: AbstractProbe end + +unwrap(p::WrappedProbe) = p.wrapped + +Base.Broadcast.broadcastable(p::WrappedProbe) = Ref(p) + +NaiveNASlib.name(p::WrappedProbe) = NaiveNASlib.name(unwrap(p)) + +nextname(p::WrappedProbe) = nextname(unwrap(p)) +add!(p::WrappedProbe, args...) = add!(unwrap(p), args...) +shape(p::WrappedProbe) = shape(unwrap(p)) +add_output!(p::WrappedProbe) = add_output!(unwrap(p)) +newnamestrat(p::WrappedProbe, args...) = rewrap(p, newnamestrat(unwrap(p), args...)) +newfrom(p::WrappedProbe, args...) = rewrap(p, newfrom(unwrap(p), args...)) + # Called by several activation functions Base.oftype(::AbstractProbe, x) = x @@ -126,8 +149,9 @@ nextname(p::ProtoProbe) = p.nextname add!(p::ProtoProbe, n) = add!(p.graph, n) shape(p::ProtoProbe) = p.shape add_output!(p::ProtoProbe) = push!(p.graph.output, ONNX.ValueInfoProto(name(p), shape(p))) -Base.ndims(p::AbstractProbe) = length(shape(p)) -function inputprotoprobe!(gp, name, shape, namestrat) + +inputprotoprobe!(args...) = _inputprotoprobe!(args...) +function _inputprotoprobe!(gp, name, shape, namestrat) push!(gp.input, ONNX.ValueInfoProto(name, shape)) ProtoProbe(name, shape, namestrat, gp) end @@ -145,12 +169,72 @@ newnamestrat(p::ProtoProbe, f, pname=p.name) = ProtoProbe(pname, p.shape, f, p.g Return a new `ProtoProbe` with name `outname`. Argument `fshape` is used to determine a new shape (typically a function). """ newfrom(p::ProtoProbe, outname::AbstractString, fshape) = ProtoProbe(outname, nextshape(p, fshape), p.nextname, p.graph) -nextshape(p::AbstractProbe, f::Function) = f(shape(p)) add!(gp::ONNX.GraphProto, np::ONNX.NodeProto) = push!(gp.node, np) add!(gp::ONNX.GraphProto, tp::ONNX.TensorProto) = push!(gp.initializer, tp) +## Don't forget to check if new methods need to be added for any WrappedProbe implementations if you add something here! + +# Stuff whose only purpose is to override the name in case this is the naming strategy +# See namingutil for a little story about the design since it is a bit messy :( + +""" + NameInterceptProbe <: WrappedProbe + +An AbstractProbe which is only used to intercept methods call above the primitive level to catch names (e.g. a `Chain` with a `NamedTuple` as layers). +""" +struct NameInterceptProbe{P<:AbstractProbe} <: WrappedProbe + wrapped::P +end +rewrap(::NameInterceptProbe, p) = NameInterceptProbe(p) + +inputprotoprobe!(gp, name, shape, namestrat::NamedNodeContext) = NameInterceptProbe(_inputprotoprobe!(gp, name, shape, namestrat)) + +# This little dance is just to avoid ambiguities of (n::NamedNode)(pps...) since NamedNode is abstract +# TODO: Generate with macro so we can add types to the union without worry? +(v::MutationVertex)(pps::AbstractProbe...) = _apply_probe_call(v, pps...) +(n::NamedFunction)(pps::AbstractProbe...) = _apply_probe_call(n, pps...) + +_apply_probe_call(n::NamedNode, pps::AbstractProbe...) = __apply_probe_call(n, nextname(first(pps)), pps...) + +# We are not in the business of catching names here, so just forward the call +__apply_probe_call(n::NamedNode, ::Any, pps::AbstractProbe...) = base(n)(pps...) +# We just want to rewrap probes in NameInterceptProbes for the sole reason that we might encounter other +# objects that we want to intercept names from (e.g. a Chain inside a Parallel) +__apply_probe_call(n::NamedNode, ::NamedNodeContext, pps::AbstractProbe...) = _apply_probe_call(n, map(NameInterceptProbe, pps)...) + +function _apply_probe_call(n::NamedNode, pps::NameInterceptProbe...) + ppsname = map(pps) do pp + newnamestrat(pp, nextname(pp)(n)) + end + ppout = base(n)(ppsname...) + return newnamestrat(ppout, nextname(pps[1])) +end + +# Here we wrap all callable children in a NamedFunction so that we can access the name when the child is called +(c::Chain)(pp::NameInterceptProbe) = _instrument_named_functor(c; addbasename=false)(unwrap(pp)) +(p::Parallel)(pp::NameInterceptProbe) = _instrument_named_functor(p; addbasename=false)(unwrap(pp)) +(p::SkipConnection)(pp::NameInterceptProbe) = _instrument_named_functor(p; addbasename=false)(unwrap(pp)) + +# The exclude is to ensure that we only see the fields of w, not the fields of its children +_instrument_named_functor(w; addbasename=true) = Functors.fmap_with_path(w; exclude = (k,c) -> c != w) do keypath, child + _instrument_named_child(child, string(only(keypath)); addbasename) +end + +_instrument_named_child(child, name; kwargs...) = !isempty(methods(child)) ? NamedFunction(child, name) : child +_instrument_named_child(child::Tuple, name; addbasename) = ntuple(length(child)) do i + _instrument_named_child(child[i], string(addbasename ? name : "", '[', i, ']')) +end +_instrument_named_child(child::AbstractArray, name; addbasename) = map(enumerate(child)) do (i, elem) + _instrument_named_child(elem, string(addbasename ? name : "", '[', i, ']')) +end +_instrument_named_child(child::NamedTuple{K}, name; addbasename) where K = ntuple(length(child)) do i + _instrument_named_child(child[i], string(addbasename ? string(name, '.') : "", K[i])) +end |> NamedTuple{K} + +# End of stuff whose only purpose is to override the name in case this is the naming strategy + """ Used to get activation functions as [`ONNX.AttributeProto`](@ref)s. @@ -224,25 +308,6 @@ function add_outputs!(gp, namestrat, pps::Tuple) add_output!.(output_pps) end -# Stuff whose only purpose is to override the name in case this is the naming strategy -function (v::MutationVertex)(pps::AbstractProbe...) - ppsname = map(pps) do pp - newnamestrat(pp, nextname(pp)(v)) - end - ppout = base(v)(ppsname...) - return newnamestrat(ppout, nextname(pps[1])) -end - -function (c::Chain)(pp::AbstractProbe) - ppnext = pp - for (k, l) in zip(keys(c), c) - ppnext = l(newnamestrat(ppnext, chainlayername(nextname(pp), k, l))) - end - return newnamestrat(ppnext, nextname(pp)) -end - - -# End of stuff whose only purpose is to override the name in case this is the naming strategy actfun(::FluxLayer, l) = l.σ function weightlayer(lt::FluxParLayer, l, pp, optype;attributes = ONNX.AttributeProto[]) @@ -430,7 +495,7 @@ globalmaxpool(pp::AbstractProbe, wrap) = globalpool(pp, wrap, "GlobalMaxPool") function globalpool(pp::AbstractProbe, wrap, type) gpp = attribfun(s -> ismissing(s) ? s : (1, 1, s[3:end]...), type, pp) - ppnext = newnamestrat(gpp, f -> join([gpp.name, genname(f)], "_"), gpp.name) + ppnext = newnamestrat(gpp, f -> join([name(gpp), genname(f)], "_"), name(gpp)) wpp = wrap(ppnext) return newnamestrat(wpp, nextname(gpp)) end diff --git a/test/serialize/serialize.jl b/test/serialize/serialize.jl index e1fcf3c..8c59c6a 100644 --- a/test/serialize/serialize.jl +++ b/test/serialize/serialize.jl @@ -933,6 +933,25 @@ @test name.(mp.graph.node) == ["dense_0", "dense_0_relu", "dense_1", "dense_1_sigmoid", "dense_2", "dense_2_tanh", "dense_3", "dense_4"] end + @testset "Chain Parallel" begin + org = Chain( + Dense(1 => 2, relu), + Dense(2 => 3, sigmoid), + Parallel(+, + Chain( + Dense(3 => 3, tanh), + Dense(3 => 3)), + Dense(3 => 3, elu), + ), + Dense(3 => 4)) + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org) + @test name.(mp.graph.node) == ["dense_0", "dense_0_relu", "dense_1", "dense_1_sigmoid", "dense_2", "dense_2_tanh", "dense_3", "dense_4", "dense_4_elu", "add_0", "dense_5"] + end + @testset "Named Chain Parallel" begin org = Chain( layer1 =Dense(1 => 2, relu), @@ -940,7 +959,7 @@ fork = Parallel(+, Chain( Dense(3 => 3, tanh), - Dense(3=>3)), + Dense(3 => 3)), Dense(3 => 3, elu), ), layer3 = Dense(3 => 4)) @@ -949,7 +968,103 @@ x = randn(Float32, 1, 4) @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) mp = modelproto(org) - @test name.(mp.graph.node) == ["layer1", "layer1_relu", "layer2", "layer2_sigmoid", "fork[1]", "fork[1]_tanh", "fork[2]", "fork", "fork_elu", "fork_1", "layer3"] + @test name.(mp.graph.node) == ["layer1", "layer1_relu", "layer2", "layer2_sigmoid", "fork[1][1]", "fork[1][1]_tanh", "fork[1][2]", "fork[2]", "fork[2]_elu", "fork.connection", "layer3"] + end + + @testset "Named Chain Named Parallel" begin + org = Chain( + layer1 =Dense(1 => 2, relu), + layer2 = Dense(2 => 3, sigmoid), + fork = Parallel(+, + path1 = Chain( + Dense(3 => 3, tanh), + Dense(3 => 3)), + path2 = Dense(3 => 3, elu), + ), + layer3 = Dense(3 => 4)) + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org) + @test name.(mp.graph.node) == ["layer1", "layer1_relu", "layer2", "layer2_sigmoid", "fork.path1[1]", "fork.path1[1]_tanh", "fork.path1[2]", "fork.path2", "fork.path2_elu", "fork.connection", "layer3"] + end + + + @testset "Named Chain Parallel Named Chain" begin + org = Chain( + layer1 =Dense(1 => 2, relu), + layer2 = Dense(2 => 3, sigmoid), + fork = Parallel(+, + Chain( + l1 = Dense(3 => 3, tanh), + l2 = Dense(3 => 3) + ), + Chain( + Dense(3 => 3, elu), + Dense(3 => 3, leakyrelu) + ) + ), + layer3 = Dense(3 => 4)) + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org) + @test name.(mp.graph.node) == ["layer1", "layer1_relu", "layer2", "layer2_sigmoid", "fork[1].l1", "fork[1].l1_tanh", "fork[1].l2", "fork[2][1]", "fork[2][1]_elu", "fork[2][2]", "fork[2][2]_leakyrelu", "fork.connection", "layer3"] + end + + @testset "Named Chain SkipConnection" begin + org = Chain( + layer1 =Dense(1 => 2, relu), + layer2 = Dense(2 => 3, sigmoid), + fork = SkipConnection( + Chain( + Dense(3 => 3, tanh), + Dense(3 => 3)), + +), + layer3 = Dense(3 => 4)) + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org) + @test name.(mp.graph.node) == ["layer1", "layer1_relu", "layer2", "layer2_sigmoid", "fork.layers[1]", "fork.layers[1]_tanh", "fork.layers[2]", "fork.connection", "layer3"] + end + + @testset "Named Chain CompGraph" begin + org = Chain( + layer1 =Dense(1 => 2, relu), + layer2 = Dense(2 => 3, sigmoid), + graph = let + iv = denseinputvertex("graphin", 3) + v1 = fluxvertex("v1", Dense(3 => 3, elu), iv) + v2 = "v2" >> iv + v1 + CompGraph(iv, v2) + end, + layer3 = Dense(3 => 4)) + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org) + @test name.(mp.graph.node) == ["layer1", "layer1_relu", "layer2", "layer2_sigmoid", "graph.v1", "graph.v1_elu", "graph.v2", "layer3"] + end + + @testset "CompGraph Named Chain" begin + org = let + iv = denseinputvertex("graphin", 1) + v1 = fluxvertex("v1", Dense(1 => 3, elu), iv) + v2 = invariantvertex("chain", Chain(l1 = Dense(3 => 3, relu), l2 = Dense(3 => 3, tanh)), v1) + v3 = "v3" >> v1 + v2 + CompGraph(iv, v3) + end + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org) + @test name.(mp.graph.node) == ["v1", "v1_elu", "chain.l1", "chain.l1_relu", "chain.l2", "chain.l2_tanh", "v3"] end end From 025011078b8e1a030372f70dc0f092fc56dda37a Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Thu, 1 Aug 2024 11:09:40 +0200 Subject: [PATCH 3/4] Add testcase for named Chain with nested Array-Chain --- test/serialize/serialize.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/serialize/serialize.jl b/test/serialize/serialize.jl index 8c59c6a..e17efdc 100644 --- a/test/serialize/serialize.jl +++ b/test/serialize/serialize.jl @@ -901,6 +901,22 @@ @test name.(mp.graph.node) == ["layer1", "layer1_relu", "layer2", "layer2_sigmoid", "inner[1]", "inner[1]_tanh", "inner[2]", "layer3"] end + @testset "Nested Named Chain Array" begin + org = Chain( + layer1 =Dense(1 => 2, relu), + layer2 = Dense(2 => 3, sigmoid), + inner = Chain([ + Dense(3 => 3, tanh), + Dense(3=>3)]), + layer3 = Dense(3 => 4)) + res = remodel(org) + + x = randn(Float32, 1, 4) + @test org(x) == res(x) ≈ only(onnxruntime_infer(org, x)) + mp = modelproto(org) + @test name.(mp.graph.node) == ["layer1", "layer1_relu", "layer2", "layer2_sigmoid", "inner[1]", "inner[1]_tanh", "inner[2]", "layer3"] + end + @testset "Nested Named Chain Named Inner" begin org = Chain( layer1 = Dense(1 => 2, relu), From ad1de4380ea24b7475f8924d5c93469abdd90aa9 Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Thu, 1 Aug 2024 11:10:24 +0200 Subject: [PATCH 4/4] Formatting change --- test/serialize/serialize.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/serialize/serialize.jl b/test/serialize/serialize.jl index e17efdc..5d96a2d 100644 --- a/test/serialize/serialize.jl +++ b/test/serialize/serialize.jl @@ -887,7 +887,7 @@ @testset "Nested Named Chain" begin org = Chain( - layer1 =Dense(1 => 2, relu), + layer1 = Dense(1 => 2, relu), layer2 = Dense(2 => 3, sigmoid), inner = Chain( Dense(3 => 3, tanh), @@ -903,7 +903,7 @@ @testset "Nested Named Chain Array" begin org = Chain( - layer1 =Dense(1 => 2, relu), + layer1 = Dense(1 => 2, relu), layer2 = Dense(2 => 3, sigmoid), inner = Chain([ Dense(3 => 3, tanh), @@ -970,7 +970,7 @@ @testset "Named Chain Parallel" begin org = Chain( - layer1 =Dense(1 => 2, relu), + layer1 = Dense(1 => 2, relu), layer2 = Dense(2 => 3, sigmoid), fork = Parallel(+, Chain( @@ -989,7 +989,7 @@ @testset "Named Chain Named Parallel" begin org = Chain( - layer1 =Dense(1 => 2, relu), + layer1 = Dense(1 => 2, relu), layer2 = Dense(2 => 3, sigmoid), fork = Parallel(+, path1 = Chain( @@ -1009,7 +1009,7 @@ @testset "Named Chain Parallel Named Chain" begin org = Chain( - layer1 =Dense(1 => 2, relu), + layer1 = Dense(1 => 2, relu), layer2 = Dense(2 => 3, sigmoid), fork = Parallel(+, Chain( @@ -1032,7 +1032,7 @@ @testset "Named Chain SkipConnection" begin org = Chain( - layer1 =Dense(1 => 2, relu), + layer1 = Dense(1 => 2, relu), layer2 = Dense(2 => 3, sigmoid), fork = SkipConnection( Chain( @@ -1050,7 +1050,7 @@ @testset "Named Chain CompGraph" begin org = Chain( - layer1 =Dense(1 => 2, relu), + layer1 = Dense(1 => 2, relu), layer2 = Dense(2 => 3, sigmoid), graph = let iv = denseinputvertex("graphin", 3)