Skip to content

Commit

Permalink
Adapt xgboost and sklearn iris examples
Browse files Browse the repository at this point in the history
  • Loading branch information
gmontamat committed Sep 17, 2024
1 parent 5a84791 commit 902422f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions examples/scikit_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from gentun.algorithms import Tournament
from gentun.genes import RandomChoice
from gentun.models.sklearn import SklearnCV
from gentun.models.sklearn import Sklearn
from gentun.populations import Population


Expand Down Expand Up @@ -47,12 +47,12 @@ def parse_iris(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
"sklearn_model": RandomForestClassifier,
"sklearn_metric": f1_score,
"metric_kwargs": {"average": "macro"},
"kfold": 5,
"folds": 5,
}

# Fetch training data
x_train, y_train = parse_iris("iris.data")
# Run genetic algorithm on a population of 10 for 10 generations
population = Population(genes, SklearnCV, 10, x_train, y_train, **kwargs)
population = Population(genes, Sklearn, 10, x_train, y_train, **kwargs)
algorithm = Tournament(population)
algorithm.run(10)
4 changes: 2 additions & 2 deletions examples/xgboost_grid_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from gentun.algorithms import Tournament
from gentun.genes import RandomChoice, RandomLogUniform
from gentun.models.xgboost import XGBoostCV
from gentun.models.xgboost import XGBoost
from gentun.populations import Grid


Expand Down Expand Up @@ -56,6 +56,6 @@ def parse_iris(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
# Fetch training data
x_train, y_train = parse_iris("iris.data")
# Run genetic algorithm on a grid population for 1 generation
population = Grid(genes, XGBoostCV, gene_samples, x_train, y_train, **kwargs)
population = Grid(genes, XGBoost, gene_samples, x_train, y_train, **kwargs)
algorithm = Tournament(population)
algorithm.run(1, maximize=False)
4 changes: 2 additions & 2 deletions examples/xgboost_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from gentun.algorithms import Tournament
from gentun.genes import RandomChoice, RandomLogUniform, RandomUniform
from gentun.models.xgboost import XGBoostCV
from gentun.models.xgboost import XGBoost
from gentun.populations import Population


Expand Down Expand Up @@ -63,6 +63,6 @@ def parse_iris(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
# Fetch training data
x_train, y_train = parse_iris("iris.data")
# Run genetic algorithm on a population of 50 for 100 generations
population = Population(genes, XGBoostCV, 50, x_train, y_train, **kwargs)
population = Population(genes, XGBoost, 50, x_train, y_train, **kwargs)
algorithm = Tournament(population)
algorithm.run(100, maximize=False)

0 comments on commit 902422f

Please sign in to comment.