Skip to content

Commit

Permalink
Style fix
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 6, 2022
1 parent f1ca511 commit 3253734
Show file tree
Hide file tree
Showing 18 changed files with 102 additions and 89 deletions.
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
style = "sciml"
whitespace_in_kwargs = false
always_use_return = true
4 changes: 2 additions & 2 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ CUDA.allowscalar(false)
# ## Loading MNIST
## Use MLDataUtils LabelEnc for natural onehot conversion
function onehot(labels_raw)
convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
end

function loadmnist(batchsize, train_split)
Expand Down Expand Up @@ -66,7 +66,7 @@ function (n::NeuralODE)(x, ps, st)
end

function diffeqsol_to_array(x::ODESolution{T, N, <:AbstractVector{<:CuArray}}) where {T, N}
dropdims(gpu(x); dims=3)
return dropdims(gpu(x); dims=3)
end
diffeqsol_to_array(x::ODESolution) = dropdims(Array(x); dims=3)

Expand Down
4 changes: 2 additions & 2 deletions lib/Boltz/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ Type-stable and faster version of `MLUtils.chunk`
"""
@inline fast_chunk(h::Int, n::Int) = (1:h) .+ h * (n - 1)
@inline function fast_chunk(x::AbstractArray, h::Int, n::Int, ::Val{dim}) where {dim}
selectdim(x, dim, fast_chunk(h, n))
return selectdim(x, dim, fast_chunk(h, n))
end
@inline function fast_chunk(x::AbstractArray, ::Val{N}, d::Val{D}) where {N, D}
fast_chunk.((x,), size(x, D) ÷ N, 1:N, d)
return fast_chunk.((x,), size(x, D) ÷ N, 1:N, d)
end

"""
Expand Down
5 changes: 5 additions & 0 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ using Requires

const use_cuda = Ref{Union{Nothing, Bool}}(nothing)

# NOTE: In theory, we can support any parameter type which allows us to access
# parameters using getproperty. But I will be conservative here and only specify
# NamedTuple and ComponentArray until we have tested other cases properly.
const VALID_PARAMETER_TYPES = Union{NamedTuple, ComponentArray}

# Data Transfer Utilities
include("adapt.jl")
# Utilities
Expand Down
6 changes: 3 additions & 3 deletions src/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ adapt_storage(::LuxCUDAAdaptor, x) = CUDA.cu(x)
adapt_storage(::LuxCUDAAdaptor, x::FillArrays.AbstractFill) = CUDA.cu(colelct(x))
adapt_storage(::LuxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x))
function adapt_storage(to::LuxCUDAAdaptor, x::ComponentArray)
ComponentArray(adapt_storage(to, getdata(x)), getaxes(x))
return ComponentArray(adapt_storage(to, getdata(x)), getaxes(x))
end
adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng

Expand All @@ -18,13 +18,13 @@ function adapt_storage(::LuxCPUAdaptor,
end
adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x)
function adapt_storage(to::LuxCPUAdaptor, x::ComponentArray)
ComponentArray(adapt_storage(to, getdata(x)), getaxes(x))
return ComponentArray(adapt_storage(to, getdata(x)), getaxes(x))
end
adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng
# TODO: SparseArrays
function adapt_storage(::LuxCPUAdaptor,
x::CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix)
adapt(Array, x)
return adapt(Array, x)
end

_isbitsarray(::AbstractArray{<:Number}) = true
Expand Down
5 changes: 3 additions & 2 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ChainRulesCore.Tangent{P}(; kwargs...) where {P <: AbstractExplicitLayer} = NoTa
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)

function ChainRulesCore.rrule(::typeof(Base.broadcasted), ::typeof(identity), x)
x, Δ -> (NoTangent(), NoTangent(), Δ)
return x, Δ -> (NoTangent(), NoTangent(), Δ)
end

# NNlib Functions
Expand Down Expand Up @@ -43,7 +43,8 @@ function ChainRulesCore.rrule(::typeof(dropout),
t::Val{training}) where {T, N, training}
y, mask, rng = dropout(rng, x, p, q, dims, t)
function dropout_pullback((dy, dmask, drng))
return (NoTangent(), NoTangent(), elementwise_mul(dy, mask), NoTangent(), NoTangent(),
return (NoTangent(), NoTangent(), elementwise_mul(dy, mask), NoTangent(),
NoTangent(),
NoTangent(), NoTangent())
end
return (y, mask, rng), dropout_pullback
Expand Down
18 changes: 9 additions & 9 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Generate the initial parameters of the layer `l`.
"""
initialparameters(::AbstractRNG, ::Any) = NamedTuple()
function initialparameters(rng::AbstractRNG, l::NamedTuple)
map(Base.Fix1(initialparameters, rng), l)
return map(Base.Fix1(initialparameters, rng), l)
end

"""
Expand All @@ -32,10 +32,10 @@ initialstates(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1(initialstates, rn
Return the total number of parameters of the layer `l`.
"""
function parameterlength(l::AbstractExplicitLayer)
parameterlength(initialparameters(Random.default_rng(), l))
return parameterlength(initialparameters(Random.default_rng(), l))
end
function parameterlength(nt::Union{NamedTuple, Tuple})
length(nt) == 0 ? 0 : sum(parameterlength, nt)
return length(nt) == 0 ? 0 : sum(parameterlength, nt)
end
parameterlength(a::AbstractArray) = length(a)
parameterlength(x) = 0
Expand All @@ -57,23 +57,23 @@ statelength(x) = 0
Shorthand for getting the parameters and states of the layer `l`. Is equivalent to `(initialparameters(rng, l), initialstates(rng, l))`.
"""
function setup(rng::AbstractRNG, l::AbstractExplicitLayer)
(initialparameters(rng, l), initialstates(rng, l))
return (initialparameters(rng, l), initialstates(rng, l))
end

"""
apply(model::AbstractExplicitLayer, x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple)
Simply calls `model(x, ps, st)`
"""
function apply(model::AbstractExplicitLayer, x, ps::Union{ComponentArray, NamedTuple},
function apply(model::AbstractExplicitLayer, x, ps::VALID_PARAMETER_TYPES,
st::NamedTuple)
model(x, ps, st)
return model(x, ps, st)
end

function Base.show(io::IO, x::AbstractExplicitLayer)
__t = rsplit(string(get_typename(x)), "."; limit=2)
T = length(__t) == 2 ? __t[2] : __t[1]
print(io, "$T()")
return print(io, "$T()")
end

# Abstract Container Layers
Expand All @@ -92,11 +92,11 @@ function initialstates(rng::AbstractRNG,
end

function parameterlength(l::AbstractExplicitContainerLayer{layers}) where {layers}
sum(parameterlength, getfield.((l,), layers))
return sum(parameterlength, getfield.((l,), layers))
end

function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers}
sum(statelength, getfield.((l,), layers))
return sum(statelength, getfield.((l,), layers))
end

function Base.keys(l::AbstractExplicitContainerLayer{layers}) where {layers}
Expand Down
Loading

0 comments on commit 3253734

Please sign in to comment.