Skip to content

Commit

Permalink
add surrogare model
Browse files Browse the repository at this point in the history
  • Loading branch information
maypink committed Aug 15, 2023
1 parent 6eaa4d2 commit 9220f92
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
2 changes: 1 addition & 1 deletion experiments/mab_experiment/mab_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ common_fedot_params:
problem: classification
n_jobs: -1
show_progress: false
context_agent_type: nodes_num
context_agent_type: surrogate
adaptive_mutation_type: contextual_bandit

FEDOT_Classic:
Expand Down
31 changes: 26 additions & 5 deletions experiments/mab_experiment/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import openml
import pandas as pd
from fedot.api.main import Fedot
from fedot.core.pipelines.node import PrimaryNode
from fedot.core.pipelines.pipeline import Pipeline
from tqdm import tqdm
import yaml

Expand All @@ -16,8 +18,15 @@
from meta_automl.data_preparation.dataset import OpenMLDataset
from meta_automl.data_preparation.datasets_loaders import OpenMLDatasetsLoader
from meta_automl.data_preparation.datasets_train_test_split import openml_datasets_train_test_split
from meta_automl.data_preparation.feature_preprocessors import FeaturesPreprocessor
from meta_automl.data_preparation.file_system import get_project_root, get_cache_dir
import warnings

from meta_automl.data_preparation.meta_features_extractors import OpenMLDatasetMetaFeaturesExtractor
from meta_automl.data_preparation.pipeline_features_extractors import FEDOTPipelineFeaturesExtractor
from meta_automl.surrogate.models import RankingPipelineDatasetSurrogateModel
from thegolem import DataPipelineSurrogate

warnings.filterwarnings("ignore")


Expand Down Expand Up @@ -125,11 +134,12 @@ def run_experiment_per_launch(experiment_params_dict, experiment_date, config, d
timeout = config['timeout']
run_date = datetime.now()

surrogate_model = RankingPipelineDatasetSurrogateModel.load_from_checkpoint(
checkpoint_path="./experiments/base/checkpoints/last.ckpt",
hparams_file="./experiments/base/hparams.yaml"
)
config['common_fedot_params']['FEDOT_MAB']['context_agent_type'] = S
# get surrogate model
if experiment_label == 'FEDOT_MAB':
context_agent_type = config['common_fedot_params']['FEDOT_MAB']['context_agent_type']
if context_agent_type == 'surrogate':
config['common_fedot_params']['FEDOT_MAB']['context_agent_type'] = _load_surrogate_model()

fedot, run_results = fit_fedot(dataset=dataset, timeout=timeout, run_label='FEDOT',
**config['common_fedot_params'][experiment_label])
save_evaluation(run_results, run_date, experiment_date, save_dir)
Expand All @@ -140,6 +150,17 @@ def run_experiment_per_launch(experiment_params_dict, experiment_date, config, d
best_models_per_dataset[dataset_id] = best_models


def _load_surrogate_model() -> RankingPipelineDatasetSurrogateModel:
checkpoint_path = os.path.join(get_project_root(), 'experiments', 'base', 'checkpoints', 'last.ckpt')
hparams_file = os.path.join(get_project_root(), 'experiments', 'base', 'hparams.yaml')
surrogate_model = RankingPipelineDatasetSurrogateModel.load_from_checkpoint(
checkpoint_path=checkpoint_path,
hparams_file=hparams_file
)

return surrogate_model


if __name__ == '__main__':
config_name = 'mab_config.yaml'
path_to_config = os.path.join(get_project_root(), 'experiments', 'mab_experiment', config_name)
Expand Down

0 comments on commit 9220f92

Please sign in to comment.