Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1776: Use conjugates in optimizers to better learn on complex-valued inputs r=DhairyaLGandhi a=staticfloat When weights are complex, the deltas to them will also be complex. In all optimizers that need a second-order estimate of gradient statistics, we generally want to use the `x * conj(x)` pattern, rather than `x^2`. We can see the effect this has on ADAM with the following test: ```julia begin # This model will learn `W = I` and `bias = 0` complex_init(dims...) = Flux.glorot_uniform(dims...) .+ 1im .* Flux.glorot_uniform(dims...) model = Chain( Dense(4, 4, tanh; init=complex_init), Dense(4, 16, tanh; init=complex_init), Dense(16, 4, tanh; init=complex_init), Dense(4, 4, tanh; init=complex_init), ) # Loss function; note we don't need the `abs()` if we update `Flux.Losses.mse()` as below function loss(x) return abs.(Flux.Losses.mse(model(x), x)) end # Keep track of loss from epoch to epoch losses = Float64[] dataset = [(randn(ComplexF32, 4, 10),)] params = Flux.params(model) opt = Flux.Optimise.ADAM(0.001) for epoch_idx in 1:10000 Flux.train!(loss, params, dataset, opt) epoch_loss = loss(dataset[1][1]) push!(losses, epoch_loss) if epoch_idx % 100 == 0 `@info("epoch` done", epoch_idx, epoch_loss) end end # Plot the loss fig = Figure() meta_ax = Axis(fig[1,1]) lines!(meta_ax, log.(losses); label="Training loss") fig[1,2] = Legend(fig, meta_ax, "Learning Stats") fig end ``` The training loss before the fix looks like this: ![without_workaround](https://user-images.githubusercontent.com/130920/142955143-385c5ca9-b2d7-4129-aae0-152741661689.png) Whereas after both of these commits, it looks like this: ![with_workaround](https://user-images.githubusercontent.com/130920/142955168-807943d7-a2d4-4f7a-82a6-fbab0610e407.png) Note that while the absolute value of the loss is actually comparable in this simple example, the loss landscape is significantly more chaotic. With a higher learning rate, the "fixed" version is able to learn much faster: ![download-1](https://user-images.githubusercontent.com/130920/142955367-e945e6c2-7045-42f7-8a7f-9135ee40c5b4.png) Whereas the unfixed version simply diverges: ![download-2](https://user-images.githubusercontent.com/130920/142955420-8f32bb3c-5add-4fcb-86a6-eff7fac6dfaf.png) Co-authored-by: Elliot Saba <staticfloat@gmail.com>
- Loading branch information