Skip to content

Commit

Permalink
in multithreading stack test replace stack with double stack
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Jun 14, 2022
1 parent 9a0cd54 commit d6cf7a1
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions test/composition/models/stacking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -547,31 +547,44 @@ end
end
end

@testset "Test multithreaded version" begin
X, y = make_regression(100, 5; rng=StableRNG(1234))
# a regression `Stack` which has `model` as one of the base models:
function _stack(model, resource)
models = (constant=DeterministicConstantRegressor(),
ridge_lambda=FooBarRegressor(;lambda=0.1),
ridge=FooBarRegressor(;lambda=0))
model=model)
Stack(;
metalearner=FooBarRegressor(;lambda=0.05),
resampling=CV(;nfolds=3),
acceleration=resource,
models...
)
end

stack = Stack(;metalearner=FooBarRegressor(),
resampling=CV(;nfolds=3),
acceleration=CPU1(),
models...)
# return a nested stack in which `model` appears at two levels, with
# both layers accelerated using `resource`:
_double_stack(model, resource) =
_stack(_stack(model, resource), resource)

@testset "Test multithreaded version" begin
X, y = make_regression(100, 5; rng=StableRNG(1234))

stack = _double_stack(FooBarRegressor(;lambda=0.07), CPU1())

mach = machine(stack, X, y)
fit!(mach, verbosity=0)
cpu_fp = fitted_params(mach)
cpu_ypred = predict(mach)

stack.acceleration = CPUThreads()
stack = _double_stack(FooBarRegressor(;lambda=0.07), CPUThreads())

mach = machine(stack, X, y)
fit!(mach, verbosity=0)
thread_fp = fitted_params(mach)
thread_ypred = predict(mach)

@test cpu_ypred thread_ypred
@test cpu_fp.metalearner thread_fp.metalearner
@test cpu_fp.ridge thread_fp.ridge
@test cpu_ypred thread_ypred
@test cpu_fp.metalearner thread_fp.metalearner
@test cpu_fp.ridge_lambda thread_fp.ridge_lambda
end

end
Expand Down

0 comments on commit d6cf7a1

Please sign in to comment.