Skip to content

Commit

Permalink
Merge #1776
Browse files Browse the repository at this point in the history
1776: Use conjugates in optimizers to better learn on complex-valued inputs r=DhairyaLGandhi a=staticfloat

When weights are complex, the deltas to them will also be complex.  In
all optimizers that need a second-order estimate of gradient statistics,
we generally want to use the `x * conj(x)` pattern, rather than `x^2`.

We can see the effect this has on ADAM with the following test:

```julia
begin
    # This model will learn `W = I` and `bias = 0`
    complex_init(dims...) = Flux.glorot_uniform(dims...) .+ 1im .* Flux.glorot_uniform(dims...)
    model = Chain(
        Dense(4, 4, tanh; init=complex_init),
        Dense(4, 16, tanh; init=complex_init),
        Dense(16, 4, tanh; init=complex_init),
        Dense(4, 4, tanh; init=complex_init),
    )

    # Loss function; note we don't need the `abs()` if we update `Flux.Losses.mse()` as below
    function loss(x)
        return abs.(Flux.Losses.mse(model(x), x))
    end

    # Keep track of loss from epoch to epoch
    losses = Float64[]
    dataset = [(randn(ComplexF32, 4, 10),)]
    params = Flux.params(model)
    opt = Flux.Optimise.ADAM(0.001)
    for epoch_idx in 1:10000
        Flux.train!(loss, params, dataset, opt)
        epoch_loss = loss(dataset[1][1])
        push!(losses, epoch_loss)
        if epoch_idx % 100 == 0
            `@info("epoch` done", epoch_idx, epoch_loss)
        end
    end

    # Plot the loss
    fig = Figure()
    meta_ax = Axis(fig[1,1])
    lines!(meta_ax, log.(losses); label="Training loss")
    fig[1,2] = Legend(fig, meta_ax, "Learning Stats")
    fig
end
```

The training loss before the fix looks like this:

![without_workaround](https://user-images.githubusercontent.com/130920/142955143-385c5ca9-b2d7-4129-aae0-152741661689.png)


Whereas after both of these commits, it looks like this:

![with_workaround](https://user-images.githubusercontent.com/130920/142955168-807943d7-a2d4-4f7a-82a6-fbab0610e407.png)

Note that while the absolute value of the loss is actually comparable in this simple example, the loss landscape is significantly more chaotic.  With a higher learning rate, the "fixed" version is able to learn much faster: 

![download-1](https://user-images.githubusercontent.com/130920/142955367-e945e6c2-7045-42f7-8a7f-9135ee40c5b4.png)

Whereas the unfixed version simply diverges:

![download-2](https://user-images.githubusercontent.com/130920/142955420-8f32bb3c-5add-4fcb-86a6-eff7fac6dfaf.png)




Co-authored-by: Elliot Saba <staticfloat@gmail.com>
  • Loading branch information
bors[bot] and staticfloat authored Nov 30, 2021
2 parents bb88c55 + 8c3d852 commit cbc1275
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/losses/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ julia> Flux.mse(y_model, y_true)
"""
function mse(ŷ, y; agg = mean)
_check_sizes(ŷ, y)
agg((ŷ .- y) .^ 2)
error =.- y
real(agg(error .* conj(error)))
end

"""
Expand Down
18 changes: 9 additions & 9 deletions src/optimise/optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function apply!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho
acc = get!(() -> zero(x), o.acc, x)::typeof(x)
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ)
@. Δ *= η / (acc + ϵ)
end

Expand Down Expand Up @@ -179,7 +179,7 @@ function apply!(o::ADAM, x, Δ)
end :: Tuple{typeof(x),typeof(x),Vector{Float64}}

@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
βp .= βp .* β

Expand Down Expand Up @@ -221,7 +221,7 @@ function apply!(o::RADAM, x, Δ)
end :: Tuple{typeof(x),typeof(x),Vector{Float64},Ref{Int}}

@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
ρ = ρ∞ - 2t[] * βp[2] / (1 - βp[2])
if ρ > 4
r = sqrt((ρ-4)*-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ))
Expand Down Expand Up @@ -311,7 +311,7 @@ function apply!(o::OADAM, x, Δ)
end :: Tuple{typeof(x),typeof(x),typeof(x),Vector{Float64}}

@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
@. Δ = -Δ_
@. Δ_ = η * mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ)
@. Δ += 2Δ_
Expand Down Expand Up @@ -348,7 +348,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function apply!(o::ADAGrad, x, Δ)
η = o.eta
acc = get!(() -> fill!(similar(x), ϵ), o.acc, x)::typeof(x)
@. acc += Δ^2
@. acc += Δ * conj(Δ)
@. Δ *= η / (acc + ϵ)
end

Expand Down Expand Up @@ -379,11 +379,11 @@ ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
function apply!(o::ADADelta, x, Δ)
ρ = o.rho
acc, Δacc = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)}
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ)
# DON'T remove epsilon from numerator
# or even out of the square roots
@. Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
@. Δacc = ρ * Δacc + (1 - ρ) * Δ * conj(Δ)
return Δ
end

Expand Down Expand Up @@ -463,7 +463,7 @@ function apply!(o::NADAM, x, Δ)
β1p, β2p = βp

@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / ((vt * β[2] / (1 - β2p)) + ϵ) * η
βp .= βp .* β

Expand Down Expand Up @@ -524,7 +524,7 @@ function apply!(o::AdaBelief, x, Δ)
η, β = o.eta, o.beta
mt, st = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)}
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. st = β[2] * st + (1 - β[2]) *- mt)^2
@. st = β[2] * st + (1 - β[2]) *- mt) * conj- mt)
@. Δ = η * mt / ((st) + ϵ)
return Δ
end
Expand Down
3 changes: 3 additions & 0 deletions test/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ y = [1, 1, 0, 0]

@testset "mse" begin
@test mse(ŷ, y) (.1^2 + .9^2)/2

# Test that mse() loss works on complex values:
@test mse(0 + 0im, 1 + 1im) == 2
end

@testset "mae" begin
Expand Down
37 changes: 37 additions & 0 deletions test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,40 @@ end
Flux.update!(opt, θ, gs)
@test w wold .- 0.1
end

# Flux PR #1776
# We need to test that optimisers like ADAM that maintain an internal momentum
# estimate properly calculate the second-order statistics on the gradients as
# the flow backward through the model. Previously, we would calculate second-
# order statistics via `Δ^2` rather than the complex-aware `Δ * conj(Δ)`, which
# wreaks all sorts of havoc on our training loops. This test ensures that
# a simple optimization is montonically decreasing (up to learning step effects)
@testset "Momentum Optimisers and complex values" begin
# Test every optimizer that has momentum internally
for opt_ctor in [ADAM, RMSProp, RADAM, OADAM, ADAGrad, ADADelta, NADAM, AdaBelief]
# Our "model" is just a complex number
w = zeros(ComplexF32, 1)

# Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x`
function loss()
# Deterministic training data is the best training data
x = ones(1, 1) + 1im*ones(1, 1)

# Manually implement `mse()` to allow demonstration of brokenness
# on older Flux builds that don't have a fixed `mse()`
return sum(abs2.(w * x .- conj(x)))
end

params = Flux.Params([w])
opt = opt_ctor(1e-2)

# Train for 10 iterations, enforcing that loss is monotonically decreasing
last_loss = Inf
for idx in 1:10
grads = Flux.gradient(loss, params)
@test loss() < last_loss
last_loss = loss()
Flux.update!(opt, params, grads)
end
end
end

0 comments on commit cbc1275

Please sign in to comment.