From 266c9afd89ef6d1fefd71c1c026123b1a22907a3 Mon Sep 17 00:00:00 2001 From: Bruno Hebling Vieira Date: Tue, 11 Feb 2020 09:36:35 -0500 Subject: [PATCH] WeightNormWeight is now called WeightNormParam WeightNorm for several params, single dim Test for Scalar and Vector dims Test newly created WN equality Simplified some bits Missing last constructor --- src/layers/normalise.jl | 57 ++++++++++++++++++++---------------- test/layers/normalisation.jl | 14 ++++++--- 2 files changed, 42 insertions(+), 29 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 0ee64e15ca..f16db63e1c 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -371,13 +371,13 @@ end Weight Normalization. This layer reparametrizes weights (w) of a layer with its decomposition into magnitude (g) and direction (v). - WeightNorm(layer, weight::Union{Symbol,Int}, dim) + WeightNorm(layer, weight, dim) ``layer`` is the layer being normalized. -``weight`` is the parameter to be normalized. +``weight`` are the parameters to be normalized. -``dim`` is the dimension of normalization. +``dim`` are the dimension of normalization. Often, its the dimension encoding the output channels. Example: @@ -390,55 +390,62 @@ wndB = WeightNorm(d, :W, 1:2); #Now we normalize all directions together, keepin Link : https://arxiv.org/pdf/1602.07868.pdf """ -struct WeightNormWeight{T,N,I} +struct WeightNormParam{T,N,I} g::AbstractArray{T,N} v::AbstractArray{T,N} dim::I end -Base.size(w::WeightNormWeight, i...) = size(w.v, i...) -Base.size(w::WeightNormWeight) = size(w.v) -Base.iterate(w::WeightNormWeight, i...) = iterate(w.g .* w.v ./ WN_mag(w.v, w.dim), i...) -Base.getindex(w::WeightNormWeight, i...) = getindex(w.g .* w.v ./ WN_mag(w.v, w.dim), i...) -Base.ndims(w::WeightNormWeight) = ndims(w.v) +Base.size(w::WeightNormParam, i...) = size(w.v, i...) +Base.size(w::WeightNormParam) = size(w.v) +Base.iterate(w::WeightNormParam, i...) = iterate(w.g .* w.v ./ WN_mag(w.v, w.dim), i...) +Base.getindex(w::WeightNormParam, i...) = getindex(w.g .* w.v ./ WN_mag(w.v, w.dim), i...) +Base.ndims(w::WeightNormParam) = ndims(w.v) +Base.length(w::WeightNormParam) = length(w.v) -Flux.@functor WeightNormWeight +@functor WeightNormParam -WN_mag(p, dim) = sqrt.(sum(abs2.(p), dims = dim)) -WN_dir(p, mag, eps) = p ./ (mag .+ eps) -WN_dir(p, mag) = WN_dir(p, mag, eps(eltype(p))) +WN_mag(p, dim, eps) = sqrt.(sum(abs2.(p), dims = dim)) .+ eps +WN_mag(p, dim) = WN_mag(p, dim, eps(eltype(p))) +WN_dir(p, mag) = p ./ mag import Base.*, Base./, Base.+, Base.- for f in (:+, :-, :*, :/) - @eval ($f)(z::AbstractArray, w::WeightNormWeight) = ($f)(z, w.g .* w.v ./ WN_mag(w.v, w.dim)) - @eval ($f)(w::WeightNormWeight, z::AbstractArray) = ($f)(w.g .* w.v ./ WN_mag(w.v, w.dim), z) + @eval ($f)(z::AbstractArray, w::WeightNormParam) = ($f)(z, w.g .* w.v ./ WN_mag(w.v, w.dim)) + @eval ($f)(w::WeightNormParam, z::AbstractArray) = ($f)(w.g .* w.v ./ WN_mag(w.v, w.dim), z) end -struct WeightNorm{L,E,I,W} +struct WeightNorm{L} layer::L - eps::E - weight::W - dim::I + eps::Number + weight::Vector + dim::Vector end -Flux.@functor WeightNorm +@functor WeightNorm function Base.show(io::IO, wn::WeightNorm) print(io, "WeightNorm(", wn.layer, ", ", wn.weight, ", ", wn.dim, ")") end -function WeightNorm(layer, weight::Union{Symbol,Int}, dim) +function WeightNorm(layer, weight::Vector, dim::Vector) #Expose layer fields and constructor func, re = Flux.functor(layer) #Get the fields par = [getfield(layer, fn) for fn in keys(func)] - w = getfield(layer, weight) - g = WN_mag(w, dim) - v = WN_dir(w, g) - par[findfirst(keys(func) .== weight)] = WeightNormWeight(g, v, dim) + w = map(weight) do W + getfield(layer, W) + end + g = map((W, D) -> WN_mag(W, D), w, dim) + v = map((W, G) -> WN_dir(W, G), w, g) + par[indexin(weight,collect(keys(func)))] = WeightNormParam.(g, v, dim) return WeightNorm(re(par), eps(Float32), weight, dim) end +WeightNorm(layer, weight::Symbol, dim::Vector) = WeightNorm(layer, [weight], dim) +WeightNorm(layer, weight::Symbol, dim::Integer) = WeightNorm(layer, [weight], [dim]) +WeightNorm(layer, weight::Vector, dim::Integer) = WeightNorm(layer, weight, [dim for _ in axes(weight,1)]) + function (wn::WeightNorm)(x) wn.layer(x) end \ No newline at end of file diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index d55dd7e731..fa060a04fc 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -196,20 +196,26 @@ end d = Dense(10, 9, tanh) gs = gradient(() -> sum(abs2, d(fake_data)), params(d)) W = d.W - for WN_dim in [1, 2, 1:2] + for WN_dim in [[1], 1, [2], 2, [1:2]] wnd = WeightNorm(d, :W, WN_dim) gswn = gradient(() -> sum(abs2, wnd(fake_data)), params(wnd)) g = wnd.layer.W.g v = wnd.layer.W.v - normv = sum(abs2, v, dims = WN_dim) ΔW = gs[W] Δg = gswn[g] Δv = gswn[v] - @test sum(ΔW .* v ./ normv, dims = WN_dim) ≈ Δg + @test wnd(fake_data) ≈ d(fake_data) + if isa(WN_dim, Int) + normv = sum(abs2, v, dims = WN_dim) + @test sum(ΔW .* v ./ normv, dims = WN_dim) ≈ Δg + else + normv = sum(abs2, v, dims = WN_dim[1]) + @test sum(ΔW .* v ./ normv, dims = WN_dim[1]) ≈ Δg + end @test g ./ normv .* ΔW - g .* Δg .* v ./ (normv.^2) ≈ Δv @test size(Δv) == size(ΔW) - @test isa(wnd.layer.W, WeightNormWeight) + @test isa(wnd.layer.W, Flux.WeightNormParam) end end end