Skip to content

Commit

Permalink
Merge pull request #397 from vincentmolin/master
Browse files Browse the repository at this point in the history
Move exports to main source file
  • Loading branch information
CarloLucibello authored Mar 3, 2022
2 parents aa86827 + 5a13f3f commit d8b9b41
Show file tree
Hide file tree
Showing 19 changed files with 38 additions and 60 deletions.
38 changes: 38 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@ using Requires
using ChainRulesCore
import ChainRulesCore: rrule
using Base.Broadcast: broadcasted
using Base.Threads
using Statistics
using Statistics: mean
using LinearAlgebra
using LinearAlgebra: BlasFloat, Transpose, Adjoint, AdjOrTransAbsMat
using LinearAlgebra.BLAS: libblas, BlasInt, @blasfunc

const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N}
const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}

# Include APIs
include("dim_helpers.jl")
export ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims

is_nnpack_available() = false

Expand All @@ -27,14 +33,46 @@ is_nnpack_available() = false
end

include("activations.jl")
for f in ACTIVATIONS
@eval export $(f)
end
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu # Aliases

include("softmax.jl")
export softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax,
logsoftmax!, ∇logsoftmax, ∇logsoftmax!, logsumexp

include("batched/batchedadjtrans.jl")
include("batched/batchedmul.jl")
export batched_mul, batched_mul!, , batched_vec,
batched_transpose, batched_adjoint

include("gemm.jl")
export grid_sample, ∇grid_sample

include("conv.jl")
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
∇conv_filter!, depthwiseconv, depthwiseconv!,
∇depthwiseconv_data, ∇depthwiseconv_data!,
∇depthwiseconv_filter, ∇depthwiseconv_filter!

include("conv_bias_act.jl")
export conv_bias_act, conv_bias_act!

include("pooling.jl")
export maxpool, maxpool!, meanpool, meanpool!,
∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!

include("padding.jl")
export pad_constant, pad_repeat, pad_reflect, pad_zeros

include("upsample.jl")
export upsample_nearest, ∇upsample_nearest,
upsample_linear, ∇upsample_linear,
upsample_bilinear, ∇upsample_bilinear,
upsample_trilinear, ∇upsample_trilinear,
pixel_shuffle

include("gather.jl")
include("scatter.jl")
include("utils.jl")
Expand Down
7 changes: 0 additions & 7 deletions src/activations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,6 @@ ACTIVATIONS = [
:tanh_fast, :sigmoid_fast,
]

for f in ACTIVATIONS
@eval export $(f)
end

# Aliases
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu

# of type float (to allow for integer inputs)
oftf(x, y) = oftype(float(x), y)

Expand Down
2 changes: 0 additions & 2 deletions src/batched/batchedadjtrans.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using LinearAlgebra

import Base: -
import Adapt: adapt_structure, adapt

Expand Down
8 changes: 0 additions & 8 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@

export batched_mul, batched_mul!, , batched_vec
export batched_transpose, batched_adjoint

include("./batchedadjtrans.jl")

using LinearAlgebra: BlasFloat, Transpose, Adjoint, AdjOrTransAbsMat

_unbatch(A) = A
_unbatch(A::BatchedAdjOrTrans) = parent(A)

Expand Down
4 changes: 0 additions & 4 deletions src/conv.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!, depthwiseconv,
depthwiseconv!, ∇depthwiseconv_data, ∇depthwiseconv_data!, ∇depthwiseconv_filter,
∇depthwiseconv_filter!

## Convolution API
#
# We provide the following generic methods, for 3d, 4d, and 5d tensors, calculating 1d,
Expand Down
2 changes: 0 additions & 2 deletions src/conv_bias_act.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export conv_bias_act, conv_bias_act!

function conv_bias_act(x::AbstractArray{xT,N}, w::AbstractArray{wT,N},
cdims::ConvDims, b::AbstractArray{bT,N}, σ=identity; kwargs...) where {xT, wT, bT, N}
y = similar(x, promote_type(xT, wT, bT), output_size(cdims)..., channels_out(cdims), size(x,N))
Expand Down
2 changes: 0 additions & 2 deletions src/dim_helpers/ConvDims.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export ConvDims

"""
ConvDims
Expand Down
2 changes: 0 additions & 2 deletions src/dim_helpers/DenseConvDims.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export DenseConvDims

"""
DenseConvDims
Expand Down
2 changes: 0 additions & 2 deletions src/dim_helpers/DepthwiseConvDims.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export DepthwiseConvDims

"""
DepthwiseConvDims
Expand Down
2 changes: 0 additions & 2 deletions src/dim_helpers/PoolDims.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export PoolDims

"""
PoolDims(x_size::NTuple{M}, k::Union{NTuple{L, Int}, Int};
stride=k, padding=0, dilation=1) where {M, L}
Expand Down
1 change: 0 additions & 1 deletion src/functions.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using NNlib: sigmoid
"""
glu(x, dim = 1)
Expand Down
3 changes: 0 additions & 3 deletions src/gemm.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
## Low level gemm! call with pointers
## Borrowed from Knet.jl, adapted for compile-time constants

using LinearAlgebra
using LinearAlgebra.BLAS: libblas, BlasInt, @blasfunc

using Compat: get_num_threads, set_num_threads # needs Compat 3.13, for any Julia < 1.6

"""
Expand Down
1 change: 0 additions & 1 deletion src/impl/conv_direct.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
## This file contains direct Julia implementations of 2d and 3d convolutions
using Base.Threads

# Helper functions for restricting x/w overreach
function clamp_lo(x, w)
Expand Down
2 changes: 0 additions & 2 deletions src/impl/pooling_direct.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using Statistics

# Pooling is so similar, we abstract over meanpooling and maxpooling, simply replacing
# the inner loop operation and a few initialization parameters.
for name in (:max, :mean)
Expand Down
2 changes: 0 additions & 2 deletions src/padding.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export pad_constant, pad_repeat, pad_reflect, pad_zeros

"""
pad_zeros(x, pad::Tuple; [dims])
pad_zeros(x, pad::Int; [dims])
Expand Down
2 changes: 0 additions & 2 deletions src/pooling.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export maxpool, maxpool!, meanpool, meanpool!, ∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!

## Pooling API
#
# We provide the following generic methods, for 3d, 4d, and 5d tensors, calculating 1d,
Expand Down
2 changes: 0 additions & 2 deletions src/sampling.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export grid_sample, ∇grid_sample

@inline in_bounds(h, w, H, W) = 1 h H && 1 w W
# Borders are considered out-of-bounds for gradient.
@inline clip_coordinate(coordinate, dim_size) = min(dim_size, max(1, coordinate))
Expand Down
10 changes: 0 additions & 10 deletions src/softmax.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
export softmax,
softmax!,
∇softmax,
∇softmax!,
logsoftmax,
logsoftmax!,
∇logsoftmax,
∇logsoftmax!,
logsumexp

"""
softmax(x; dims = 1)
Expand Down
6 changes: 0 additions & 6 deletions src/upsample.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
export upsample_nearest, ∇upsample_nearest,
upsample_linear, ∇upsample_linear,
upsample_bilinear, ∇upsample_bilinear,
upsample_trilinear, ∇upsample_trilinear,
pixel_shuffle

"""
upsample_nearest(x, scale::NTuple{S,Int})
upsample_nearest(x; size::NTuple{S,Int})
Expand Down

0 comments on commit d8b9b41

Please sign in to comment.