diff --git a/src/builders.jl b/src/builders.jl index a7123fb6..45316161 100644 --- a/src/builders.jl +++ b/src/builders.jl @@ -14,11 +14,13 @@ abstract type Builder <: MLJModelInterface.MLJType end """ - Linear(; σ=Flux.relu) + Linear(; σ=Flux.relu, rng=Random.GLOBAL_RNG) MLJFlux builder that constructs a fully connected two layer network with activation function `σ`. The number of input and output nodes is -determined from the data. +determined from the data. The bias and coefficients are initialized +using `Flux.glorot_uniform(rng)`. If `rng` is an integer, it is +instead used as the seed for a `MersenneTwister`. """ mutable struct Linear <: Builder @@ -29,14 +31,18 @@ build(builder::Linear, rng, n::Integer, m::Integer) = Flux.Chain(Flux.Dense(n, m, builder.σ, init=Flux.glorot_uniform(rng))) """ - Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid) + Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid, rng=GLOBAL_RNG) MLJFlux builder that constructs a full-connected three-layer network using `n_hidden` nodes in the hidden layer and the specified `dropout` (defaulting to 0.5). An activation function `σ` is applied between the hidden and final layers. If `n_hidden=0` (the default) then `n_hidden` is the geometric mean of the number of input and output nodes. The -number of input and output nodes is determined from the data. +number of input and output nodes is determined from the data. + +The each layer is initialized using `Flux.glorot_uniform(rng)`. If +`rng` is an integer, it is instead used as the seed for a +`MersenneTwister`. """ mutable struct Short <: Builder