Skip to content

Commit

Permalink
Added the optional max_in_ensemble parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
varunlakshmanan authored Jun 8, 2020
1 parent 82c4929 commit f471336
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
6 changes: 3 additions & 3 deletions src/GlassClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class GlassClassifier:
def __init__(self):
pass

def fit(self, x_train, y_train, x_test, y_test):
def fit(self, x_train, y_train, x_test, y_test, timeout=5, max_in_ensemble=4):
is_classifier = True
global ensemble
ensemble = ensemble_models(optimize_hyperparams(build_models(is_classifier), x_train, y_train),
x_train, y_train, x_test, y_test, is_classifier)
ensemble = ensemble_models(optimize_hyperparams(build_models(is_classifier), x_train, y_train, timeout),
x_train, y_train, x_test, y_test, is_classifier, max_in_ensemble)
return ensemble

def predict(self, x_test):
Expand Down
4 changes: 2 additions & 2 deletions src/GlassRegressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class GlassRegressor:
def __init__(self):
pass

def fit(self, x_train, y_train, x_test, y_test, timeout=5):
def fit(self, x_train, y_train, x_test, y_test, timeout=5, max_in_ensemble=4):
is_classifier = False
global ensemble
ensemble = ensemble_models(optimize_hyperparams(build_models(is_classifier), x_train, y_train, timeout),
x_train, y_train, x_test, y_test, is_classifier)
x_train, y_train, x_test, y_test, is_classifier, max_in_ensemble)
return ensemble

def predict(self, x_test):
Expand Down
8 changes: 3 additions & 5 deletions src/ensemble_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
best_voting_estimator = []


def ensemble_models(optimized_estimators, x_train, y_train, x_test, y_test, is_classifier):
def ensemble_models(optimized_estimators, x_train, y_train, x_test, y_test, is_classifier, max_in_ensemble):
print("Finding the best model ensemble...")

MAX_ENSEMBLES = 4
all_estimator_combinations = []

# Store combinations of length 2 to length MAX_ENSEMBLES of all estimators in a list
for i in reversed(range(2, MAX_ENSEMBLES + 1)):
# Store combinations of length 2 to length max_in_ensemble of all estimators in a list
for i in reversed(range(2, max_in_ensemble + 1)):
temp_estimator_combinations = combinations(optimized_estimators, i)
all_estimator_combinations.extend(temp_estimator_combinations)

Expand Down

0 comments on commit f471336

Please sign in to comment.