Skip to content

Commit

Permalink
add test (and StableRNGs to [extras])
Browse files Browse the repository at this point in the history
add StableRNGs for testing

add test
  • Loading branch information
ablaom committed Jun 21, 2021
1 parent 6628c4b commit d2b51af
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJScientificTypes = "2e2323e0-db8b-457b-ae0d-bdfb3bc63afd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["LinearAlgebra", "MLDatasets", "MLJBase", "MLJScientificTypes", "Random", "Statistics", "StatsBase", "Test"]
test = ["LinearAlgebra", "MLDatasets", "MLJBase", "MLJScientificTypes", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"]
12 changes: 12 additions & 0 deletions test/builders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,15 @@ MLJFlux.build(builder::TESTBuilder, rng, n_in, n_out) =
@test pretraining_loss pretraining_loss_by_hand

end

@testset_accelerated "Short" accel begin
builder = MLJFlux.Short(n_hidden=4, σ=Flux.relu, dropout=0)
chain = MLJFlux.build(builder, StableRNGs.StableRNG(123), 5, 3)
ps = Flux.params(chain)
@test size.(ps) == [(4, 5), (4,), (3, 4), (3,)]

# reproducibility (without dropout):
chain2 = MLJFlux.build(builder, StableRNGs.StableRNG(123), 5, 3)
x = rand(5)
@test chain(x) chain2(x)
end
5 changes: 3 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Random.seed!
using Statistics
import StatsBase
using MLJModelInterface.ScientificTypes
using StableRNGs

using ComputationalResources
using ComputationalResources: CPU1, CUDALibs
Expand All @@ -19,9 +20,9 @@ EXCLUDED_RESOURCE_TYPES = Any[]

MLJFlux.gpu_isdead() && push!(EXCLUDED_RESOURCE_TYPES, CUDALibs)

@info "Available computational resources: $RESOURCES"
@info "MLJFlux supports these computational resources:\n$RESOURCES"
@info "Current test run to exclude resources with "*
"these types: $EXCLUDED_RESOURCE_TYPES\n"*
"these types, as unavailable:\n$EXCLUDED_RESOURCE_TYPES\n"*
"Excluded tests marked as \"broken\"."

# alternative version of Short builder with no dropout; see
Expand Down

0 comments on commit d2b51af

Please sign in to comment.