diff --git a/configs/test.yaml b/configs/test.yaml index 1b4e5f5..f3965f2 100644 --- a/configs/test.yaml +++ b/configs/test.yaml @@ -9,6 +9,7 @@ batch_size: 8 # EXPERIMENT +use_final_ckpt: false finetune: false ckpt_dir: ??? diff --git a/configs/train.yaml b/configs/train.yaml index a1ea932..28b7fa7 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -19,6 +19,7 @@ limited_label_val: 1 limited_label_strategy: stratified # Options: stratified, oversampled, random stratification_bins: 3 # number of bins for stratified sampling, only for stratified data_replicate: 1 +use_final_ckpt: false defaults: diff --git a/pangaea/run.py b/pangaea/run.py index 82f2c9a..d6f4ec1 100644 --- a/pangaea/run.py +++ b/pangaea/run.py @@ -24,6 +24,7 @@ from pangaea.utils.utils import ( fix_seed, get_best_model_ckpt_path, + get_final_model_ckpt_path, get_generator, seed_worker, ) @@ -278,8 +279,12 @@ def main(cfg: DictConfig) -> None: test_evaluator: Evaluator = instantiate( cfg.task.evaluator, val_loader=test_loader, exp_dir=exp_dir, device=device ) - best_model_ckpt_path = get_best_model_ckpt_path(exp_dir) - test_evaluator.evaluate(decoder, "best_model", best_model_ckpt_path) + + if cfg.use_final_ckpt: + model_ckpt_path = get_final_model_ckpt_path(exp_dir) + else: + model_ckpt_path = get_best_model_ckpt_path(exp_dir) + test_evaluator.evaluate(decoder, "test_model", model_ckpt_path) if cfg.use_wandb and rank == 0: wandb.finish()