Skip to content

Commit

Permalink
Hide wandb.watch behind flag (FAIR-Chem#747)
Browse files Browse the repository at this point in the history
* hide watch behind flag

* set frequency

Former-commit-id: d96aa52df1267efbab698cac5927a355fc2849fc
  • Loading branch information
rayg1234 authored Jul 2, 2024
1 parent 6af9466 commit 1d77e20
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/fairchem/core/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, config) -> None:
self.config = config

@abstractmethod
def watch(self, model):
def watch(self, model, log_freq: int = 1000):
"""
Monitor parameters and gradients.
"""
Expand Down Expand Up @@ -82,8 +82,8 @@ def __init__(self, config) -> None:
resume="allow",
)

def watch(self, model) -> None:
wandb.watch(model)
def watch(self, model, log_freq: int = 1000) -> None:
wandb.watch(model, log_freq = log_freq)

def log(self, update_dict, step: int, split: str = "") -> None:
update_dict = super().log(update_dict, step, split)
Expand All @@ -109,7 +109,7 @@ def __init__(self, config) -> None:
self.writer = SummaryWriter(self.config["cmd"]["logs_dir"])

# TODO: add a model hook for watching gradients.
def watch(self, model) -> bool:
def watch(self, model, log_freq: int = 1000) -> bool:
logging.warning("Model gradient logging to tensorboard not yet supported.")
return False

Expand Down
5 changes: 4 additions & 1 deletion src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,10 @@ def load_model(self) -> None:
)

if self.logger is not None:
self.logger.watch(self.model)
# only "watch" model if user specify watch: True because logging gradients
# spews too much data into W&B and makes the UI slow to respond
if "watch" in self.config["logger"]:
self.logger.watch(self.model, log_freq = int(self.config["logger"]["watch"]))
self.logger.log_summary({"num_params": self.model.num_params})

if distutils.initialized() and not self.config["noddp"]:
Expand Down

0 comments on commit 1d77e20

Please sign in to comment.