Skip to content

Commit

Permalink
Merge pull request #332 from ASUS-AICS/search_param_no_retrain
Browse files Browse the repository at this point in the history
Modify no_merge_train_val to retrain.
  • Loading branch information
Gordon119 authored Sep 15, 2023
2 parents 7e620ae + 0122791 commit 0491489
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,15 @@ def init_search_algorithm(search_alg, metric=None, mode=None):
logging.info(f"{search_alg} search is found, run BasicVariantGenerator().")


def prepare_retrain_config(best_config, best_log_dir, merge_train_val):
def prepare_retrain_config(best_config, best_log_dir, retrain):
"""Prepare the configuration for re-training.
Args:
best_config (AttributeDict): The best hyper-parameter configuration.
best_log_dir (str): The directory of the best trial of the experiment.
merge_train_val (bool): Whether to merge the training and validation data.
retrain (bool): Whether to retrain the model with merged training and validation data.
"""
if merge_train_val:
if retrain:
best_config.merge_train_val = True

log_path = os.path.join(best_log_dir, "logs.json")
Expand Down Expand Up @@ -205,31 +205,31 @@ def load_static_data(config, merge_train_val=False):
}


def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val):
def retrain_best_model(exp_name, best_config, best_log_dir, retrain):
"""Re-train the model with the best hyper-parameters.
A new model is trained on the combined training and validation data if `merge_train_val` is True.
A new model is trained on the combined training and validation data if `retrain` is True.
If a test set is provided, it will be evaluated by the obtained model.
Args:
exp_name (str): The directory to save trials generated by ray tune.
best_config (AttributeDict): The best hyper-parameter configuration.
best_log_dir (str): The directory of the best trial of the experiment.
merge_train_val (bool): Whether to merge the training and validation data.
retrain (bool): Whether to retrain the model with merged training and validation data.
"""
best_config.silent = False
checkpoint_dir = os.path.join(best_config.result_dir, exp_name, "trial_best_params")
os.makedirs(checkpoint_dir, exist_ok=True)
with open(os.path.join(checkpoint_dir, "params.yml"), "w") as fp:
yaml.dump(dict(best_config), fp)
best_config.run_name = "_".join(exp_name.split("_")[:-1]) + "_best"
best_config.checkpoint_dir = checkpoint_dir
best_config.log_path = os.path.join(best_config.checkpoint_dir, "logs.json")
prepare_retrain_config(best_config, best_log_dir, merge_train_val)
prepare_retrain_config(best_config, best_log_dir, retrain)
set_seed(seed=best_config.seed)
with open(os.path.join(checkpoint_dir, "params.yml"), "w") as fp:
yaml.dump(dict(best_config), fp)

data = load_static_data(best_config, merge_train_val=best_config.merge_train_val)

if merge_train_val:
if retrain:
logging.info(f"Re-training with best config: \n{best_config}")
trainer = TorchTrainer(config=best_config, **data)
trainer.train()
Expand All @@ -247,7 +247,7 @@ def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val):

if "test" in data["datasets"]:
test_results = trainer.test()
if merge_train_val:
if retrain:
logging.info(f"Test results after re-training: {test_results}")
else:
logging.info(f"Test results of best config: {test_results}")
Expand All @@ -260,8 +260,18 @@ def main():
"--config",
help="Path to configuration file (default: %(default)s). Please specify a config with all arguments in LibMultiLabel/main.py::get_config.",
)
parser.add_argument("--cpu_count", type=int, default=4, help="Number of CPU per trial (default: %(default)s)")
parser.add_argument("--gpu_count", type=int, default=1, help="Number of GPU per trial (default: %(default)s)")
parser.add_argument(
"--cpu_count",
type=int,
default=4,
help="Number of CPU per trial (default: %(default)s)",
)
parser.add_argument(
"--gpu_count",
type=int,
default=1,
help="Number of GPU per trial (default: %(default)s)",
)
parser.add_argument(
"--num_samples",
type=int,
Expand All @@ -275,9 +285,9 @@ def main():
help="Search algorithms (default: %(default)s)",
)
parser.add_argument(
"--no_merge_train_val",
"--no_retrain",
action="store_true",
help="Do not add the validation set in re-training the final model after hyper-parameter search.",
help="Do not retrain the model with validation set after hyperparameter search.",
)
args, _ = parser.parse_known_args()

Expand Down Expand Up @@ -343,7 +353,7 @@ def main():
# Save best model after parameter search.
best_config = analysis.get_best_config(f"val_{config.val_metric}", config.mode, scope="all")
best_log_dir = analysis.get_best_logdir(f"val_{config.val_metric}", config.mode, scope="all")
retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val=not config.no_merge_train_val)
retrain_best_model(exp_name, best_config, best_log_dir, retrain=not config.no_retrain)


if __name__ == "__main__":
Expand Down

0 comments on commit 0491489

Please sign in to comment.