Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fast paths + type fixes #2137

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,13 @@ end

function (a::Dense)(x::AbstractVecOrMat)
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
return σ.(a.weight * x .+ a.bias)
xT = _match_eltype(a, eltype(a.weight), x) # fixes Float64 input, etc.
return σ.(a.weight * xT .+ a.bias)
end

(a::Dense{typeof(identity), <:AbstractMatrix, Bool})(x::AbstractVecOrMat) =
a.weight * _match_eltype(a, eltype(a.weight), x) # fast path, no broadcast

(a::Dense)(x::AbstractArray) =
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

Expand All @@ -185,6 +189,22 @@ end
Dense(W::LinearAlgebra.Diagonal, bias = true, σ = identity) =
Scale(W.diag, bias, σ)

_match_eltype(layer, ::Type{T}, x::AbstractArray{T}) where {T} = x # best case
function _match_eltype(layer, ::Type{Float32}, x::AbstractArray{Float64}) # common mistake
@warn "Layer with Float32 parameters got Float64 input.
The input will be converted, but any earlier layers may be very slow" layer summary(x) maxlog=1
convert(AbstractArray{Float32}, x)
end
function _match_eltype(layer, ::Type{T}, x::AbstractArray{<:Union{AbstractFloat, Integer}}) where {T}
convert(AbstractArray{T}, x)
end
_match_eltype(layer, ::Type, x::OneHotLike) = x
_match_eltype(layer, ::Type, x::AbstractArray) = x # weird types

function ChainRulesCore.rrule(::typeof(_match_eltype), layer, ::Type{T}, x::AbstractArray) where {T}
_match_eltype(layer, T, x), dx -> (NoTangent(), ZeroTangent(), NoTangent(), dx) # does not un-thunk dx
end

"""
Scale(size::Integer..., σ=identity; bias=true, init=ones32)
Scale(scale::AbstractArray, [bias, σ])
Expand Down Expand Up @@ -246,6 +266,9 @@ function (a::Scale)(x::AbstractArray)
σ.(a.scale .* x .+ a.bias)
end

(a::Scale{typeof(identity), <:AbstractArray, Bool})(x::AbstractArray) =
a.scale .* x

function Base.show(io::IO, l::Scale)
print(io, "Scale(", join(size(l.scale), ", "))
l.σ == identity || print(io, ", ", l.σ)
Expand Down Expand Up @@ -421,6 +444,7 @@ Bilinear((in12, out)::Pair{<:Integer, <:Integer}, σ = identity; kw...) = Biline

function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
W, b, σ = a.weight, a.bias, a.σ
xT = _match_eltype(a, eltype(a.weight), x)

d_z, d_x, d_y = size(W)
d_x == size(x,1) && d_y == size(y,1) || throw(DimensionMismatch("number of rows in data must match W"))
Expand All @@ -430,7 +454,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
Wy = reshape(reshape(W, (:, d_y)) * y, (d_z, d_x, :))

# @einsum Z[o,s] := Wy[o,i,s] * x[i,s]
Wyx = batched_mul(Wy, reshape(x, (d_x, 1, :)))
Wyx = batched_mul(Wy, reshape(xT, (d_x, 1, :)))
Z = reshape(Wyx, (d_z, :))

# @einsum out[o,s] := σ(Z[o,i] + b[o])
Expand Down
23 changes: 20 additions & 3 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,12 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
function (c::Conv)(x::AbstractArray)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_dims(c, x)
σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c))
xT = _match_eltype(c, eltype(c.weight), x)
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
end
function (c::Conv{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray)
cdims = conv_dims(c, x)
conv(x, c.weight, cdims) # fast path, no broadcast
end

_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
Expand Down Expand Up @@ -330,7 +335,13 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
function (c::ConvTranspose)(x::AbstractArray)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_transpose_dims(c, x)
σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c))
xT = _match_eltype(c, eltype(c.weight), x)
σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
end
function (c::ConvTranspose{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray)
cdims = conv_transpose_dims(c, x)
xT = _match_eltype(c, eltype(c.weight), x)
∇conv_data(xT, c.weight, cdims) # fast path, no broadcast
end

function Base.show(io::IO, l::ConvTranspose)
Expand Down Expand Up @@ -468,7 +479,13 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
function (c::CrossCor)(x::AbstractArray)
σ = NNlib.fast_act(c.σ, x)
cdims = crosscor_dims(c, x)
σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c))
xT = _match_eltype(c, eltype(c.weight), x)
σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
end
function (c::CrossCor{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray)
cdims = crosscor_dims(c, x)
xT = _match_eltype(c, eltype(c.weight), x)
crosscor(xT, c.weight, cdims) # fast path, no broadcast
end

function Base.show(io::IO, l::CrossCor)
Expand Down
25 changes: 12 additions & 13 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ true
```
"""
struct LayerNorm{F,D,T,N}
λ::F
λ::F # this field is not used
diag::D
ϵ::T
size::NTuple{N,Int}
Expand Down Expand Up @@ -254,16 +254,16 @@ function _norm_layer_forward(
end
end

o = _norm_layer_forward(x, μ, σ², l.ϵ)
hasaffine(l) || return l.λ.(o)

γ = reshape(l.γ, affine_shape)
β = reshape(l.β, affine_shape)
return l.λ.(γ .* o .+ β)
s = (inv∘sqrt).(σ² .+ l.ϵ) # faster to un-fuse this, fewer inv∘sqrt calls
if hasaffine(l)
γ = reshape(l.γ, affine_shape) # ideally reshape on construction?
β = reshape(l.β, affine_shape)
return l.λ.(γ .* s .* (x .- μ) .+ β)
else
return l.λ.(s .* (x .- μ))
end
end

@inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ)

