Skip to content

Commit

Permalink
Merge pull request #95 from DrChainsaw/fluxnames
Browse files Browse the repository at this point in the history
Try to extract layer names from Chain, Parallel and SkipConnection
  • Loading branch information
DrChainsaw authored Aug 1, 2024
2 parents 63d5faf + ad1de43 commit b63f06c
Show file tree
Hide file tree
Showing 5 changed files with 476 additions and 36 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.2.13"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
NaiveNASflux = "85610aed-7d32-5e57-bb50-4c2e1c9e7997"
NaiveNASlib = "bd45eb3e-47ce-54bd-9eaf-e86c5f900853"
Expand All @@ -17,6 +18,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
ChainRulesCore = "1"
Flux = "0.13, 0.14"
Functors = "0.4"
JuMP = "0.21, 0.22, 0.23, 1"
NaiveNASflux = "2.0.10"
NaiveNASlib = "2.0.11"
Expand Down
1 change: 1 addition & 0 deletions src/ONNXNaiveNASflux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import .BaseOnnx: array
const ONNX = BaseOnnx
using Flux
using Flux: params
import Functors
using NaiveNASflux
using NaiveNASflux: weights, bias
using NaiveNASflux: indim, outdim, actdim, actrank, layertype, wrapped
Expand Down
173 changes: 149 additions & 24 deletions src/serialize/namingutil.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,150 @@
# The naming strategies not well organized and should probably be rewritten. Until that happens,
# here is a little synopsis of whats going on here:

# There are two main ways to create names:
# 1) By just using the function name (or lowercase name of the struct, e.g dense for Flux.Dense)
# and add a running number after to make names unique (e.g. dense_1)

# 2) By intercepting names (e.g. the named tuple keys of Flux.Chain or the names of NaiveNASlibs
# vertices) and make sure all ops "below" in the call stack (e.g. the actual layers) use those names

# Both strategies will be attached to the input ProtoProbe as the (badly named) field nextname
# when we start traversing the call stack.
# The strategy might be changed based on what we encounter. One example of when this happens
# is for the activation functions of the Flux layers which need to be separate nodes in ONNX
# so we make them inherit the name of the layer node (e.g. dense, conv) and suffix them with
# the name of the function.

# For strategy 1) we just use NameRunningNr and thats the end of the story, no fuss!

# For strategy 2) however, there are quite a few moving pieces, three to be exact.
# The whole thing is probably overengineered and could/should be done in a simpler way.

# The first entity is the strategy itself: NamedNodeContext which signals that we are in the
# business of trying to find named things in the call stack. It also remembers what the last
# name we saw was (and the hierarchy of names in case of nesting, e.g. a Chain inside a Parallel
# inside a Chain). Note that for the sake of safety NamedNodeContext wraps a NameRunningNr to
# ensure that we don't end up with duplicate names.

# In addition, whenever we create an input ProtoProbe with the nextname strategy being a
# NamedNodeContext we wrap the ProtoProbe in a NameInterceptProbe which is the second entity in this
# little scheme. This allows us to dispatch on high level (i.e non-primitive) function calls such as
# Chain, Parallel, SkipConnection and MutationVertex to just catch the names we are after as they
# are not visible on the primitive (e.g. Flux.Dense) level.

# We could also have tried to dispatch on something like
# (c::Chain)(pp::ProtoProbe{<:Any, <:NamedNodeContext}), but this has the following problem: What do
# we do then after we have caught the name? We then want Chain to just do its thing on pp, but if
# we call c(pp) we just end up with infinite recursion.

# We also don't want to replace the naming strategy at this point since we want to handle nested
# name holders (i.e the Chain inside a Parallel inside a Chain, or a Chain inside a MutationVertex
# if you'd like).

# Instead we just unwrap the NameInterceptProbe whenever we hit a name holder before we forward the
# call.

# The annoying thing here is that when we hold an object like a Flux.Parallel in our hand we know
# its structure and can see how it names things (e.g. through Functors.fmap_with_path), but we
# don't know for sure in what order the stuff inside it will be called when we call it as a function
# (ok, in practice we know since it has a doc-string contract on what it does, but we also don't
# want to reimplement it and all other things we want to catch names from).

# To prevent that we need to encode in this library what each possible high level name holder (i.e
# Chain, Parallel, etc.) does when called, we use a third entity, the NamedFunction, which just
# wraps anything callable (well, it does not care if it is callable or not, but obviously things
# will not work if it is not) and the name it has. We then use a shallow Functors.fmap_with_path
# to wrap all named things (with methods) inside the name holder in a NamedFunction.

