Skip to content

Commit

Permalink
Merge pull request #79 from itan1/add-instancenorm
Browse files Browse the repository at this point in the history
Add InstanceNorm
  • Loading branch information
DrChainsaw authored Jul 11, 2023
2 parents ee7f5f0 + 8902277 commit 899343a
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Flatten
Gemm
GlobalAveragePool
GlobalMaxPool
InstanceNormalization
LSTM
LeakyRelu
MatMul
Expand Down
4 changes: 2 additions & 2 deletions src/ONNXNaiveNASflux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ using NaiveNASflux
using NaiveNASflux: weights, bias
using NaiveNASflux: indim, outdim, actdim, actrank, layertype, wrapped
using NaiveNASflux: FluxLayer, FluxParLayer, FluxNoParLayer, FluxDense, FluxConvolutional, FluxConv, FluxConvTranspose,
FluxBatchNorm, FluxRecurrent, FluxRnn, FluxLstm, FluxGru, FluxTransparentLayer, FluxPoolLayer,
FluxDropOut, Flux2D, GenericFluxConvolutional, GenericFlux2D, GenericFluxRecurrent
FluxBatchNorm, FluxInstanceNorm, FluxRecurrent, FluxRnn, FluxLstm, FluxGru, FluxTransparentLayer,
FluxPoolLayer, FluxDropOut, Flux2D, GenericFluxConvolutional, GenericFlux2D, GenericFluxRecurrent
using Setfield
using Statistics
import Pkg
Expand Down
13 changes: 13 additions & 0 deletions src/deserialize/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,19 @@ default_Wb_Rb(Wh_WBh) = fill!(similar(Wh_WBh, (size(Wh_WBh, 2) * 2, size(Wh_WBh,
default_init_h(Wb_Rb, sc) = fill!(similar(Wb_Rb, (size(Wb_Rb,1) ÷ sc, size(Wb_Rb,2))), 0)
# TODO when https://github.com/FluxML/Flux.jl/issues/1279 is resolved default_init_h(Wh_WBh, sc) = fill!(similar(Wh_WBh, (size(Wh_WBh, 2) ÷ sc, size(Wh_WBh, 3))), 0)

actlayers[:InstanceNormalization] = function(params, γ, β)
λ = get(params, :activation, identity)
ϵ = get(params, :epsilon, 1f-5)

# ONNX InstanceNormalization does not support tracking μ and σ²
momentum = NaN32
μ = zeros(length(γ))
σ² = ones(length(γ))

return InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, true, false, nothing, length(γ))
end
fluxlayertypes[:InstanceNormalization] = (pars...) -> FluxInstanceNorm()

fluxrecurrentlayers[:RNN] = function(params, Wi_WBi, Wh_WBh, Wb_Rb=default_Wb_Rb(Wh_WBh), seqlen=[], h3d = default_init_h(Wb_Rb, 2))
@assert size(Wi_WBi, 3) == 1 "Num directions must be 1! Bidirectional (num directions = 2) not supported!" # TODO: Add...

Expand Down
22 changes: 22 additions & 0 deletions src/serialize/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,28 @@ function(l::Flux.BatchNorm)(pp::AbstractProbe)
end
actfun(::FluxBatchNorm, l) = l.λ

function(l::Flux.InstanceNorm)(pp::AbstractProbe)
@assert l.affine == true "ONNX InstanceNormalization does not support affine=false"
@assert l.track_stats == false "ONNX InstanceNormalization does not support track_stats=true"
lname = recursename(l, nextname(pp))
γname, βname = lname .* ("_scale", "_bias")

add!(pp, ONNX.NodeProto(
input=[name(pp), γname, βname],
output=[lname],
name=lname,
attribute = ONNX.AttributeProto.(["epsilon"], [l.ϵ]),
op_type="InstanceNormalization"))

add!(pp, ONNX.TensorProto(l.γ, γname))
add!(pp, ONNX.TensorProto(l.β, βname))


ppout = actfun(layertype(l), l)(newnamestrat(pp, f -> join([lname, genname(f)], "_"), lname))
return newnamestrat(ppout, nextname(pp))
end
actfun(::FluxInstanceNorm, l) = l.λ


# Dropdims because ONNX expects recurrent layers to output tensors of shape [seq_length, num_directions, batch_size, hidden_size] where num_directions is 2 in case of bidirectional and 1 otherwise
# Flux.Recur is not bidirectional so we'll just assume the user wants to also drop num_directions so that recurrent layers can be stacked without hassle.
Expand Down
6 changes: 6 additions & 0 deletions test/deserialize/Artifacts.toml
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ git-tree-sha1 = "377710458916cc790bb7eec00c8e3f0719680cf8"
[test_globalmaxpool_precomputed]
git-tree-sha1 = "6d72b58370176351d46937ca3df65ba2fd114f04"

[test_instancenorm_epsilon]
git-tree-sha1 = "9b14c628a483dd94105ce207fca8e7ca6cb10e45"

[test_instancenorm_example]
git-tree-sha1 = "e63a946ae5c5b3becf847f4e3adaaf92ef9b358f"

[test_leakyrelu]
git-tree-sha1 = "07afe319b71db2cb6bc295ff9409482721473817"

Expand Down
2 changes: 2 additions & 0 deletions test/deserialize/deserialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ end
(name="test_gemm_default_zero_bias", ninputs=3, noutputs=1),
#(name="test_gemm_transposeA", ninputs=3, noutputs=1), Not supported!
(name="test_gemm_transposeB", ninputs=3, noutputs=1),
(name="test_instancenorm_epsilon", ninputs=3, noutputs=1),
(name="test_instancenorm_example", ninputs=3, noutputs=1),
(name="test_lstm_defaults", ninputs=3, noutputs=1),
(name="test_lstm_with_initial_bias", ninputs=4, noutputs=1),
# (name="test_lstm_with_peepholes", ninputs=8, noutputs=1), Not supported!
Expand Down
35 changes: 34 additions & 1 deletion test/serialize/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@
end

@testset "$(tc.layer) node" for tc in (
(layer=BatchNorm(3, relu; initβ = i -> collect(Float32, 1:i), initγ = i -> collect(Float32, i:-1:1), ϵ=1e-3, momentum = 0.78), indata=reshape(collect(Float32, 1:2*3*3), 2,3,3,1) .- 10),
(layer=BatchNorm(3, relu; initβ = i -> collect(Float32, 1:i), initγ = i -> collect(Float32, i:-1:1), eps=1e-3, momentum = 0.78), indata=reshape(collect(Float32, 1:2*3*3), 2,3,3,1) .- 10),
)

inprobe = NodeProbe("input", genname)
Expand Down Expand Up @@ -313,6 +313,39 @@
@test size(ortout) == size(expout)
@test ortout expout
end

@testset "$(tc.layer) node" for tc in (
(layer=InstanceNorm(3, relu, affine=true), indata=reshape(collect(Float32, 1:2*3*3), 2,3,3,1) .- 10),
(layer=InstanceNorm(3, relu, initβ = i -> collect(Float32, 1:i), initγ = i -> collect(Float32, i:-1:1), affine=true, track_stats=false, eps=1f-3), indata=reshape(collect(Float32, 1:2*3*3), 2,3,3,1) .- 10),
)

inprobe = NodeProbe("input", genname)
outprobe = tc.layer(inprobe)
@test length(outprobe.protos) == 4

ln, γ, β, an = Tuple(serdeser.(outprobe.protos))

@test size(β) == size(tc.layer.β)
@test size(γ) == size(tc.layer.γ)

@test β tc.layer.β
@test γ tc.layer.γ

ln.attribute[:activation] = actfuns[Symbol(optype(an))](an.attribute)
res = fluxlayers[optype(ln)](ln.attribute, γ, β)

@test string(res) == string(tc.layer)

resout = res(tc.indata)
expout = tc.layer(tc.indata)

@test size(resout) == size(expout)
@test resout expout

ortout, = onnxruntime_infer(tc.layer, tc.indata)
@test size(ortout) == size(expout)
@test ortout expout
end
end

@testset "Graphs" begin
Expand Down

0 comments on commit 899343a

Please sign in to comment.