diff --git a/src/fairchem/core/common/logger.py b/src/fairchem/core/common/logger.py index fb6e14fc9c..fd52756e20 100644 --- a/src/fairchem/core/common/logger.py +++ b/src/fairchem/core/common/logger.py @@ -27,7 +27,7 @@ def __init__(self, config) -> None: self.config = config @abstractmethod - def watch(self, model): + def watch(self, model, log_freq: int = 1000): """ Monitor parameters and gradients. """ @@ -82,8 +82,8 @@ def __init__(self, config) -> None: resume="allow", ) - def watch(self, model) -> None: - wandb.watch(model) + def watch(self, model, log_freq: int = 1000) -> None: + wandb.watch(model, log_freq = log_freq) def log(self, update_dict, step: int, split: str = "") -> None: update_dict = super().log(update_dict, step, split) @@ -109,7 +109,7 @@ def __init__(self, config) -> None: self.writer = SummaryWriter(self.config["cmd"]["logs_dir"]) # TODO: add a model hook for watching gradients. - def watch(self, model) -> bool: + def watch(self, model, log_freq: int = 1000) -> bool: logging.warning("Model gradient logging to tensorboard not yet supported.") return False diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index b44753f6eb..0da40320a8 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -431,7 +431,10 @@ def load_model(self) -> None: ) if self.logger is not None: - self.logger.watch(self.model) + # only "watch" model if user specify watch: True because logging gradients + # spews too much data into W&B and makes the UI slow to respond + if "watch" in self.config["logger"]: + self.logger.watch(self.model, log_freq = int(self.config["logger"]["watch"])) self.logger.log_summary({"num_params": self.model.num_params}) if distutils.initialized() and not self.config["noddp"]: