Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fast paths + type fixes #2137

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ function (a::Dense)(x::AbstractVecOrMat)
return σ.(a.weight * x .+ a.bias)
end

(a::Dense{typeof(identity), <:AbstractMatrix, Bool})(x::AbstractVecOrMat) =
a.weight * x # fast path, no broadcast

(a::Dense)(x::AbstractArray) =
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

Expand Down Expand Up @@ -246,6 +249,9 @@ function (a::Scale)(x::AbstractArray)
σ.(a.scale .* x .+ a.bias)
end

(a::Scale{typeof(identity), <:AbstractArray, Bool})(x::AbstractArray) =
a.scale .* x

function Base.show(io::IO, l::Scale)
print(io, "Scale(", join(size(l.scale), ", "))
l.σ == identity || print(io, ", ", l.σ)
Expand Down
12 changes: 12 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ function (c::Conv)(x::AbstractArray)
cdims = conv_dims(c, x)
σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c))
end
function (c::Conv{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray)
cdims = conv_dims(c, x)
conv(x, c.weight, cdims) # fast path, no broadcast
end

_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
_channels_out(l::Conv) = size(l.weight, ndims(l.weight))
Expand Down Expand Up @@ -332,6 +336,10 @@ function (c::ConvTranspose)(x::AbstractArray)
cdims = conv_transpose_dims(c, x)
σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c))
end
function (c::ConvTranspose{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray)
cdims = conv_transpose_dims(c, x)
∇conv_data(x, c.weight, cdims) # fast path, no broadcast
end

function Base.show(io::IO, l::ConvTranspose)
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
Expand Down Expand Up @@ -470,6 +478,10 @@ function (c::CrossCor)(x::AbstractArray)
cdims = crosscor_dims(c, x)
σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c))
end
function (c::CrossCor{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray)
cdims = crosscor_dims(c, x)
crosscor(x, c.weight, cdims) # fast path, no broadcast
end

function Base.show(io::IO, l::CrossCor)
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
Expand Down
25 changes: 12 additions & 13 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ true
```
"""
struct LayerNorm{F,D,T,N}
λ::F
λ::F # this field is not used
diag::D
ϵ::T
size::NTuple{N,Int}
Expand Down Expand Up @@ -254,16 +254,16 @@ function _norm_layer_forward(
end
end

o = _norm_layer_forward(x, μ, σ², l.ϵ)
hasaffine(l) || return l.λ.(o)

γ = reshape(l.γ, affine_shape)
β = reshape(l.β, affine_shape)
return l.λ.(γ .* o .+ β)
s = (inv∘sqrt).(σ² .+ l.ϵ) # faster to un-fuse this, fewer inv∘sqrt calls
if hasaffine(l)
γ = reshape(l.γ, affine_shape) # ideally reshape on construction?
β = reshape(l.β, affine_shape)
return l.λ.(γ .* s .* (x .- μ) .+ β)
else
return l.λ.(s .* (x .- μ))
end
end

@inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ)

function _track_stats!(
bn, x::AbstractArray{T, N}, μ, σ², reduce_dims,
) where {T, N}
Expand Down Expand Up @@ -356,10 +356,9 @@ end
@functor BatchNorm
trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)

function (BN::BatchNorm)(x)
@assert size(x, ndims(x)-1) == BN.chs
N = ndims(x)
reduce_dims = [1:N-2; N]
function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N}
size(x, N-1) == BN.chs || error("BatchNorm expected an input with $(BN.chs) channels, got size(x) == $(size(x))")
reduce_dims = ntuple(d -> d + (d==N-1), N-1) # i.e. 1:N with N-1 removed
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change hits the following failure:

julia> Zygote.hessian_reverse(x -> sum(sin, mean(x.^2; dims=1)), [1 2; 3 4.0])
4×4 Matrix{Float64}:
 1.24259  2.87677  0.0      0.0
 2.87677  8.91398  0.0      0.0
 0.0      0.0      1.33701  4.35217
 0.0      0.0      4.35217  7.86527

julia> Zygote.hessian_reverse(x -> sum(sin, mean(x.^2; dims=[1])), [1 2; 3 4.0])
4×4 Matrix{Float64}:
 1.24259  2.87677  0.0      0.0
 2.87677  8.91398  0.0      0.0
 0.0      0.0      1.33701  4.35217
 0.0      0.0      4.35217  7.86527

julia> Zygote.hessian_reverse(x -> sum(sin, mean(x.^2; dims=(1,))), [1 2; 3 4.0])
ERROR: Mutating arrays is not supported -- called push!(Vector{Int64}, ...)
Stacktrace:
  [3] (::Zygote.var"#397#398"{Vector{Int64}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/array.jl:105
  [4] (::Zygote.var"#2529#back#399"{Zygote.var"#397#398"{Vector{Int64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [5] unique
    @ ./set.jl:176 [inlined]
  [6] (::typeof((unique)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
  [7] _denom
    @ ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Statistics/statistics.jl:7 [inlined]
  [8] (::typeof((_denom)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
  [9] #rrule#1801
    @ ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Statistics/statistics.jl:13 [inlined]
 [10] (::typeof((#rrule#1801)))(Δ::Tuple{Matrix{Float64}, NamedTuple{(:n, :sum_pullback), Tuple{Float64, Nothing}}})

So it's differentiating this:

https://github.com/JuliaDiff/ChainRules.jl/blob/9a405f732758552cd945a110adb6828a997887a8/src/rulesets/Statistics/statistics.jl#L7

and differentiating the rule for unique, which doesn't handle this case.

Zygote differentiates so many things it need not touch, surely this adds startup time... you only notice when it fails.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of its fatal flaws, you might say. Usually first-order differentiation is well-behaved because control flow and possible mutation are hidden away, but all bets are off with second order...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even first order, I think it does a lot which it need not do. Just most of the resulting errors have already been found. Same thing in trying out Diffractor -- lots of errors from obscure code calculating indices for views or whatever, to a human obviously non-diff.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one fixed in JuliaDiff/ChainRules.jl#687

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's one definite benefit of tracing/overload-based ADs. Anything not numerically interesting gets ignored or falls away in the final tape/graph.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. I presume that any kind of activity tracking would also let you eliminate most off-track things. Maybe declaring integers (and all structs not containing floats) non-diff would also help.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's certainly a lot we could learn from projects like differentiable Swift (which uses activity analysis). It seems unlikely Zygote will be where such knowledge is applied given how poorly integrated it is with the compiler.

Comment on lines +360 to +361
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well take the opportunity to mark these lines and the definition of affine_shape below as ignored. BN causes a decent amount of Zygote compilation latency, so hiding anything that doesn't need to go through AD seems reasonable.

affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
return _norm_layer_forward(BN, x; reduce_dims, affine_shape)
end
Expand Down
6 changes: 3 additions & 3 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ end
end

# with activation function
let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0;
2.0 4.0 6.0]
let m = BatchNorm(2, sigmoid)
x = Float32[1.0 3.0 5.0; 2.0 4.0 6.0]
y = m(x)
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
@inferred m(x)
@inferred m(x) # fails when x::Matrix{Float64}, do we care?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know why this fails?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not. Checking branches if !_isactive(l) && l.track_stats I get the same types on all paths.

end

let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
Expand Down