diff --git a/Project.toml b/Project.toml index 50fb5dc8..389245ea 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJFlux" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" authors = ["Anthony D. Blaom ", "Ayush Shridhar "] -version = "0.1.17" +version = "0.2.0" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -10,6 +10,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" @@ -29,9 +30,10 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJScientificTypes = "2e2323e0-db8b-457b-ae0d-bdfb3bc63afd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["LinearAlgebra", "MLDatasets", "MLJBase", "MLJScientificTypes", "Random", "Statistics", "StatsBase", "Test"] +test = ["LinearAlgebra", "MLDatasets", "MLJBase", "MLJScientificTypes", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"] diff --git a/README.md b/README.md index a7ece050..c5d27c35 100644 --- a/README.md +++ b/README.md @@ -84,18 +84,18 @@ NeuralNetworkClassifier = @load NeuralNetworkClassifier julia> clf = NeuralNetworkClassifier() NeuralNetworkClassifier( - builder = Short( - n_hidden = 0, - dropout = 0.5, - σ = NNlib.σ), - finaliser = NNlib.softmax, - optimiser = ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}()), - loss = Flux.crossentropy, - epochs = 10, - batch_size = 1, - lambda = 0.0, - alpha = 0.0, - optimiser_changes_trigger_retraining = false) @ 1…60 + builder = Short( + n_hidden = 0, + dropout = 0.5, + σ = NNlib.σ), + finaliser = NNlib.softmax, + optimiser = ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}()), + loss = Flux.crossentropy, + epochs = 10, + batch_size = 1, + lambda = 0.0, + alpha = 0.0, + optimiser_changes_trigger_retraining = false) @ 1…60 ``` #### Incremental training @@ -121,8 +121,8 @@ julia> fit!(mach, verbosity=2) [ Info: Loss is 0.7347 Machine{NeuralNetworkClassifier{Short,…},…} @804 trained 2 times; caches data args: - 1: Source @985 ⏎ `Table{AbstractVector{Continuous}}` - 2: Source @367 ⏎ `AbstractVector{Multiclass{3}}` + 1: Source @985 ⏎ `Table{AbstractVector{Continuous}}` + 2: Source @367 ⏎ `AbstractVector{Multiclass{3}}` julia> training_loss = cross_entropy(predict(mach, X), y) |> mean 0.7347092796453824 @@ -140,15 +140,15 @@ Chain(Chain(Dense(4, 3, σ), Flux.Dropout{Float64}(0.5, false), Dense(3, 3)), so ```julia r = range(clf, :epochs, lower=1, upper=200, scale=:log10) curve = learning_curve(clf, X, y, - range=r, - resampling=Holdout(fraction_train=0.7), - measure=cross_entropy) + range=r, + resampling=Holdout(fraction_train=0.7), + measure=cross_entropy) using Plots plot(curve.parameter_values, - curve.measurements, - xlab=curve.parameter_name, - xscale=curve.parameter_scale, - ylab = "Cross Entropy") + curve.measurements, + xlab=curve.parameter_name, + xscale=curve.parameter_scale, + ylab = "Cross Entropy") ``` @@ -239,13 +239,31 @@ CPU at then conclusion of `fit!`, and made available as `fitted_params(mach)`. +### Random number generators and reproducibility + +Every MLJFlux model includes an `rng` hyper-parameter that is passed +to builders for the purposes of weight initialization. This can be +any `AbstractRNG` or the seed (integer) for a `MersenneTwister` that +will be reset on every cold restart of model (machine) training. + +Until there is a [mechanism for +doing so](https://github.com/FluxML/Flux.jl/issues/1617) `rng` is *not* +passed to dropout layers and one must manually seed the `GLOBAL_RNG` +for reproducibility purposes, when using a builder that includes +`Dropout` (such as `MLJFlux.Short`). If training models on a +GPU (i.e., `acceleration isa CUDALibs`) one must additionally call +`CUDA.seed!(...)`. + + ### Built-in builders -MLJ provides two simple builders out of the box: +MLJ provides two simple builders out of the box. In all cases weights + are intitialized using `glorot_uniform(rng)` where `rng` is the RNG + (or `MersenneTwister` seed) specified by the MLJFlux model. -- `MLJFlux.Linear(σ=...)` builds a fully connected two layer - network with `n_in` inputs and `n_out` outputs, with activation - function `σ`, defaulting to a `MLJFlux.relu`. +- `MLJFlux.Linear(σ=...)` builds a fully connected two layer network + with `n_in` inputs and `n_out` outputs, with activation function + `σ`, defaulting to a `MLJFlux.relu`. - `MLJFlux.Short(n_hidden=..., dropout=..., σ=...)` builds a full-connected three-layer network with `n_in` inputs and `n_out` @@ -268,7 +286,8 @@ All models share the following hyper-parameters: 2. `optimiser`: The optimiser to use for training. Default = `Flux.ADAM()` -3. `loss`: The loss function used for training. Default = `Flux.mse` (regressors) and `Flux.crossentropy` (classifiers) +3. `loss`: The loss function used for training. Default = `Flux.mse` + (regressors) and `Flux.crossentropy` (classifiers) 4. `n_epochs`: Number of epochs to train for. Default = `10` @@ -278,9 +297,15 @@ All models share the following hyper-parameters: 7. `alpha`: The L2/L1 mix of regularization. Default = 0. Range = [0, 1] -8. `acceleration`: Use `CUDALibs()` for training on GPU; default is `CPU1()`. +8. `rng`: The random number generator (RNG) passed to builders, for + weight intitialization, for example. Can be any `AbstractRNG` or + the seed (integer) for a `MersenneTwister` that is reset on every + cold restart of model (machine) training. Default = + `GLOBAL_RNG`. + +9. `acceleration`: Use `CUDALibs()` for training on GPU; default is `CPU1()`. -9. `optimiser_changes_trigger_retraining`: True if fitting an +10. `optimiser_changes_trigger_retraining`: True if fitting an associated machine should trigger retraining from scratch whenever the optimiser changes. Default = `false` @@ -309,13 +334,16 @@ any of the first three models in Table 1. The definition includes one mutable struct and one method: ```julia -mutable struct MyNetwork <: MLJFlux.Builder - n1 :: Int - n2 :: Int +mutable struct MyBuilder <: MLJFlux.Builder + n1 :: Int + n2 :: Int end -function MLJFlux.build(nn::MyNetwork, n_in, n_out) - return Chain(Dense(n_in, nn.n1), Dense(nn.n1, nn.n2), Dense(nn.n2, n_out)) +function MLJFlux.build(nn::MyBuilder, rng, n_in, n_out) + init = Flux.glorot_uniform(rng) + return Chain(Dense(n_in, nn.n1, init=init), + Dense(nn.n1, nn.n2, init=init), + Dense(nn.n2, n_out, init=init)) end ``` @@ -330,8 +358,8 @@ sub-typing `MLJFlux.Builder` and defining a new `MLJFlux.build` method with one of these signatures: ```julia -MLJFlux.build(builder::MyNetwork, n_in, n_out) -MLJFlux.build(builder::MyNetwork, n_in, n_out, n_channels) # for use with `ImageClassifier` +MLJFlux.build(builder::MyBuilder, rng, n_in, n_out) +MLJFlux.build(builder::MyBuilder, rng, n_in, n_out, n_channels) # for use with `ImageClassifier` ``` This method must return a `Flux.Chain` instance, `chain`, subject to the @@ -339,12 +367,13 @@ following conditions: - `chain(x)` must make sense: - - for any `x <: Vector{<:AbstractFloat}` of length `n_in` (for use - with one of the first three model types); or + - for any `x <: Array{<:AbstractFloat, 2}` of size `(n_in, + batch_size)` where `batch_size` is any integer (for use with one + of the first three model types); or - - for any `x <: Array{<:Float32, 4}` of size `(W, H, n_channels, - batch_size)`, where `(W, H) = n_in`, `n_channels` is 1 or 3, and - `batch_size` is any integer (for use with `ImageClassifier`) + - for any `x <: Array{<:Float32, 4}` of size `(W, H, n_channels, + batch_size)`, where `(W, H) = n_in`, `n_channels` is 1 or 3, and + `batch_size` is any integer (for use with `ImageClassifier`) - The object returned by `chain(x)` must be an `AbstractFloat` vector of length `n_out`. @@ -388,40 +417,36 @@ using MLDatasets # helper function function flatten(x::AbstractArray) - return reshape(x, :, size(x)[end]) + return reshape(x, :, size(x)[end]) end import MLJFlux mutable struct MyConvBuilder - filter_size::Int - channels1::Int - channels2::Int - channels3::Int + filter_size::Int + channels1::Int + channels2::Int + channels3::Int end -function MLJFlux.build(b::MyConvBuilder, n_in, n_out, n_channels) - - k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3 +function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels) - mod(k, 2) == 1 || error("`filter_size` must be odd. ") + k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3 - # padding to preserve image size on convolution: - p = div(k - 1, 2) + mod(k, 2) == 1 || error("`filter_size` must be odd. ") - # compute size, in first two dims, of output of final maxpool layer: - half(x) = div(x, 2) - h = n_in[1] |> half |> half |> half - w = n_in[2] |> half |> half |> half + # padding to preserve image size on convolution: + p = div(k - 1, 2) - return Chain( - Conv((k, k), n_channels => c1, pad=(p, p), relu), - MaxPool((2, 2)), - Conv((k, k), c1 => c2, pad=(p, p), relu), - MaxPool((2, 2)), - Conv((k, k), c2 => c3, pad=(p, p), relu), - MaxPool((2 ,2)), - flatten, - Dense(h*w*c3, n_out)) + front = Chain( + Conv((k, k), n_channels => c1, pad=(p, p), relu), + MaxPool((2, 2)), + Conv((k, k), c1 => c2, pad=(p, p), relu), + MaxPool((2, 2)), + Conv((k, k), c2 => c3, pad=(p, p), relu), + MaxPool((2 ,2)), + flatten) + d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first + return Chain(front, Dense(d, n_out)) end ``` @@ -429,26 +454,28 @@ Next, we load some of the MNIST data and check scientific types conform to those is the table above: ```julia -N = 1000 -X, y = MNIST.traindata(); +N = 500 +Xraw, yraw = MNIST.traindata(); +Xraw = Xraw[:,:,1:N]; +yraw = yraw[1:N]; -julia> scitype(X) -AbstractArray{GrayImage{28,28},1} +julia> scitype(Xraw) +AbstractArray{Unknown, 3} -julia> scitype(y) +julia> scitype(yraw) AbstractArray{Count,1} ``` -Inputs should have scitype `GrayImage` +Inputs should have element scitype `GrayImage`: ```julia -X = coerce(X, GrayImage); +X = coerce(Xraw, GrayImage); ``` -For classifiers, target must have element scitype `<: Finite`, so we fix this: +For classifiers, target must have element scitype `<: Finite`: ```julia -y = coerce(y, Multiclass); +y = coerce(yraw, Multiclass); ``` Instantiating an image classifier model: @@ -456,8 +483,8 @@ Instantiating an image classifier model: ```julia ImageClassifier = @load ImageClassifier clf = ImageClassifier(builder=MyConvBuilder(3, 16, 32, 32), - epochs=10, - loss=Flux.crossentropy) + epochs=10, + loss=Flux.crossentropy) ``` And evaluating the accuracy of the model on a 30% holdout set: @@ -466,9 +493,9 @@ And evaluating the accuracy of the model on a 30% holdout set: mach = machine(clf, X, y) julia> evaluate!(mach, - resampling=Holdout(rng=123, fraction_train=0.7), - operation=predict_mode, - measure=misclassification_rate) + resampling=Holdout(rng=123, fraction_train=0.7), + operation=predict_mode, + measure=misclassification_rate) ┌────────────────────────┬───────────────┬────────────┐ │ _.measure │ _.measurement │ _.per_fold │ ├────────────────────────┼───────────────┼────────────┤ diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index 805113e4..bc480b62 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -12,6 +12,7 @@ using Tables using Statistics using ColorTypes using ComputationalResources +using Random include("core.jl") include("builders.jl") diff --git a/src/builders.jl b/src/builders.jl index 6f68ea13..45316161 100644 --- a/src/builders.jl +++ b/src/builders.jl @@ -14,29 +14,35 @@ abstract type Builder <: MLJModelInterface.MLJType end """ - Linear(; σ=Flux.relu) + Linear(; σ=Flux.relu, rng=Random.GLOBAL_RNG) MLJFlux builder that constructs a fully connected two layer network with activation function `σ`. The number of input and output nodes is -determined from the data. +determined from the data. The bias and coefficients are initialized +using `Flux.glorot_uniform(rng)`. If `rng` is an integer, it is +instead used as the seed for a `MersenneTwister`. """ mutable struct Linear <: Builder σ end Linear(; σ=Flux.relu) = Linear(σ) -build(builder::Linear, n::Integer, m::Integer) = - Flux.Chain(Flux.Dense(n, m, builder.σ)) +build(builder::Linear, rng, n::Integer, m::Integer) = + Flux.Chain(Flux.Dense(n, m, builder.σ, init=Flux.glorot_uniform(rng))) """ - Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid) + Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid, rng=GLOBAL_RNG) MLJFlux builder that constructs a full-connected three-layer network using `n_hidden` nodes in the hidden layer and the specified `dropout` (defaulting to 0.5). An activation function `σ` is applied between the hidden and final layers. If `n_hidden=0` (the default) then `n_hidden` is the geometric mean of the number of input and output nodes. The -number of input and output nodes is determined from the data. +number of input and output nodes is determined from the data. + +The each layer is initialized using `Flux.glorot_uniform(rng)`. If +`rng` is an integer, it is instead used as the seed for a +`MersenneTwister`. """ mutable struct Short <: Builder @@ -45,11 +51,13 @@ mutable struct Short <: Builder σ end Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid) = Short(n_hidden, dropout, σ) -function build(builder::Short, n, m) +function build(builder::Short, rng, n, m) n_hidden = builder.n_hidden == 0 ? round(Int, sqrt(n*m)) : builder.n_hidden - return Flux.Chain(Flux.Dense(n, n_hidden, builder.σ), - Flux.Dropout(builder.dropout), - Flux.Dense(n_hidden, m)) + init=Flux.glorot_uniform(rng) + Flux.Chain( + Flux.Dense(n, n_hidden, builder.σ, init=init), + # TODO: fix next after https://github.com/FluxML/Flux.jl/issues/1617 + Flux.Dropout(builder.dropout), + Flux.Dense(n_hidden, m, init=init)) end - diff --git a/src/classifier.jl b/src/classifier.jl index 084193be..82d4efc9 100644 --- a/src/classifier.jl +++ b/src/classifier.jl @@ -1,4 +1,4 @@ -# if `b` is a builder, then `b(model, shape...)` is called to make a +# if `b` is a builder, then `b(model, rng, shape...)` is called to make a # new chain, where `shape` is the return value of this method: function MLJFlux.shape(model::NeuralNetworkClassifier, X, y) levels = MLJModelInterface.classes(y[1]) @@ -8,8 +8,8 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y) end # builds the end-to-end Flux chain needed, given the `model` and `shape`: -MLJFlux.build(model::NeuralNetworkClassifier, shape) = - Flux.Chain(build(model.builder, shape...), +MLJFlux.build(model::NeuralNetworkClassifier, rng, shape) = + Flux.Chain(build(model.builder, rng, shape...), model.finaliser) # returns the model `fitresult` (see "Adding Models for General Use" diff --git a/src/common.jl b/src/common.jl index bd7f765f..47b77186 100644 --- a/src/common.jl +++ b/src/common.jl @@ -40,6 +40,8 @@ end # # FIT AND UPDATE +true_rng(model) = model.rng isa Integer ? MersenneTwister(model.rng) : model.rng + function MLJModelInterface.fit(model::MLJFluxModel, verbosity::Int, X, @@ -47,8 +49,10 @@ function MLJModelInterface.fit(model::MLJFluxModel, data = collate(model, X, y) + rng = true_rng(model) + shape = MLJFlux.shape(model, X, y) - chain = build(model, shape) + chain = build(model, rng, shape) optimiser = deepcopy(model.optimiser) @@ -65,7 +69,7 @@ function MLJModelInterface.fit(model::MLJFluxModel, # `optimiser` is now mutated - cache = (deepcopy(model), data, history, shape, optimiser) + cache = (deepcopy(model), data, history, shape, optimiser, deepcopy(rng)) fitresult = MLJFlux.fitresult(model, chain, y) report = (training_losses=history, ) @@ -80,7 +84,7 @@ function MLJModelInterface.update(model::MLJFluxModel, X, y) - old_model, data, old_history, shape, optimiser = old_cache + old_model, data, old_history, shape, optimiser, rng = old_cache old_chain = old_fitresult[1] optimiser_flag = model.optimiser_changes_trigger_retraining && @@ -93,7 +97,8 @@ function MLJModelInterface.update(model::MLJFluxModel, chain = old_chain epochs = model.epochs - old_model.epochs else - chain = build(model, shape) + rng = true_rng(model) + chain = build(model, rng, shape) data = collate(model, X, y) epochs = model.epochs end @@ -123,7 +128,7 @@ function MLJModelInterface.update(model::MLJFluxModel, end fitresult = MLJFlux.fitresult(model, chain, y) - cache = (deepcopy(model), data, history, shape, optimiser) + cache = (deepcopy(model), data, history, shape, optimiser, deepcopy(rng)) report = (training_losses=history, ) return fitresult, cache, report diff --git a/src/image.jl b/src/image.jl index 766a33cf..5c973eb7 100644 --- a/src/image.jl +++ b/src/image.jl @@ -11,8 +11,8 @@ function shape(model::ImageClassifier, X, y) return (n_input, n_output, n_channels) end -build(model::ImageClassifier, shape) = - Flux.Chain(build(model.builder, shape...), +build(model::ImageClassifier, rng, shape) = + Flux.Chain(build(model.builder, rng, shape...), model.finaliser) fitresult(model::ImageClassifier, chain, y) = diff --git a/src/regressor.jl b/src/regressor.jl index 32abbc87..f932bff7 100644 --- a/src/regressor.jl +++ b/src/regressor.jl @@ -6,8 +6,8 @@ function shape(model::NeuralNetworkRegressor, X, y) return (n_input, 1) end -build(model::NeuralNetworkRegressor, shape) = - build(model.builder, shape...) +build(model::NeuralNetworkRegressor, rng, shape) = + build(model.builder, rng, shape...) fitresult(model::NeuralNetworkRegressor, chain, y) = (chain, nothing) @@ -38,8 +38,8 @@ function shape(model::MultitargetNeuralNetworkRegressor, X, y) return (n_input, n_output) end -build(model::MultitargetNeuralNetworkRegressor, shape) = - build(model.builder, shape...) +build(model::MultitargetNeuralNetworkRegressor, rng, shape) = + build(model.builder, rng, shape...) function fitresult(model::MultitargetNeuralNetworkRegressor, chain, y) target_column_names = Tables.schema(y).names diff --git a/src/types.jl b/src/types.jl index 6afc5f27..86bef752 100644 --- a/src/types.jl +++ b/src/types.jl @@ -15,6 +15,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier] batch_size::Int # size of a batch lambda::Float64 # regularization strength alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) + rng::Union{AbstractRNG,Int64} optimiser_changes_trigger_retraining::Bool acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` end @@ -27,6 +28,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier] , batch_size = 1 , lambda = 0 , alpha = 0 + , rng = Random.GLOBAL_RNG , optimiser_changes_trigger_retraining = false , acceleration = CPU1() ) where {B,F,O,L} @@ -39,6 +41,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier] , batch_size , lambda , alpha + , rng , optimiser_changes_trigger_retraining , acceleration ) @@ -64,6 +67,7 @@ for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor] batch_size::Int # size of a batch lambda::Float64 # regularization strength alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) + rng::Union{AbstractRNG,Integer} optimiser_changes_trigger_retraining::Bool acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` end @@ -75,6 +79,7 @@ for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor] , batch_size = 1 , lambda = 0 , alpha = 0 + , rng = Random.GLOBAL_RNG , optimiser_changes_trigger_retraining=false , acceleration = CPU1() ) where {B,O,L} @@ -86,6 +91,7 @@ for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor] , batch_size , lambda , alpha + , rng , optimiser_changes_trigger_retraining , acceleration) diff --git a/test/builders.jl b/test/builders.jl index d3ec2e66..bddcdf61 100644 --- a/test/builders.jl +++ b/test/builders.jl @@ -2,7 +2,7 @@ myinit(n, m) = reshape(float(1:n*m), n , m) mutable struct TESTBuilder <: MLJFlux.Builder end -MLJFlux.build(builder::TESTBuilder, n_in, n_out) = +MLJFlux.build(builder::TESTBuilder, rng, n_in, n_out) = Flux.Chain(Flux.Dense(n_in, n_out, init=myinit)) @testset_accelerated "issue #152" accel begin @@ -15,10 +15,11 @@ MLJFlux.build(builder::TESTBuilder, n_in, n_out) = y = X.x1 .^2 + X.x2 .* X.x3 - 4 * X.x4 # train a model on all the data using batch size > 1: - model = MLJFlux.NeuralNetworkRegressor(builder = TESTBuilder(), - batch_size=25, - epochs=1, - loss=Flux.mse) + model = MLJFlux.NeuralNetworkRegressor(builder=TESTBuilder(), + batch_size=25, + epochs=1, + loss=Flux.mse, + acceleration=accel) mach = machine(model, X, y) fit!(mach, verbosity=0) @@ -35,3 +36,15 @@ MLJFlux.build(builder::TESTBuilder, n_in, n_out) = @test pretraining_loss ≈ pretraining_loss_by_hand end + +@testset_accelerated "Short" accel begin + builder = MLJFlux.Short(n_hidden=4, σ=Flux.relu, dropout=0) + chain = MLJFlux.build(builder, StableRNGs.StableRNG(123), 5, 3) + ps = Flux.params(chain) + @test size.(ps) == [(4, 5), (4,), (3, 4), (3,)] + + # reproducibility (without dropout): + chain2 = MLJFlux.build(builder, StableRNGs.StableRNG(123), 5, 3) + x = rand(5) + @test chain(x) ≈ chain2(x) +end diff --git a/test/image.jl b/test/image.jl index a31d54ce..f3a1ceb9 100644 --- a/test/image.jl +++ b/test/image.jl @@ -2,18 +2,21 @@ Random.seed!(123) -mutable struct mynn <: MLJFlux.Builder +mutable struct MyNeuralNetwork <: MLJFlux.Builder kernel1 kernel2 end -MLJFlux.build(model::mynn, ip, op, n_channels) = - Flux.Chain(Flux.Conv(model.kernel1, n_channels=>2), - Flux.Conv(model.kernel2, 2=>1), - x->reshape(x, :, size(x)[end]), - Flux.Dense(16, op)) +function MLJFlux.build(model::MyNeuralNetwork, rng, ip, op, n_channels) + init = Flux.glorot_uniform(rng) + Flux.Chain( + Flux.Conv(model.kernel1, n_channels=>2, init=init), + Flux.Conv(model.kernel2, 2=>1, init=init), + x->reshape(x, :, size(x)[end]), + Flux.Dense(16, op, init=init)) +end -builder = mynn((2,2), (2,2)) +builder = MyNeuralNetwork((2,2), (2,2)) # collection of gray images as a 4D array in WHCN format: raw_images = rand(Float32, 6, 6, 1, 50); @@ -78,18 +81,18 @@ function flatten(x::AbstractArray) return reshape(x, :, size(x)[end]) end -function MLJFlux.build(builder::MyConvBuilder, n_in, n_out, n_channels) +function MLJFlux.build(builder::MyConvBuilder, rng, n_in, n_out, n_channels) cnn_output_size = [3,3,32] - + init = Flux.glorot_uniform(rng) return Chain( - Conv((3, 3), n_channels=>16, pad=(1,1), relu), + Conv((3, 3), n_channels=>16, pad=(1,1), relu, init=init), MaxPool((2,2)), - Conv((3, 3), 16=>32, pad=(1,1), relu), + Conv((3, 3), 16=>32, pad=(1,1), relu, init=init), MaxPool((2,2)), - Conv((3, 3), 32=>32, pad=(1,1), relu), + Conv((3, 3), 32=>32, pad=(1,1), relu, init=init), MaxPool((2,2)), flatten, - Dense(prod(cnn_output_size), n_out)) + Dense(prod(cnn_output_size), n_out, init=init)) end losses = [] @@ -121,7 +124,7 @@ reference = losses[1] ## BASIC IMAGE TESTS COLOR -builder = mynn((2,2), (2,2)) +builder = MyNeuralNetwork((2,2), (2,2)) # collection of color images as a 4D array in WHCN format: raw_images = rand(Float32, 6, 6, 3, 50); diff --git a/test/runtests.jl b/test/runtests.jl index 52a8f635..6e20e184 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ import Random.seed! using Statistics import StatsBase using MLJModelInterface.ScientificTypes +using StableRNGs using ComputationalResources using ComputationalResources: CPU1, CUDALibs @@ -19,9 +20,9 @@ EXCLUDED_RESOURCE_TYPES = Any[] MLJFlux.gpu_isdead() && push!(EXCLUDED_RESOURCE_TYPES, CUDALibs) -@info "Available computational resources: $RESOURCES" +@info "MLJFlux supports these computational resources:\n$RESOURCES" @info "Current test run to exclude resources with "* - "these types: $EXCLUDED_RESOURCE_TYPES\n"* + "these types, as unavailable:\n$EXCLUDED_RESOURCE_TYPES\n"* "Excluded tests marked as \"broken\"." # alternative version of Short builder with no dropout; see @@ -31,11 +32,13 @@ mutable struct Short2 <: MLJFlux.Builder σ end Short2(; n_hidden=0, σ=Flux.sigmoid) = Short2(n_hidden, σ) -function MLJFlux.build(builder::Short2, n, m) +function MLJFlux.build(builder::Short2, rng, n, m) n_hidden = builder.n_hidden == 0 ? round(Int, sqrt(n*m)) : builder.n_hidden - return Flux.Chain(Flux.Dense(n, n_hidden, builder.σ), - Flux.Dense(n_hidden, m)) + init = Flux.glorot_uniform(rng) + return Flux.Chain( + Flux.Dense(n, n_hidden, builder.σ, init=init), + Flux.Dense(n_hidden, m, init=init)) end seed!(123) diff --git a/test/test_utils.jl b/test/test_utils.jl index 61a4db76..942643a9 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -134,6 +134,8 @@ function optimisertest(ModelType, X, y, builder, optimiser, accel) mach = machine(model, $X, $y); + # USING GLOBAL RNG + # two epochs in stages: Random.seed!(123) # chains are always initialized on CPU fit!(mach, verbosity=0, force=true); @@ -152,6 +154,27 @@ function optimisertest(ModelType, X, y, builder, optimiser, accel) @test_broken isapprox(l1, l2, rtol=1e-8) end + # USING USER SPECIFIED RNG SEED + + # two epochs in stages: + model.rng = 1234 + mach = machine(model, $X, $y); + + fit!(mach, verbosity=0, force=true); + model.epochs = model.epochs + 1 + fit!(mach, verbosity=0); # update + l1 = MLJBase.report(mach).training_losses[end] + + # two epochs in one go: + fit!(mach, verbosity=1, force=true) + l2 = MLJBase.report(mach).training_losses[end] + + if accel isa CPU1 + @test isapprox(l1, l2) + else + @test_broken isapprox(l1, l2, rtol=1e-8) + end + end) return true