From 4e4d27a8f937f646169d80c16555a415f1e02880 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 4 Sep 2023 14:56:58 +1200 Subject: [PATCH 1/4] rethrow the error caught when builder fails to close #237 --- src/mlj_model_interface.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mlj_model_interface.jl b/src/mlj_model_interface.jl index f3af0774..d488eb24 100644 --- a/src/mlj_model_interface.jl +++ b/src/mlj_model_interface.jl @@ -59,6 +59,7 @@ function MLJModelInterface.fit(model::MLJFluxModel, build(model, rng, shape) |> move catch ex @error ERR_BUILDER + throw(ex) end penalty = Penalty(model) From bbc1a446a81275e6203328db88ae6b3ec8444c89 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 5 Sep 2023 09:18:35 +1200 Subject: [PATCH 2/4] add test --- test/mlj_model_interface.jl | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/mlj_model_interface.jl b/test/mlj_model_interface.jl index 6f07a661..a2e6c8fe 100644 --- a/test/mlj_model_interface.jl +++ b/test/mlj_model_interface.jl @@ -52,3 +52,31 @@ end @test length(losses) == 10 end +mutable struct LisasBuilder + n1::Int +end + +@testset "builder errors and issue #237" begin + # create a builder with an intentional flaw; + # `Chains` is undefined - it should be `Chain` + function MLJFlux.build(builder::LisasBuilder, rng, nin, nout) + return Flux.Chains( + Flux.Dense(nin, builder.n1), + Flux.Dense(builder.n1, nout) + ) + end + + model = NeuralNetworkRegressor( + epochs = 2, + batch_size = 32, + builder = LisasBuilder(10), + ) + + X, y = @load_boston + @test_logs( + (:error, MLJFlux.ERR_BUILDER), + @test_throws UndefVarError(:Chains) MLJBase.fit(model, 0, X, y) + ) +end + +true From 98de344b29bf7a9214feadbb410cb881282e6456 Mon Sep 17 00:00:00 2001 From: "Anthony Blaom, PhD" Date: Tue, 5 Sep 2023 09:22:49 +1200 Subject: [PATCH 3/4] Update src/mlj_model_interface.jl Co-authored-by: Okon Samuel <39421418+OkonSamuel@users.noreply.github.com> --- src/mlj_model_interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mlj_model_interface.jl b/src/mlj_model_interface.jl index d488eb24..5ffe903f 100644 --- a/src/mlj_model_interface.jl +++ b/src/mlj_model_interface.jl @@ -59,7 +59,7 @@ function MLJModelInterface.fit(model::MLJFluxModel, build(model, rng, shape) |> move catch ex @error ERR_BUILDER - throw(ex) + rethrow() end penalty = Penalty(model) From d604792cefe0fdabc4859d2a15e9d630f1465171 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 11 Sep 2023 14:12:40 -0700 Subject: [PATCH 4/4] bump 0.3.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1bc53ff4..cac18ad2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJFlux" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" authors = ["Anthony D. Blaom ", "Ayush Shridhar "] -version = "0.3.0" +version = "0.3.1" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"