Skip to content

Commit

Permalink
Merge pull request #18 from JuliaAI/compatibility-fix
Browse files Browse the repository at this point in the history
✨ Allow specifying positional argument for model #17
  • Loading branch information
EssamWisam authored Oct 9, 2023
2 parents e6b1aaa + 7d86463 commit 8245f06
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 15 deletions.
19 changes: 15 additions & 4 deletions src/balanced_bagging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,25 @@ const ERR_MISSING_CLF = "No model specified. Please specify a probabilistic clas
const ERR_BAD_T = "The number of ensemble models `T` cannot be negative."
const INFO_DEF_T(T_def) = "The number of ensemble models was not given and was thus, automatically set to $T_def"*
" which is the ratio of the frequency of the majority class to that of the minority class"
function BalancedBaggingClassifier(;
const ERR_NUM_ARGS_BB = "`BalancedBaggingClassifier` can at most have one non-keyword argument where the model is passed."
const WRN_MODEL_GIVEN = "Ignoring keyword argument `model=...` as model already given as positional argument. "

function BalancedBaggingClassifier(args...;
model = nothing,
T = 0,
rng = Random.default_rng(),
)
model === nothing && error(ERR_MISSING_CLF)
T < 0 && error(ERR_BAD_T)
rng = rng_handler(rng)
length(args) <= 1 || throw(ERR_NUM_ARGS_BB)
if length(args) === 1
atom = first(args)
model === nothing ||
@warn WRN_MODEL_GIVEN
model = atom
else
model === nothing && throw(ERR_MISSING_CLF)
end
T < 0 && error(ERR_BAD_T)
rng = rng_handler(rng)
return BalancedBaggingClassifier(model, T, rng)
end

Expand Down
22 changes: 16 additions & 6 deletions src/balanced_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,33 +55,42 @@ const UNION_MODEL_TYPES = Union{keys(MODELTYPE_TO_COMPOSITETYPE_EVAL)...}


# Possible Errors (for the constructor as well)
const ERR_MODEL_UNSPECIFIED = ArgumentError("Expected an atomic model as argument. None specified. ")
const ERR_MODEL_UNSPECIFIED = ErrorException("Expected an atomic model as argument. None specified. ")

const WRN_BALANCER_UNSPECIFIED = "No balancer was provided. Data will be directly passed to the model. "

const PRETTY_SUPPORTED_MODEL_TYPES = join([string("`", opt, "`") for opt in SUPPORTED_MODEL_TYPES], ", ",", and ")

const ERR_UNSUPPORTED_MODEL(model) = ArgumentError(
const ERR_UNSUPPORTED_MODEL(model) = ErrorException(
"Only these model supertypes support wrapping: "*
"$PRETTY_SUPPORTED_MODEL_TYPES.\n"*
"Model provided has type `$(typeof(model))`. "
)
const ERR_NUM_ARGS_BM = "`BalancedModel` can at most have one non-keyword argument where the model is passed."


"""
BalancedModel(; balancers=[], model=nothing)
BalancedModel(; model=nothing, balancer1=balancer_model1, balancer2=balancer_model2, ...)
BalancedModel(model; balancer1=balancer_model1, balancer2=balancer_model2, ...)
Wraps a classification model with balancers that resample the data before passing it to the model.
# Arguments
- `balancers::AbstractVector=[]`: A vector of balancers (i.e., resampling models).
Data passed to the model will be first passed to the balancers sequentially.
- `model=nothing`: The classification model which must be provided.
"""
function BalancedModel(; model=nothing, named_balancers...)
function BalancedModel(args...; model=nothing, named_balancers...)
# check model and balancer are given
model === nothing && throw(ERR_MODEL_UNSPECIFIED)
length(args) <= 1 || throw(ERR_NUM_ARGS_BM)
if length(args) === 1
atom = first(args)
model === nothing ||
@warn WRN_MODEL_GIVEN
model = atom
else
model === nothing && throw(ERR_MODEL_UNSPECIFIED)
end
# check model is supported
model isa UNION_MODEL_TYPES || throw(ERR_UNSUPPORTED_MODEL(model))

Expand Down Expand Up @@ -116,6 +125,7 @@ for model_type in SUPPORTED_MODEL_TYPES
eval(ex)
end


const ERR_NO_PROP = ArgumentError("trying to access property $name which does not exist")
# overload set property to set the property from the vector in the struct
for model_type in SUPPORTED_MODEL_TYPES
Expand Down
20 changes: 18 additions & 2 deletions test/balanced_bagging.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

@testset "group_inds and get_majority_minority_inds_counts" begin
y = [0, 0, 0, 0, 1, 1, 1, 0]
@test MLJBalancing.group_inds(y) == Dict(0 => [1, 2, 3, 4, 8], 1 => [5, 6, 7])
@test MLJBalancing.group_inds(y) == Dict(0 => [1, 2, 3, 4, 8], 1 => [5, 6, 7])
@test MLJBalancing.get_majority_minority_inds_counts(y) ==
([1, 2, 3, 4, 8], [5, 6, 7], 5, 3)
y = [0, 0, 0, 0, 1, 1, 1, 0, 2, 2, 2]
Expand Down Expand Up @@ -120,5 +120,21 @@ end
mach = machine(modelo, X, y)
fit!(mach)
@test report(mach) == (chosen_T = 5,)

end




@testset "Equivalence of Constructions" begin
## setup parameters
R = Random.Xoshiro(42)
T = 2
LogisticClassifier = @load LogisticClassifier pkg = MLJLinearModels verbosity = 0
model = LogisticClassifier()
BalancedBaggingClassifier(model=model, T=T, rng=R) == BalancedBaggingClassifier(model; T=T, rng=R)

@test_throws MLJBalancing.ERR_NUM_ARGS_BB BalancedBaggingClassifier(model, model; T=T, rng=R)
@test_logs (:warn, MLJBalancing.WRN_MODEL_GIVEN) begin
BalancedBaggingClassifier(model; model=model, T=T, rng=R)
end
end
20 changes: 17 additions & 3 deletions test/balanced_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
fit!(mach)
y_pred = MLJBase.predict(mach, X_test)

# with MLJ balancing
@test_throws MLJBalancing.ERR_MODEL_UNSPECIFIED begin
# with MLJ balancing
@test_throws MLJBalancing.ERR_MODEL_UNSPECIFIED begin
BalancedModel(b1 = balancer1, b2 = balancer2, b3 = balancer3)
end
@test_throws(
Expand All @@ -46,7 +46,6 @@
@test_logs (:warn, MLJBalancing.WRN_BALANCER_UNSPECIFIED) begin
BalancedModel(model = model_prob)
end

balanced_model =
BalancedModel(model = model_prob, b1 = balancer1, b2 = balancer2, b3 = balancer3)
mach = machine(balanced_model, X_train, y_train)
Expand Down Expand Up @@ -86,3 +85,18 @@
Base.setproperty!(balanced_model, :name11, balancer2),
)
end


@testset "Equivalence of Constructions" begin
## setup parameters
R = Random.Xoshiro(42)
LogisticClassifier = @load LogisticClassifier pkg = MLJLinearModels verbosity = 0
balancer1 = Imbalance.MLJ.RandomOversampler(ratios = 1.0, rng = 42)
model = LogisticClassifier()
BalancedModel(model=model, balancer1=balancer1) == BalancedModel(model; balancer1=balancer1)

@test_throws MLJBalancing.ERR_NUM_ARGS_BM BalancedModel(model, model; balancer1=balancer1)
@test_logs (:warn, MLJBalancing.WRN_MODEL_GIVEN) begin
BalancedModel(model; model=model, balancer1=balancer1)
end
end

0 comments on commit 8245f06

Please sign in to comment.