From 32537346a53c9795e9933bc2e8b8825b972603ea Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 5 Jun 2022 22:35:52 -0700 Subject: [PATCH] Style fix --- .JuliaFormatter.toml | 1 + examples/NeuralODE/main.jl | 4 +- lib/Boltz/src/utils.jl | 4 +- src/Lux.jl | 5 +++ src/adapt.jl | 6 +-- src/autodiff.jl | 5 ++- src/core.jl | 18 ++++---- src/layers/basic.jl | 86 +++++++++++++++++++++----------------- src/layers/conv.jl | 4 +- src/layers/display.jl | 4 +- src/layers/normalize.jl | 6 +-- src/layers/recurrent.jl | 18 ++++---- src/nnlib.jl | 8 ++-- src/random.jl | 2 - src/utils.jl | 10 ++--- test/layers/basic.jl | 4 +- test/layers/recurrent.jl | 2 +- test/runtests.jl | 4 +- 18 files changed, 102 insertions(+), 89 deletions(-) delete mode 100644 src/random.jl diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index e56de732a7..f024c09775 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,2 +1,3 @@ style = "sciml" whitespace_in_kwargs = false +always_use_return = true diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index 062ad87d47..91e20d2cfd 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -14,7 +14,7 @@ CUDA.allowscalar(false) # ## Loading MNIST ## Use MLDataUtils LabelEnc for natural onehot conversion function onehot(labels_raw) - convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) + return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) end function loadmnist(batchsize, train_split) @@ -66,7 +66,7 @@ function (n::NeuralODE)(x, ps, st) end function diffeqsol_to_array(x::ODESolution{T, N, <:AbstractVector{<:CuArray}}) where {T, N} - dropdims(gpu(x); dims=3) + return dropdims(gpu(x); dims=3) end diffeqsol_to_array(x::ODESolution) = dropdims(Array(x); dims=3) diff --git a/lib/Boltz/src/utils.jl b/lib/Boltz/src/utils.jl index 39c0412172..b2a716c10e 100644 --- a/lib/Boltz/src/utils.jl +++ b/lib/Boltz/src/utils.jl @@ -5,10 +5,10 @@ Type-stable and faster version of `MLUtils.chunk` """ @inline fast_chunk(h::Int, n::Int) = (1:h) .+ h * (n - 1) @inline function fast_chunk(x::AbstractArray, h::Int, n::Int, ::Val{dim}) where {dim} - selectdim(x, dim, fast_chunk(h, n)) + return selectdim(x, dim, fast_chunk(h, n)) end @inline function fast_chunk(x::AbstractArray, ::Val{N}, d::Val{D}) where {N, D} - fast_chunk.((x,), size(x, D) ÷ N, 1:N, d) + return fast_chunk.((x,), size(x, D) ÷ N, 1:N, d) end """ diff --git a/src/Lux.jl b/src/Lux.jl index abc02b761f..4b8c87f830 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -22,6 +22,11 @@ using Requires const use_cuda = Ref{Union{Nothing, Bool}}(nothing) +# NOTE: In theory, we can support any parameter type which allows us to access +# parameters using getproperty. But I will be conservative here and only specify +# NamedTuple and ComponentArray until we have tested other cases properly. +const VALID_PARAMETER_TYPES = Union{NamedTuple, ComponentArray} + # Data Transfer Utilities include("adapt.jl") # Utilities diff --git a/src/adapt.jl b/src/adapt.jl index e6ded3a9e7..a9f2af4207 100644 --- a/src/adapt.jl +++ b/src/adapt.jl @@ -7,7 +7,7 @@ adapt_storage(::LuxCUDAAdaptor, x) = CUDA.cu(x) adapt_storage(::LuxCUDAAdaptor, x::FillArrays.AbstractFill) = CUDA.cu(colelct(x)) adapt_storage(::LuxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x)) function adapt_storage(to::LuxCUDAAdaptor, x::ComponentArray) - ComponentArray(adapt_storage(to, getdata(x)), getaxes(x)) + return ComponentArray(adapt_storage(to, getdata(x)), getaxes(x)) end adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng @@ -18,13 +18,13 @@ function adapt_storage(::LuxCPUAdaptor, end adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) function adapt_storage(to::LuxCPUAdaptor, x::ComponentArray) - ComponentArray(adapt_storage(to, getdata(x)), getaxes(x)) + return ComponentArray(adapt_storage(to, getdata(x)), getaxes(x)) end adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng # TODO: SparseArrays function adapt_storage(::LuxCPUAdaptor, x::CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix) - adapt(Array, x) + return adapt(Array, x) end _isbitsarray(::AbstractArray{<:Number}) = true diff --git a/src/autodiff.jl b/src/autodiff.jl index 470b478440..d8c1a1715f 100644 --- a/src/autodiff.jl +++ b/src/autodiff.jl @@ -14,7 +14,7 @@ ChainRulesCore.Tangent{P}(; kwargs...) where {P <: AbstractExplicitLayer} = NoTa ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),) function ChainRulesCore.rrule(::typeof(Base.broadcasted), ::typeof(identity), x) - x, Δ -> (NoTangent(), NoTangent(), Δ) + return x, Δ -> (NoTangent(), NoTangent(), Δ) end # NNlib Functions @@ -43,7 +43,8 @@ function ChainRulesCore.rrule(::typeof(dropout), t::Val{training}) where {T, N, training} y, mask, rng = dropout(rng, x, p, q, dims, t) function dropout_pullback((dy, dmask, drng)) - return (NoTangent(), NoTangent(), elementwise_mul(dy, mask), NoTangent(), NoTangent(), + return (NoTangent(), NoTangent(), elementwise_mul(dy, mask), NoTangent(), + NoTangent(), NoTangent(), NoTangent()) end return (y, mask, rng), dropout_pullback diff --git a/src/core.jl b/src/core.jl index 7fcc1d89af..94b164c9c0 100644 --- a/src/core.jl +++ b/src/core.jl @@ -15,7 +15,7 @@ Generate the initial parameters of the layer `l`. """ initialparameters(::AbstractRNG, ::Any) = NamedTuple() function initialparameters(rng::AbstractRNG, l::NamedTuple) - map(Base.Fix1(initialparameters, rng), l) + return map(Base.Fix1(initialparameters, rng), l) end """ @@ -32,10 +32,10 @@ initialstates(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1(initialstates, rn Return the total number of parameters of the layer `l`. """ function parameterlength(l::AbstractExplicitLayer) - parameterlength(initialparameters(Random.default_rng(), l)) + return parameterlength(initialparameters(Random.default_rng(), l)) end function parameterlength(nt::Union{NamedTuple, Tuple}) - length(nt) == 0 ? 0 : sum(parameterlength, nt) + return length(nt) == 0 ? 0 : sum(parameterlength, nt) end parameterlength(a::AbstractArray) = length(a) parameterlength(x) = 0 @@ -57,7 +57,7 @@ statelength(x) = 0 Shorthand for getting the parameters and states of the layer `l`. Is equivalent to `(initialparameters(rng, l), initialstates(rng, l))`. """ function setup(rng::AbstractRNG, l::AbstractExplicitLayer) - (initialparameters(rng, l), initialstates(rng, l)) + return (initialparameters(rng, l), initialstates(rng, l)) end """ @@ -65,15 +65,15 @@ end Simply calls `model(x, ps, st)` """ -function apply(model::AbstractExplicitLayer, x, ps::Union{ComponentArray, NamedTuple}, +function apply(model::AbstractExplicitLayer, x, ps::VALID_PARAMETER_TYPES, st::NamedTuple) - model(x, ps, st) + return model(x, ps, st) end function Base.show(io::IO, x::AbstractExplicitLayer) __t = rsplit(string(get_typename(x)), "."; limit=2) T = length(__t) == 2 ? __t[2] : __t[1] - print(io, "$T()") + return print(io, "$T()") end # Abstract Container Layers @@ -92,11 +92,11 @@ function initialstates(rng::AbstractRNG, end function parameterlength(l::AbstractExplicitContainerLayer{layers}) where {layers} - sum(parameterlength, getfield.((l,), layers)) + return sum(parameterlength, getfield.((l,), layers)) end function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} - sum(statelength, getfield.((l,), layers)) + return sum(statelength, getfield.((l,), layers)) end function Base.keys(l::AbstractExplicitContainerLayer{layers}) where {layers} diff --git a/src/layers/basic.jl b/src/layers/basic.jl index c2b4472a3c..ec2ce3b638 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -25,7 +25,7 @@ end end function Base.show(io::IO, r::ReshapeLayer) - print(io, "ReshapeLayer(output_dims = (", join(r.dims, ", "), ", :))") + return print(io, "ReshapeLayer(output_dims = (", join(r.dims, ", "), ", :))") end """ @@ -72,11 +72,12 @@ struct SelectDim{dim, index} <: AbstractExplicitLayer end SelectDim(dim, index) = SelectDim{Val(dim), Val(index)}() @inline function (s::SelectDim{dim, index})(x, ps, st::NamedTuple) where {dim, index} - selectdim(x, get_known(dim), get_known(index)), st + return selectdim(x, get_known(dim), get_known(index)), st end function Base.show(io::IO, s::SelectDim{dim, index}) where {dim, index} - print(io, "SelectDim(dim = ", get_known(dim), ", index = ", get_known(index), ")") + return print(io, "SelectDim(dim = ", get_known(dim), ", index = ", get_known(index), + ")") end """ @@ -181,7 +182,7 @@ struct SkipConnection{T <: AbstractExplicitLayer, F} <: connection::F end -@inline function (skip::SkipConnection)(x, ps::Union{ComponentArray, NamedTuple}, +@inline function (skip::SkipConnection)(x, ps::VALID_PARAMETER_TYPES, st::NamedTuple) mx, st = skip.layers(x, ps, st) return skip.connection(mx, x), st @@ -226,12 +227,12 @@ function Parallel(connection, layers...) return Parallel(connection, NamedTuple{names}(layers)) end -function (m::Parallel)(x, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) +function (m::Parallel)(x, ps::VALID_PARAMETER_TYPES, st::NamedTuple) return applyparallel(m.layers, m.connection, x, ps, st) end @generated function applyparallel(layers::NamedTuple{names}, connection::C, x::T, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) where {names, C, T} N = length(names) y_symbols = [gensym() for _ in 1:(N + 1)] @@ -309,12 +310,12 @@ function BranchLayer(layers...) return BranchLayer(NamedTuple{names}(layers)) end -function (m::BranchLayer)(x, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) - applybranching(m.layers, x, ps, st) +function (m::BranchLayer)(x, ps::VALID_PARAMETER_TYPES, st::NamedTuple) + return applybranching(m.layers, x, ps, st) end @generated function applybranching(layers::NamedTuple{names}, x, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) where {names} N = length(names) y_symbols = [gensym() for _ in 1:N] @@ -393,12 +394,12 @@ function PairwiseFusion(connection, layers...) return PairwiseFusion(connection, NamedTuple{names}(layers)) end -function (m::PairwiseFusion)(x, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) +function (m::PairwiseFusion)(x, ps::VALID_PARAMETER_TYPES, st::NamedTuple) return applypairwisefusion(m.layers, m.connection, x, ps, st) end @generated function applypairwisefusion(layers::NamedTuple{names}, connection::C, x::T, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) where {names, C, T} N = length(names) y_symbols = [gensym() for _ in 1:(N + 1)] @@ -418,6 +419,7 @@ end """ Chain(layers...; disable_optimizations::Bool = false) + Chain(; kwargs...) Chain(nt::NamedTuple) Collects multiple layers / functions to be called in sequence on a given input. @@ -425,7 +427,9 @@ Collects multiple layers / functions to be called in sequence on a given input. ## Arguments * `layers`: A list of `N` Lux layers -* `nt`: Alternatively, a NamedTuple `nt` can be passed to enable prettier naming for the layers +Alternatively either of the following could be done: +* `nt`: a NamedTuple `nt` can be passed to enable prettier naming for the layers +* `kwargs`: Pass the layers and the layer names as keyword arguments ## Keyword Arguments @@ -442,11 +446,11 @@ Input `x` is passed sequentially to each layer, and must conform to the input re ## Parameters -* Parameters of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` +* Parameters of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` or custom names ## States -* States of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` +* States of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` or custom names ## Optimizations @@ -463,8 +467,8 @@ Performs a few optimizations to generate reasonable architectures. Can be disabl ```julia c = Chain(Dense(2, 3, relu), BatchNorm(3), Dense(3, 2)) -c = Chain((feature_extractor=Chain(Dense(2, 3, relu), BatchNorm(3)), - classifier=Chain(Dense(3, 2)))) +c = Chain(feature_extractor=Chain(Dense(2, 3, relu), BatchNorm(3)), + classifier=Chain(Dense(3, 2))) ``` """ struct Chain{T} <: AbstractExplicitContainerLayer{(:layers,)} @@ -478,13 +482,17 @@ struct Chain{T} <: AbstractExplicitContainerLayer{(:layers,)} return new{typeof(layers)}(layers) end function Chain(xs::AbstractVector; disable_optimizations::Bool=false) - Chain(xs...; disable_optimizations) + return Chain(xs...; disable_optimizations) end function Chain(nt::NamedTuple; disable_optimizations::Bool=false) - disable_optimizations && - @warn "Chain(::NamedTuple) ignores `disable_optimizations`" maxlog=1 + if disable_optimizations + error("`Chain(::NamedTuple)` or `Chain(; kwargs...)` doesn't accept `disable_optimizations`") + end return new{typeof(nt)}(nt) end + function Chain(; disable_optimizations::Bool=false, kwargs...) + return Chain((; kwargs...); disable_optimizations) + end end function flatten_model(layers::Union{AbstractVector, Tuple}) @@ -494,7 +502,7 @@ function flatten_model(layers::Union{AbstractVector, Tuple}) if f isa Tuple || f isa AbstractVector append!(new_layers, f) elseif f isa Function - if !hasmethod(f, (Any, Union{ComponentArray, NamedTuple}, NamedTuple)) + if !hasmethod(f, (Any, VALID_PARAMETER_TYPES, NamedTuple)) if f === identity continue else @@ -516,12 +524,12 @@ end flatten_model(x) = x -function (c::Chain)(x, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) - applychain(c.layers, x, ps, st) +function (c::Chain)(x, ps::VALID_PARAMETER_TYPES, st::NamedTuple) + return applychain(c.layers, x, ps, st) end @generated function applychain(layers::NamedTuple{fields}, x, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple{fields}) where {fields} N = length(fields) x_symbols = [gensym("x") for _ in 1:N] @@ -617,19 +625,19 @@ end statelength(d::Dense) = 0 @inline function (d::Dense{false})(x::AbstractVecOrMat, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) return applyactivation(d.activation, ps.weight * x), st end @inline function (d::Dense{false, typeof(identity)})(x::AbstractVecOrMat, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) return ps.weight * x, st end @inline function (d::Dense{false})(x::AbstractArray, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) sz = size(x) x_reshaped = reshape(x, sz[1], :) @@ -638,36 +646,36 @@ end end @inline function (d::Dense{false, typeof(identity)})(x::AbstractArray, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) sz = size(x) x_reshaped = reshape(x, sz[1], :) return reshape(ps.weight * x_reshaped, d.out_dims, sz[2:end]...), st end -@inline function (d::Dense{true})(x::AbstractVector, ps::Union{ComponentArray, NamedTuple}, +@inline function (d::Dense{true})(x::AbstractVector, ps::VALID_PARAMETER_TYPES, st::NamedTuple) return applyactivation(d.activation, elementwise_add(ps.weight * x, vec(ps.bias))), st end @inline function (d::Dense{true, typeof(identity)})(x::AbstractVector, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) return elementwise_add(ps.weight * x, vec(ps.bias)), st end -@inline function (d::Dense{true})(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple}, +@inline function (d::Dense{true})(x::AbstractMatrix, ps::VALID_PARAMETER_TYPES, st::NamedTuple) return applyactivation(d.activation, elementwise_add(ps.weight * x, ps.bias)), st end @inline function (d::Dense{true, typeof(identity)})(x::AbstractMatrix, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) return elementwise_add(ps.weight * x, ps.bias), st end -@inline function (d::Dense{true})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, +@inline function (d::Dense{true})(x::AbstractArray, ps::VALID_PARAMETER_TYPES, st::NamedTuple) sz = size(x) x_reshaped = reshape(x, sz[1], :) @@ -677,7 +685,7 @@ end end @inline function (d::Dense{true, typeof(identity)})(x::AbstractArray, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) sz = size(x) x_reshaped = reshape(x, sz[1], :) @@ -740,7 +748,7 @@ function Scale(dims::Tuple{Vararg{Integer}}, activation=identity; end function Scale(s1::Integer, s23::Integer...; _act=identity, kw...) - Scale(tuple(s1, s23...), _act; kw...) + return Scale(tuple(s1, s23...), _act; kw...) end Scale(size_act...; kw...) = Scale(size_act[1:(end - 1)]...; _act=size_act[end], kw...) @@ -748,31 +756,31 @@ function initialparameters(rng::AbstractRNG, d::Scale{true}) return (weight=d.init_weight(rng, d.dims...), bias=d.init_bias(rng, d.dims...)) end function initialparameters(rng::AbstractRNG, d::Scale{false}) - (weight=d.init_weight(rng, d.dims...),) + return (weight=d.init_weight(rng, d.dims...),) end parameterlength(d::Scale{bias}) where {bias} = (1 + bias) * prod(d.dims) statelength(d::Scale) = 0 -function (d::Scale{true})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, +function (d::Scale{true})(x::AbstractArray, ps::VALID_PARAMETER_TYPES, st::NamedTuple) return applyactivation(d.activation, elementwise_add(elementwise_mul(ps.weight, x), ps.bias)), st end function (d::Scale{true, typeof(identity)})(x::AbstractArray, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) return elementwise_add(elementwise_mul(ps.weight, x), ps.bias), st end -function (d::Scale{false})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, +function (d::Scale{false})(x::AbstractArray, ps::VALID_PARAMETER_TYPES, st::NamedTuple) return applyactivation(d.activation, elementwise_mul(ps.weight, x)), st end function (d::Scale{false, typeof(identity)})(x::AbstractArray, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) return elementwise_mul(ps.weight, x), st end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 908cbf0dde..0fcf07734a 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -92,14 +92,14 @@ function parameterlength(c::Conv{N, bias}) where {N, bias} end @inline function (c::Conv{N, false})(x::AbstractArray, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) where {N} cdims = DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups) return applyactivation(c.activation, conv_wrapper(x, ps.weight, cdims)), st end -@inline function (c::Conv{N, true})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, +@inline function (c::Conv{N, true})(x::AbstractArray, ps::VALID_PARAMETER_TYPES, st::NamedTuple) where {N} cdims = DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups) diff --git a/src/layers/display.jl b/src/layers/display.jl index 641f4c829a..c3ec81f7e6 100644 --- a/src/layers/display.jl +++ b/src/layers/display.jl @@ -66,7 +66,7 @@ function _get_children(l::AbstractExplicitContainerLayer{names}) where {names} return NamedTuple{names}(getfield.((l,), names)) end function _get_children(p::Parallel) - p.connection === nothing ? p.layers : (p.connection, p.layers...) + return p.connection === nothing ? p.layers : (p.connection, p.layers...) end _get_children(s::SkipConnection) = (s.layers, s.connection) _get_children(s::WeightNorm) = (s.layer,) @@ -116,7 +116,7 @@ end # utility functions function underscorise(n::Integer) - join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') + return join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') end function _nan_show(io::IO, x) diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index b8a5a4db0b..e15d995a51 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -102,7 +102,7 @@ end parameterlength(l::BatchNorm{affine}) where {affine} = affine ? (l.chs * 2) : 0 function statelength(l::BatchNorm{affine, track_stats}) where {affine, track_stats} - (track_stats ? 2 * l.chs : 0) + 1 + return (track_stats ? 2 * l.chs : 0) + 1 end function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} @@ -279,7 +279,7 @@ end parameterlength(l::GroupNorm{affine}) where {affine} = affine ? (l.chs * 2) : 0 function statelength(l::GroupNorm{affine, track_stats}) where {affine, track_stats} - (track_stats ? 2 * l.groups : 0) + 1 + return (track_stats ? 2 * l.groups : 0) + 1 end function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} @@ -379,7 +379,7 @@ end initialstates(rng::AbstractRNG, wn::WeightNorm) = initialstates(rng, wn.layer) -function (wn::WeightNorm)(x, ps::Union{ComponentArray, NamedTuple}, s::NamedTuple) +function (wn::WeightNorm)(x, ps::VALID_PARAMETER_TYPES, s::NamedTuple) _ps = get_normalized_parameters(wn, wn.dims, ps.normalized) return wn.layer(x, merge(_ps, ps.unnormalized), s) end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 2423d665be..8de7467ce8 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -69,7 +69,7 @@ function initialstates(rng::AbstractRNG, ::RNNCell) return (rng=replicate(rng),) end -function (rnn::RNNCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple}, +function (rnn::RNNCell)(x::AbstractMatrix, ps::VALID_PARAMETER_TYPES, st::NamedTuple) rng = replicate(st.rng) @set! st.rng = rng @@ -78,7 +78,7 @@ function (rnn::RNNCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple} end function (rnn::RNNCell{true})((x, hidden_state)::Tuple{<:AbstractMatrix, <:AbstractMatrix}, - ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) + ps::VALID_PARAMETER_TYPES, st::NamedTuple) h_new = rnn.activation.(ps.weight_ih * x .+ ps.weight_hh * hidden_state .+ ps.bias) return h_new, st end @@ -86,14 +86,14 @@ end function (rnn::RNNCell{true, typeof(identity)})((x, hidden_state)::Tuple{<:AbstractMatrix, <:AbstractMatrix}, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state .+ ps.bias return h_new, st end function (rnn::RNNCell{false})((x, hidden_state)::Tuple{<:AbstractMatrix, <:AbstractMatrix}, - ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) + ps::VALID_PARAMETER_TYPES, st::NamedTuple) h_new = rnn.activation.(ps.weight_ih * x .+ ps.weight_hh * hidden_state) return h_new, st end @@ -101,7 +101,7 @@ end function (rnn::RNNCell{false, typeof(identity)})((x, hidden_state)::Tuple{<:AbstractMatrix, <:AbstractMatrix}, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state return h_new, st @@ -196,7 +196,7 @@ function initialstates(rng::AbstractRNG, ::LSTMCell) return (rng=replicate(rng),) end -function (lstm::LSTMCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple}, +function (lstm::LSTMCell)(x::AbstractMatrix, ps::VALID_PARAMETER_TYPES, st::NamedTuple) rng = replicate(st.rng) @set! st.rng = rng @@ -208,7 +208,7 @@ end function (lstm::LSTMCell)((x, hidden_state, memory)::Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}, - ps::Union{ComponentArray, NamedTuple}, + ps::VALID_PARAMETER_TYPES, st::NamedTuple) g = ps.weight_i * x .+ ps.weight_h * hidden_state .+ ps.bias input, forget, cell, output = multigate(g, Val(4)) @@ -294,7 +294,7 @@ function initialstates(rng::AbstractRNG, ::GRUCell) return (rng=replicate(rng),) end -function (gru::GRUCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple}, +function (gru::GRUCell)(x::AbstractMatrix, ps::VALID_PARAMETER_TYPES, st::NamedTuple) rng = replicate(st.rng) @set! st.rng = rng @@ -303,7 +303,7 @@ function (gru::GRUCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple} end function (gru::GRUCell)((x, hidden_state)::Tuple{<:AbstractMatrix, <:AbstractMatrix}, - ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) + ps::VALID_PARAMETER_TYPES, st::NamedTuple) gxs = multigate(ps.weight_i * x, Val(3)) ghbs = multigate(ps.weight_h * hidden_state .+ ps.bias_h, Val(3)) diff --git a/src/nnlib.jl b/src/nnlib.jl index 2c86d3ca2c..af32d00ff2 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -121,7 +121,7 @@ end # Dropout @inline _dropout_shape(s, ::Colon) = size(s) @inline function _dropout_shape(s, dims) - tuple((i ∉ dims ? 1 : si for (i, si) in enumerate(size(s)))...) + return tuple((i ∉ dims ? 1 : si for (i, si) in enumerate(size(s)))...) end @inline _dropout_kernel(y::T, p, q) where {T} = y > p ? q : zero(T) @@ -186,10 +186,10 @@ const cudnnValidActivationTypes = Union{ ## I think this is handled by NNlibCUDA. But currently leaving here for ## benchmarking larger models function getCUDNNActivationMode(::Union{typeof(tanh), typeof(tanh_fast)}) - CUDNN.CUDNN_ACTIVATION_TANH + return CUDNN.CUDNN_ACTIVATION_TANH end function getCUDNNActivationMode(::Union{typeof(sigmoid), typeof(sigmoid_fast)}) - CUDNN.CUDNN_ACTIVATION_SIGMOID + return CUDNN.CUDNN_ACTIVATION_SIGMOID end getCUDNNActivationMode(::Union{typeof(relu)}) = CUDNN.CUDNN_ACTIVATION_RELU getCUDNNActivationMode(::Union{typeof(elu)}) = CUDNN.CUDNN_ACTIVATION_ELU @@ -232,7 +232,7 @@ Computes `x .+ y`. Dispatches to CUDNN if possible end @inline function elementwise_add_pullback(x, y, Δ) - broadcast_shape_pullback(x, Δ), broadcast_shape_pullback(y, Δ) + return broadcast_shape_pullback(x, Δ), broadcast_shape_pullback(y, Δ) end """ diff --git a/src/random.jl b/src/random.jl deleted file mode 100644 index 980d947057..0000000000 --- a/src/random.jl +++ /dev/null @@ -1,2 +0,0 @@ -function get_prng() -end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 4e91e3d055..2e511f01d6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -67,7 +67,7 @@ replicate(rng::CUDA.RNG) = deepcopy(rng) # Linear Algebra @inline _norm(x; dims=Colon()) = sqrt.(sum(abs2, x; dims=dims)) @inline function _norm_except(x::AbstractArray{T, N}, except_dim=N) where {T, N} - _norm(x; dims=filter(i -> i != except_dim, 1:N)) + return _norm(x; dims=filter(i -> i != except_dim, 1:N)) end # Convolution @@ -90,7 +90,7 @@ _maybetuple_string(pad::Tuple) = all(==(pad[1]), pad) ? string(pad[1]) : string( struct SamePad end function calc_padding(lt, pad, k::NTuple{N, T}, dilation, stride) where {T, N} - expand(Val(2 * N), pad) + return expand(Val(2 * N), pad) end function calc_padding(lt, ::SamePad, k::NTuple{N, T}, dilation, stride) where {N, T} @@ -106,13 +106,13 @@ end # Handling ComponentArrays ## NOTE: We should probably upsteam some of these function Base.zero(c::ComponentArray{T, N, <:CuArray{T}}) where {T, N} - ComponentArray(zero(getdata(c)), getaxes(c)) + return ComponentArray(zero(getdata(c)), getaxes(c)) end Base.vec(c::ComponentArray{T, N, <:CuArray{T}}) where {T, N} = getdata(c) function Base.:-(x::ComponentArray{T, N, <:CuArray{T}}) where {T, N} - ComponentArray(-getdata(x), getaxes(x)) + return ComponentArray(-getdata(x), getaxes(x)) end function Base.similar(c::ComponentArray{T, N, <:CuArray{T}}, @@ -156,7 +156,7 @@ end ComponentArrays.recursive_length(nt::NamedTuple{(), Tuple{}}) = 0 # Return Nothing if field not present -function safe_getproperty(x::Union{ComponentArray, NamedTuple}, k::Symbol) +function safe_getproperty(x::VALID_PARAMETER_TYPES, k::Symbol) k ∈ propertynames(x) && return getproperty(x, k) return nothing end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index f496116613..b2e3274fff 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -206,8 +206,8 @@ end par.layers[2](ip2.x, ps.layer_2, st.layer_2)[1] gs = gradient((p, x...) -> sum(par(x, p, st)[1]), ps, ip, ip2) gs_reg = gradient(ps, ip, ip2) do p, x, y - sum(par.layers[1](x.x, p.layer_1, st.layer_1)[1] + - par.layers[2](y.x, p.layer_2, st.layer_2)[1]) + return sum(par.layers[1](x.x, p.layer_1, st.layer_1)[1] + + par.layers[2](y.x, p.layer_2, st.layer_2)[1]) end @test gs[1] ≈ gs_reg[1] diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 4f49a503ad..ccf531c898 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -80,7 +80,7 @@ end x = rand(6, 5) res, (dx,) = Zygote.withgradient(x) do x x1, _, x3 = Lux.multigate(x, Val(3)) - sum(x1) + sum(x3 .* 2) + return sum(x1) + sum(x3 .* 2) end @test res == sum(x[1:2, :]) + 2sum(x[5:6, :]) @test dx == [ones(2, 5); zeros(2, 5); fill(2, 2, 5)] diff --git a/test/runtests.jl b/test/runtests.jl index 1f190b1d71..9e98cb17a4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,14 +4,14 @@ const GROUP = get(ENV, "GROUP", "All") function dev_subpkg(subpkg) subpkg_path = joinpath(dirname(@__DIR__), "lib", subpkg) - Pkg.develop(PackageSpec(path=subpkg_path)) + return Pkg.develop(PackageSpec(path=subpkg_path)) end function activate_subpkg_env(subpkg) subpkg_path = joinpath(dirname(@__DIR__), "lib", subpkg) Pkg.activate(subpkg_path) Pkg.develop(PackageSpec(path=subpkg_path)) - Pkg.instantiate() + return Pkg.instantiate() end groups = if GROUP == "All"