function _track_stats!(
bn, x::AbstractArray{T, N}, μ, σ², reduce_dims,
) where {T, N}
Expand Down Expand Up @@ -356,10 +356,9 @@ end
@functor BatchNorm
trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)

function (BN::BatchNorm)(x)
@assert size(x, ndims(x)-1) == BN.chs
N = ndims(x)
reduce_dims = [1:N-2; N]
function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N}
size(x, N-1) == BN.chs || error("BatchNorm expected an input with $(BN.chs) channels, got size(x) == $(size(x))")
reduce_dims = ntuple(d -> d + (d==N-1), N-1) # i.e. 1:N with N-1 removed
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change hits the following failure:

julia> Zygote.hessian_reverse(x -> sum(sin, mean(x.^2; dims=1)), [1 2; 3 4.0])
4×4 Matrix{Float64}:
 1.24259  2.87677  0.0      0.0
 2.87677  8.91398  0.0      0.0
 0.0      0.0      1.33701  4.35217
 0.0      0.0      4.35217  7.86527

julia> Zygote.hessian_reverse(x -> sum(sin, mean(x.^2; dims=[1])), [1 2; 3 4.0])
4×4 Matrix{Float64}:
 1.24259  2.87677  0.0      0.0
 2.87677  8.91398  0.0      0.0
 0.0      0.0      1.33701  4.35217
 0.0      0.0      4.35217  7.86527

