Skip to content

Commit

Permalink
Merge pull request #239 from FluxML/dev
Browse files Browse the repository at this point in the history
For a 0.3.1 release
  • Loading branch information
ablaom authored Sep 11, 2023
2 parents 37b5f31 + d604792 commit ab630f5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJFlux"
uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>", "Ayush Shridhar <ayush.shridhar1999@gmail.com>"]
version = "0.3.0"
version = "0.3.1"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
1 change: 1 addition & 0 deletions src/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ function MLJModelInterface.fit(model::MLJFluxModel,
build(model, rng, shape) |> move
catch ex
@error ERR_BUILDER
rethrow()
end

penalty = Penalty(model)
Expand Down
28 changes: 28 additions & 0 deletions test/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,31 @@ end
@test length(losses) == 10
end

mutable struct LisasBuilder
n1::Int
end

@testset "builder errors and issue #237" begin
# create a builder with an intentional flaw;
# `Chains` is undefined - it should be `Chain`
function MLJFlux.build(builder::LisasBuilder, rng, nin, nout)
return Flux.Chains(
Flux.Dense(nin, builder.n1),
Flux.Dense(builder.n1, nout)
)
end

model = NeuralNetworkRegressor(
epochs = 2,
batch_size = 32,
builder = LisasBuilder(10),
)

X, y = @load_boston
@test_logs(
(:error, MLJFlux.ERR_BUILDER),
@test_throws UndefVarError(:Chains) MLJBase.fit(model, 0, X, y)
)
end

true

0 comments on commit ab630f5

Please sign in to comment.