Skip to content

Commit

Permalink
Add singleton logger (#873)
Browse files Browse the repository at this point in the history
* add singleton wandb logger

* add trainer update

* update comment

* update comment
  • Loading branch information
rayg1234 authored Oct 9, 2024
1 parent 3012925 commit f38767c
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 1 deletion.
86 changes: 86 additions & 0 deletions src/fairchem/core/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,89 @@ def log_summary(self, summary_dict: dict[str, Any]) -> None:

def log_artifact(self, name: str, type: str, file_location: str) -> None:
logging.warning("log_artifact for Tensorboard not supported")


class WandBSingletonLogger:
"""
Singleton version of wandb logger, this forces a single instance of the logger to be created and used from anywhere in the code (not just from the trainer).
This will replace the original WandBLogger.
We initialize wandb instance somewhere in the trainer/runner globally:
WandBSingletonLogger.init_wandb(...)
Then from anywhere in the code we can fetch the singleton instance and log to wandb,
note this allows you to log without knowing explicitly which step you are on
see: https://docs.wandb.ai/ref/python/log/#the-wb-step for more details
WandBSingletonLogger.get_instance().log({"some_value": value}, commit=False)
"""

_instance = None

def __init__(self):
raise RuntimeError("Call get_instance() instead")

@classmethod
def init_wandb(
cls,
config: dict,
run_id: str,
run_name: str,
log_dir: str,
project: str,
entity: str,
group: str | None = None,
) -> None:
wandb.init(
config=config,
id=run_id,
name=run_name,
dir=log_dir,
project=project,
entity=entity,
resume="allow",
group=group,
)

@classmethod
def get_instance(cls):
assert wandb.run is not None, "wandb is not initialized, call init_wandb first!"
if cls._instance is None:
cls._instance = cls.__new__(cls)
return cls._instance

def watch(self, model, log_freq: int = 1000) -> None:
wandb.watch(model, log_freq=log_freq)

def log(
self, update_dict: dict, step: int | None = None, commit=False, split: str = ""
) -> None:
# HACK: this is really ugly logic here for backward compat but we should get rid of.
# the split string shouldn't inserted here
if split != "":
new_dict = {}
for key in update_dict:
new_dict[f"{split}/{key}"] = update_dict[key]
update_dict = new_dict

# if step is not specified, wandb will use an auto-incremented step: https://docs.wandb.ai/ref/python/log/
# otherwise the user must increment it manually (not recommended)
wandb.log(update_dict, step=step, commit=commit)

def log_plots(self, plots, caption: str = "") -> None:
assert isinstance(plots, list)
plots = [wandb.Image(x, caption=caption) for x in plots]
wandb.log({"data": plots})

def log_summary(self, summary_dict: dict[str, Any]):
for k, v in summary_dict.items():
wandb.run.summary[k] = v

def mark_preempting(self) -> None:
wandb.mark_preempting()

def log_artifact(self, name: str, type: str, file_location: str) -> None:
art = wandb.Artifact(name=name, type=type)
art.add_file(file_location)
art.save()
15 changes: 14 additions & 1 deletion src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from fairchem.core import __version__
from fairchem.core.common import distutils, gp_utils
from fairchem.core.common.data_parallel import BalancedBatchSampler
from fairchem.core.common.logger import WandBSingletonLogger
from fairchem.core.common.registry import registry
from fairchem.core.common.slurm import (
add_timestamp_id_to_submission_pickle,
Expand Down Expand Up @@ -275,7 +276,19 @@ def load_logger(self) -> None:
logger_name = logger if isinstance(logger, str) else logger["name"]
assert logger_name, "Specify logger name"

self.logger = registry.get_logger_class(logger_name)(self.config)
if logger_name == "wandb_singleton":
WandBSingletonLogger.init_wandb(
config=self.config,
run_id=self.config["cmd"]["timestamp_id"],
run_name=self.config["cmd"]["identifier"],
log_dir=self.config["cmd"]["logs_dir"],
project=self.config["logger"]["project"],
entity=self.config["logger"]["entity"],
group=self.config["logger"].get("group", ""),
)
self.logger = WandBSingletonLogger.get_instance()
else:
self.logger = registry.get_logger_class(logger_name)(self.config)

def get_sampler(
self, dataset, batch_size: int, shuffle: bool
Expand Down

0 comments on commit f38767c

Please sign in to comment.