Skip to content

Commit

Permalink
rename a struct in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Oct 11, 2024
1 parent 21383ab commit 3a56ad3
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions test/integration/static_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,24 @@ end
# This a variation of `Selector` above that stores the names of rejected features in the
# output of `fit`, for inspection by an accessor function called `rejected`.

struct Selector2
struct FancySelector
names::Vector{Symbol}
end
Selector2(; names=Symbol[]) = Selector2(names) # LearnAPI.constructor defined later
FancySelector(; names=Symbol[]) = FancySelector(names) # LearnAPI.constructor defined later

mutable struct Selector2Fit
algorithm::Selector2
mutable struct FancySelectorFitted
algorithm::FancySelector
rejected::Vector{Symbol}
Selector2Fit(algorithm) = new(algorithm)
FancySelectorFitted(algorithm) = new(algorithm)
end
LearnAPI.algorithm(model::Selector2Fit) = model.algorithm
rejected(model::Selector2Fit) = model.rejected
LearnAPI.algorithm(model::FancySelectorFitted) = model.algorithm
rejected(model::FancySelectorFitted) = model.rejected

# Here we are wrapping `algorithm` with a place-holder for the `rejected` feature names.
LearnAPI.fit(algorithm::Selector2; verbosity=1) = Selector2Fit(algorithm)
LearnAPI.fit(algorithm::FancySelector; verbosity=1) = FancySelectorFitted(algorithm)

# output the filtered table and add `rejected` field to model (mutatated!)
function LearnAPI.transform(model::Selector2Fit, X)
function LearnAPI.transform(model::FancySelectorFitted, X)
table = Tables.columntable(X)
names = Tables.columnnames(table)
keep = LearnAPI.algorithm(model).names
Expand All @@ -98,15 +98,15 @@ function LearnAPI.transform(model::Selector2Fit, X)
end

# fit and transform in one step:
function LearnAPI.transform(algorithm::Selector2, X)
function LearnAPI.transform(algorithm::FancySelector, X)
model = fit(algorithm)
transform(model, X)
end

# note the necessity of overloading `is_static` (`fit` consumes no data):
@trait(
Selector2,
constructor = Selector2,
FancySelector,
constructor = FancySelector,
is_static = true,
tags = ("feature engineering",),
functions = (
Expand All @@ -120,7 +120,7 @@ end
)

@testset "test a variation that reports byproducts" begin
algorithm = Selector2(names=[:x, :w])
algorithm = FancySelector(names=[:x, :w])
X = DataFrames.DataFrame(rand(3, 4), [:x, :y, :z, :w])
model = fit(algorithm) # no data arguments!
@test !isdefined(model, :reject)
Expand Down

0 comments on commit 3a56ad3

Please sign in to comment.