julia> Zygote.hessian_reverse(x -> sum(sin, mean(x.^2; dims=(1,))), [1 2; 3 4.0])
ERROR: Mutating arrays is not supported -- called push!(Vector{Int64}, ...)
Stacktrace:
  [3] (::Zygote.var"#397#398"{Vector{Int64}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/array.jl:105
  [4] (::Zygote.var"#2529#back#399"{Zygote.var"#397#398"{Vector{Int64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [5] unique
    @ ./set.jl:176 [inlined]
  [6] (::typeof((unique)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
  [7] _denom
    @ ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Statistics/statistics.jl:7 [inlined]
  [8] (::typeof((_denom)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
  [9] #rrule#1801
    @ ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Statistics/statistics.jl:13 [inlined]
 [10] (::typeof((#rrule#1801)))(Δ::Tuple{Matrix{Float64}, NamedTuple{(:n, :sum_pullback), Tuple{Float64, Nothing}}})

So it's differentiating this:

https://github.com/JuliaDiff/ChainRules.jl/blob/9a405f732758552cd945a110adb6828a997887a8/src/rulesets/Statistics/statistics.jl#L7

and differentiating the rule for unique, which doesn't handle this case.

Zygote differentiates so many things it need not touch, surely this adds startup time... you only notice when it fails.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of its fatal flaws, you might say. Usually first-order differentiation is well-behaved because control flow and possible mutation are hidden away, but all bets are off with second order...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even first order, I think it does a lot which it need not do. Just most of the resulting errors have already been found. Same thing in trying out Diffractor -- lots of errors from obscure code calculating indices for views or whatever, to a human obviously non-diff.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one fixed in JuliaDiff/ChainRules.jl#687

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's one definite benefit of tracing/overload-based ADs. Anything not numerically interesting gets ignored or falls away in the final tape/graph.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. I presume that any kind of activity tracking would also let you eliminate most off-track things. Maybe declaring integers (and all structs not containing floats) non-diff would also help.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's certainly a lot we could learn from projects like differentiable Swift (which uses activity analysis). It seems unlikely Zygote will be where such knowledge is applied given how poorly integrated it is with the compiler.

Comment on lines +360 to +361
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well take the opportunity to mark these lines and the definition of affine_shape below as ignored. BN causes a decent amount of Zygote compilation latency, so hiding anything that doesn't need to go through AD seems reasonable.

affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
return _norm_layer_forward(BN, x; reduce_dims, affine_shape)
end
Expand Down
20 changes: 12 additions & 8 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,11 @@ end
RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))

function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,I,H,V,T}
function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {F,I,H,V,T}
Wi, Wh, b = m.Wi, m.Wh, m.b
σ = NNlib.fast_act(m.σ, x)
h = σ.(Wi*x .+ Wh*h .+ b)
xT = _match_eltype(m, T, x)::AbstractArray{T} # any AbstractFloat is so converted
h = σ.(Wi*xT .+ Wh*h .+ b)
return h, reshape_cell_output(h, x)
end

Expand Down Expand Up @@ -305,9 +306,10 @@ function LSTMCell((in, out)::Pair;
return cell
end

function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T}
function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T}
b, o = m.b, size(h, 1)
g = muladd(m.Wi, x, muladd(m.Wh, h, b))
xT = _match_eltype(m, T, x)::AbstractArray{T}
g = muladd(m.Wi, xT, muladd(m.Wh, h, b))
input, forget, cell, output = multigate(g, o, Val(4))
c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
h′ = @. sigmoid_fast(output) * tanh_fast(c′)
Expand Down Expand Up @@ -376,9 +378,10 @@ end
GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))

function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T}
function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T}
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
xT = _match_eltype(m, T, x)::AbstractArray{T}
gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
r, z = _gru_output(gxs, ghs, bs)
h̃ = @. tanh_fast(gxs[3] + r * ghs[3] + bs[3])
h′ = @. (1 - z) * h̃ + z * h
Expand Down Expand Up @@ -444,9 +447,10 @@ GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state =
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
init(out, out), init_state(out,1))

function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,HH,T}
function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,HH,T}
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
xT = _match_eltype(m, T, x)::AbstractArray{T}
gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
r, z = _gru_output(gxs, ghs, bs)
h̃ = tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
h′ = @. (1 - z) * h̃ + z * h
Expand Down
16 changes: 16 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,22 @@ import Flux: activations
@test Dense(10, 2, identity, init = ones)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
@test Dense(10, 2, identity, init = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
end
@testset "fast paths, type fixes, ambiguities" begin
d1 = Dense(2 => 3)
d2 = Dense(d1.weight, false)
x1 = randn(Float32, 2, 4)
@test d1(x1) ≈ d2(x1) ≈ d1.weight * x1
x2 = Float64.(x1)
@test d1(x2) ≈ d2(x2) ≈ d1.weight * x2
@test d1(x2) isa Array{Float32}
@test d2(x2) isa Array{Float32}
x3 = rand(-5:5, 2, 4)
@test d1(x3) ≈ d2(x3) ≈ d1.weight * x3
x4 = rand(Bool, 2, 4)
@test d1(x4) ≈ d2(x4) ≈ d1.weight * x4
x5 = Flux.onehotbatch(rand(Bool, 5), (true, false))
@test d1(x5) ≈ d2(x5) ≈ d1.weight * x5
end
end

@testset "Scale" begin
Expand Down
6 changes: 3 additions & 3 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ end
end

# with activation function
let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0;
2.0 4.0 6.0]
let m = BatchNorm(2, sigmoid)
x = Float32[1.0 3.0 5.0; 2.0 4.0 6.0]
y = m(x)
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
@inferred m(x)
@inferred m(x) # fails when x::Matrix{Float64}, do we care?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know why this fails?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not. Checking branches if !_isactive(l) && l.track_stats I get the same types on all paths.

end

let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
Expand Down
5 changes: 3 additions & 2 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ end
@testset "RNN-input-state-eltypes" begin
@testset for R in [RNN, GRU, LSTM, GRUv3]
m = R(3 => 5)
x = rand(Float64, 3, 1)
x = rand(Float64, 3, 1) # Float64 input is now converted
Flux.reset!(m)
@test_throws MethodError m(x)
@test m(x) isa Array{Float32}
@test m.state isa Array{Float32}
end
end

Expand Down