diff --git a/conf/default.yaml b/conf/default.yaml index f89ac1c..6882596 100644 --- a/conf/default.yaml +++ b/conf/default.yaml @@ -3,8 +3,7 @@ core: project_name: nn-template storage_dir: ${oc.env:PROJECT_ROOT}/storage version: 0.0.1 - tags: - - mytag + tags: null defaults: - hydra: default diff --git a/src/nn_template/run.py b/src/nn_template/run.py index 77db836..7f44315 100644 --- a/src/nn_template/run.py +++ b/src/nn_template/run.py @@ -5,8 +5,10 @@ import hydra import omegaconf import pytorch_lightning as pl +from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig from pytorch_lightning import Callback, seed_everything +from rich.prompt import Prompt from nn_core.callbacks import NNTemplateCore from nn_core.common import PROJECT_ROOT @@ -73,6 +75,22 @@ def parse_restore(restore_cfg: DictConfig) -> Tuple[Optional[str], Optional[str] return resume_ckpt_path, resume_run_version +def enforce_tags(tags: Optional[List[str]]) -> List[str]: + if tags is None: + if "id" in HydraConfig().cfg.hydra.job: + # We are in multi-run setting (either via a sweep or a scheduler) + message: str = "You need to specify 'core.tags' in a multi-run setting!" + pylogger.error(message) + raise ValueError(message) + + pylogger.warning("No tags provided, asking for tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="develop") + tags = [x.strip() for x in tags.split(",")] + + pylogger.info(f"Tags: {tags if tags is not None else []}") + return tags + + def run(cfg: DictConfig) -> str: """Generic train loop. @@ -90,6 +108,7 @@ def run(cfg: DictConfig) -> str: cfg.nn.data.num_workers.val = 0 cfg.nn.data.num_workers.test = 0 + cfg.core.tags = enforce_tags(cfg.core.get("tags", None)) resume_ckpt_path, resume_run_version = parse_restore(cfg.train.restore) # Instantiate datamodule diff --git a/tests/conftest.py b/tests/conftest.py index 2a00ba2..8e778a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -47,6 +47,9 @@ def cfg(tmp_path_factory: TempPathFactory) -> DictConfig: def cfg_simple_train(cfg: DictConfig) -> DictConfig: cfg = OmegaConf.create(cfg) + # Add test tag + cfg.core.tags = ["testing"] + # Disable gpus cfg.train.trainer.gpus = 0