# Yup, you read that right, for example, if Chain(layer1=l1, Layer2=l2) is called with a
# NameInterceptProbe as input it will first be fmap_with_path:ed into
# Chain(layer1=NamedFunction(l1, "layer1"), layer2=NamedFunction(l2, "layer2")), then the new chain
# will be called with the ProtoProbe wrapped inside the NameInterceptProbe.

# When a NamedFunction is called with an AbstractProbe as input, it does two things:
# 1) Wrap the AbstractProbe in a NameInterceptProbe so we can intercept nested name holders (e.g. a
# Chain inside a Parallel inside a Chain).
# 2) Create a new naming strategy for all ONNX ops encountered when calling the function wrapped
# in the NamedFunction which is to use the name of the NamedFunction.

# Note that at step 2) we append the name to any previously encountered names so that nested name holders
# get the correct path. In other words, the new naming strategy is a new NamedNodeContext where we have
# appended the name of the named node to the prefix.

# One option to simplify could have been to just fmap_with_path the entire model before calling it.
# This might turn out to be simpler. The reasons I did not go for it was:
# 1) There would not be any way to name things inside things that are not fmap:able
# 2) It is still not easy to know which things are named (although we could just always use
# fieldnames)
# 3) Less control over what we actually want to wrap inside a NamedFunction (maybe we could use
# dispatch for this when walking though)
# 4) The functor structure for CompGraphs is pretty horrible (but we anyways don't use the
# fmap_with_path method for it since the vertices are the universal name holders for it)
# None of the above seem like showstoppers, the current way was just the path of least resistance.



default_namestrat(f) = name_runningnr()
default_namestrat(g::CompGraph) = default_namestrat(vertices(g))
function default_namestrat(vs::AbstractVector{<:AbstractVertex})
# Even if all vertices have unique names, we can't be certain that no vertex produces more than one node
# Therefore, we must take a pass through name_runningnr for each op even after we have mapped the vertex to a nextname. This is the reason for the v -> f -> namegen(v, "") wierdness
namegen = name_runningnr()
all(isnamed, vs) && length(unique(name.(vs))) == length(name.(vs)) && return v -> f -> namegen(v, "")
ng(v::AbstractVertex) = namegen
ng(f) = namegen(f)
return ng
all(isnamed, vs) && length(unique(name.(vs))) == length(name.(vs)) && return NamedNodeContext("", name_runningnr(;addtofirst=false))
# TODO: Maybe we should use NamedNodeContext here as well and just add runningnumbers to duplicated names?
name_runningnr()
end

const NameType = Union{Symbol, AbstractString}

function default_namestrat(c::Chain)
!(eltype(keys(c)) <: NameType) && return name_runningnr()
NamedNodeContext("", name_runningnr(;addtofirst=false))
end

struct NameRunningNr{F}
addtofirst::Bool
init::Int
runningnrs::Dict{String, Int}
namefun::F
end

Base.Broadcast.broadcastable(n::NameRunningNr) = Ref(n)

function (n::NameRunningNr)(f)
bname = n.namefun(f)
nextnr = get(n.runningnrs, bname, n.init)
n.runningnrs[bname] = nextnr + 1
return if nextnr != n.init || n.addtofirst
string(bname, "_", nextnr)
else
bname
end
end

name_runningnr(namefun = genname; addtofirst=true, init=0) = NameRunningNr(addtofirst, init, Dict{String, Int}(), namefun)

struct NamedFunction{F}
f::F
name::String
end
NaiveNASlib.base(n::NamedFunction) = n.f
NaiveNASlib.name(n::NamedFunction) = n.name
(n::NamedFunction)(args...; kwargs...) = n.f(args...; kwargs...)

const NamedNode = Union{NamedFunction, MutationVertex}

struct NamedNodeContext{F}
prefix::String
namegen::F
end

Base.Broadcast.broadcastable(n::NamedNodeContext) = Ref(n)

(ctx::NamedNodeContext)(args...) = isempty(ctx.prefix) ? ctx.namegen(args...) : ctx.namegen(ctx.prefix)
function (ctx::NamedNodeContext)(v::Union{NamedNode, AbstractVertex})
nodename = name(v)
# Maybe add some field in NamedFunction to indicate wether it is an array element or a field instead of
# startswith here
sep = isempty(ctx.prefix) || startswith(nodename, '[') ? "" : "."
@set ctx.prefix = string(ctx.prefix, sep, name(v))
end

