Skip to content

Commit

Permalink
Ensure tags are defined asking interactively for them (#30)
Browse files Browse the repository at this point in the history
* Ensure tags are defined asking interactively for them

* Add testing tag in tests

* Improve logging of used tags

* Refactor the "tags" attribute to be completely optional.

* Raise error on missing "core.tags" in multi-runs.

Co-authored-by: Valentino Maiorca <valentino@maiorca.xyz>
  • Loading branch information
lucmos and Flegyas committed Jan 16, 2022
1 parent 926284f commit 6cc6302
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
3 changes: 1 addition & 2 deletions conf/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ core:
project_name: nn-template
storage_dir: ${oc.env:PROJECT_ROOT}/storage
version: 0.0.1
tags:
- mytag
tags: null

defaults:
- hydra: default
Expand Down
19 changes: 19 additions & 0 deletions src/nn_template/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import hydra
import omegaconf
import pytorch_lightning as pl
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from pytorch_lightning import Callback, seed_everything
from rich.prompt import Prompt

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
Expand Down Expand Up @@ -73,6 +75,22 @@ def parse_restore(restore_cfg: DictConfig) -> Tuple[Optional[str], Optional[str]
return resume_ckpt_path, resume_run_version


def enforce_tags(tags: Optional[List[str]]) -> List[str]:
if tags is None:
if "id" in HydraConfig().cfg.hydra.job:
# We are in multi-run setting (either via a sweep or a scheduler)
message: str = "You need to specify 'core.tags' in a multi-run setting!"
pylogger.error(message)
raise ValueError(message)

pylogger.warning("No tags provided, asking for tags...")
tags = Prompt.ask("Enter a list of comma separated tags", default="develop")
tags = [x.strip() for x in tags.split(",")]

pylogger.info(f"Tags: {tags if tags is not None else []}")
return tags


def run(cfg: DictConfig) -> str:
"""Generic train loop.
Expand All @@ -90,6 +108,7 @@ def run(cfg: DictConfig) -> str:
cfg.nn.data.num_workers.val = 0
cfg.nn.data.num_workers.test = 0

cfg.core.tags = enforce_tags(cfg.core.get("tags", None))
resume_ckpt_path, resume_run_version = parse_restore(cfg.train.restore)

# Instantiate datamodule
Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def cfg(tmp_path_factory: TempPathFactory) -> DictConfig:
def cfg_simple_train(cfg: DictConfig) -> DictConfig:
cfg = OmegaConf.create(cfg)

# Add test tag
cfg.core.tags = ["testing"]

# Disable gpus
cfg.train.trainer.gpus = 0

Expand Down

0 comments on commit 6cc6302

Please sign in to comment.