Skip to content

Commit

Permalink
update builder doc-strings
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Jun 23, 2021
1 parent 74f0b6e commit 43c0857
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/builders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
abstract type Builder <: MLJModelInterface.MLJType end

"""
Linear(; σ=Flux.relu)
Linear(; σ=Flux.relu, rng=Random.GLOBAL_RNG)
MLJFlux builder that constructs a fully connected two layer network
with activation function `σ`. The number of input and output nodes is
determined from the data.
determined from the data. The bias and coefficients are initialized
using `Flux.glorot_uniform(rng)`. If `rng` is an integer, it is
instead used as the seed for a `MersenneTwister`.
"""
mutable struct Linear <: Builder
Expand All @@ -29,14 +31,18 @@ build(builder::Linear, rng, n::Integer, m::Integer) =
Flux.Chain(Flux.Dense(n, m, builder.σ, init=Flux.glorot_uniform(rng)))

"""
Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid)
Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid, rng=GLOBAL_RNG)
MLJFlux builder that constructs a full-connected three-layer network
using `n_hidden` nodes in the hidden layer and the specified `dropout`
(defaulting to 0.5). An activation function `σ` is applied between the
hidden and final layers. If `n_hidden=0` (the default) then `n_hidden`
is the geometric mean of the number of input and output nodes. The
number of input and output nodes is determined from the data.
number of input and output nodes is determined from the data.
The each layer is initialized using `Flux.glorot_uniform(rng)`. If
`rng` is an integer, it is instead used as the seed for a
`MersenneTwister`.
"""
mutable struct Short <: Builder
Expand Down

0 comments on commit 43c0857

Please sign in to comment.