isnamed(v::AbstractVertex) = isnamed(base(v))
Expand All @@ -21,26 +156,16 @@ isnamed(t::DecoratingTrait) = isnamed(base(t))
isnamed(t) = false
isnamed(::NamedTrait) = true

function name_runningnr(namefun = genname)
exists = Set{String}()

return function(f, init="_0")
bname = namefun(f)
candname = bname * init
next = -1
while candname in exists
next += 1
candname = bname * "_" * string(next)
end
push!(exists, candname)
return candname
end
end

genname(v::AbstractVertex) = name(v)
genname(f::F) where F = lowercase(string(nameof(F)))
genname(::F) where F = lowercase(string(nameof(F)))
genname(s::AbstractString) = s
genname(f::Function) = lowercase(string(f))

recursename(f, namestrat) = recursename(f, namestrat(f))
recursename(f, fname::String) = fname
function recursename(f, ctx::NamedNodeContext)
res = ctx(f)
res isa NamedNodeContext && return recursename(f, res.prefix)
res
end

100 changes: 88 additions & 12 deletions src/serialize/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,29 @@ Idea is that probe will "record" all seen operations based on how methods for th
"""
abstract type AbstractProbe end

nextshape(p::AbstractProbe, f::Function) = f(shape(p))
Base.ndims(p::AbstractProbe) = length(shape(p))

"""
WrappedProbe
Abstract class which promises that it wraps another `AbstractProbe` to which it delegates most of its tasks.
"""
abstract type WrappedProbe <: AbstractProbe end

unwrap(p::WrappedProbe) = p.wrapped

Base.Broadcast.broadcastable(p::WrappedProbe) = Ref(p)

NaiveNASlib.name(p::WrappedProbe) = NaiveNASlib.name(unwrap(p))

nextname(p::WrappedProbe) = nextname(unwrap(p))
add!(p::WrappedProbe, args...) = add!(unwrap(p), args...)
shape(p::WrappedProbe) = shape(unwrap(p))
add_output!(p::WrappedProbe) = add_output!(unwrap(p))
newnamestrat(p::WrappedProbe, args...) = rewrap(p, newnamestrat(unwrap(p), args...))
newfrom(p::WrappedProbe, args...) = rewrap(p, newfrom(unwrap(p), args...))

# Called by several activation functions
Base.oftype(::AbstractProbe, x) = x

Expand All @@ -126,8 +149,9 @@ nextname(p::ProtoProbe) = p.nextname
add!(p::ProtoProbe, n) = add!(p.graph, n)
shape(p::ProtoProbe) = p.shape
add_output!(p::ProtoProbe) = push!(p.graph.output, ONNX.ValueInfoProto(name(p), shape(p)))
Base.ndims(p::AbstractProbe) = length(shape(p))
function inputprotoprobe!(gp, name, shape, namestrat)

inputprotoprobe!(args...) = _inputprotoprobe!(args...)
function _inputprotoprobe!(gp, name, shape, namestrat)
push!(gp.input, ONNX.ValueInfoProto(name, shape))
ProtoProbe(name, shape, namestrat, gp)
end
Expand All @@ -145,12 +169,72 @@ newnamestrat(p::ProtoProbe, f, pname=p.name) = ProtoProbe(pname, p.shape, f, p.g
Return a new `ProtoProbe` with name `outname`. Argument `fshape` is used to determine a new shape (typically a function).
"""
newfrom(p::ProtoProbe, outname::AbstractString, fshape) = ProtoProbe(outname, nextshape(p, fshape), p.nextname, p.graph)
nextshape(p::AbstractProbe, f::Function) = f(shape(p))

add!(gp::ONNX.GraphProto, np::ONNX.NodeProto) = push!(gp.node, np)

add!(gp::ONNX.GraphProto, tp::ONNX.TensorProto) = push!(gp.initializer, tp)

## Don't forget to check if new methods need to be added for any WrappedProbe implementations if you add something here!

# Stuff whose only purpose is to override the name in case this is the naming strategy
# See namingutil for a little story about the design since it is a bit messy :(

"""
NameInterceptProbe <: WrappedProbe
An AbstractProbe which is only used to intercept methods call above the primitive level to catch names (e.g. a `Chain` with a `NamedTuple` as layers).
"""
struct NameInterceptProbe{P<:AbstractProbe} <: WrappedProbe
wrapped::P
end
rewrap(::NameInterceptProbe, p) = NameInterceptProbe(p)

inputprotoprobe!(gp, name, shape, namestrat::NamedNodeContext) = NameInterceptProbe(_inputprotoprobe!(gp, name, shape, namestrat))

# This little dance is just to avoid ambiguities of (n::NamedNode)(pps...) since NamedNode is abstract
# TODO: Generate with macro so we can add types to the union without worry?
(v::MutationVertex)(pps::AbstractProbe...) = _apply_probe_call(v, pps...)
(n::NamedFunction)(pps::AbstractProbe...) = _apply_probe_call(n, pps...)

_apply_probe_call(n::NamedNode, pps::AbstractProbe...) = __apply_probe_call(n, nextname(first(pps)), pps...)

# We are not in the business of catching names here, so just forward the call
__apply_probe_call(n::NamedNode, ::Any, pps::AbstractProbe...) = base(n)(pps...)
# We just want to rewrap probes in NameInterceptProbes for the sole reason that we might encounter other
# objects that we want to intercept names from (e.g. a Chain inside a Parallel)
__apply_probe_call(n::NamedNode, ::NamedNodeContext, pps::AbstractProbe...) = _apply_probe_call(n, map(NameInterceptProbe, pps)...)

function _apply_probe_call(n::NamedNode, pps::NameInterceptProbe...)
ppsname = map(pps) do pp
newnamestrat(pp, nextname(pp)(n))
end
ppout = base(n)(ppsname...)
return newnamestrat(ppout, nextname(pps[1]))
end

# Here we wrap all callable children in a NamedFunction so that we can access the name when the child is called
(c::Chain)(pp::NameInterceptProbe) = _instrument_named_functor(c; addbasename=false)(unwrap(pp))
(p::Parallel)(pp::NameInterceptProbe) = _instrument_named_functor(p; addbasename=false)(unwrap(pp))
(p::SkipConnection)(pp::NameInterceptProbe) = _instrument_named_functor(p; addbasename=false)(unwrap(pp))

# The exclude is to ensure that we only see the fields of w, not the fields of its children
_instrument_named_functor(w; addbasename=true) = Functors.fmap_with_path(w; exclude = (k,c) -> c != w) do keypath, child
_instrument_named_child(child, string(only(keypath)); addbasename)
end

_instrument_named_child(child, name; kwargs...) = !isempty(methods(child)) ? NamedFunction(child, name) : child
_instrument_named_child(child::Tuple, name; addbasename) = ntuple(length(child)) do i
_instrument_named_child(child[i], string(addbasename ? name : "", '[', i, ']'))
end
_instrument_named_child(child::AbstractArray, name; addbasename) = map(enumerate(child)) do (i, elem)
_instrument_named_child(elem, string(addbasename ? name : "", '[', i, ']'))
end
_instrument_named_child(child::NamedTuple{K}, name; addbasename) where K = ntuple(length(child)) do i
_instrument_named_child(child[i], string(addbasename ? string(name, '.') : "", K[i]))
end |> NamedTuple{K}

# End of stuff whose only purpose is to override the name in case this is the naming strategy


"""
Used to get activation functions as [`ONNX.AttributeProto`](@ref)s.
Expand Down Expand Up @@ -224,14 +308,6 @@ function add_outputs!(gp, namestrat, pps::Tuple)
add_output!.(output_pps)
end

# Only purpose is to snag the name in case this is the naming strategy
function (v::NaiveNASlib.MutationVertex)(pps::AbstractProbe...)
ppsname = map(pps) do pp
newnamestrat(pp, nextname(pp)(v))
end
ppout = base(v)(ppsname...)
return newnamestrat(ppout, nextname(pps[1]))
end

actfun(::FluxLayer, l) = l.σ
function weightlayer(lt::FluxParLayer, l, pp, optype;attributes = ONNX.AttributeProto[])
Expand Down Expand Up @@ -419,7 +495,7 @@ globalmaxpool(pp::AbstractProbe, wrap) = globalpool(pp, wrap, "GlobalMaxPool")

function globalpool(pp::AbstractProbe, wrap, type)
gpp = attribfun(s -> ismissing(s) ? s : (1, 1, s[3:end]...), type, pp)
ppnext = newnamestrat(gpp, f -> join([gpp.name, genname(f)], "_"), gpp.name)
ppnext = newnamestrat(gpp, f -> join([name(gpp), genname(f)], "_"), name(gpp))
wpp = wrap(ppnext)
return newnamestrat(wpp, nextname(gpp))
end
Expand Down
Loading

0 comments on commit b63f06c

Please sign in to comment.