diff --git a/benchmarks/benchmark_averaging.py b/benchmarks/benchmark_averaging.py index bda82610d..61b29b740 100644 --- a/benchmarks/benchmark_averaging.py +++ b/benchmarks/benchmark_averaging.py @@ -80,7 +80,7 @@ def run_averager(index): with lock_stats: successful_steps += int(success) total_steps += 1 - logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}") + logger.info(f"Averager {index}: {'finished' if success else 'failed'} step #{step}") logger.info(f"Averager {index}: done.") threads = [] diff --git a/examples/albert/run_trainer.py b/examples/albert/run_trainer.py index 861d9267f..ad7f8d7ce 100644 --- a/examples/albert/run_trainer.py +++ b/examples/albert/run_trainer.py @@ -1,6 +1,5 @@ #!/usr/bin/env python -import logging import os import pickle from dataclasses import asdict @@ -18,32 +17,22 @@ from transformers.trainer_utils import is_main_process import hivemind +from hivemind.utils.logging import get_logger, use_hivemind_log_handler import utils from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments -logger = logging.getLogger(__name__) -LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None) +use_hivemind_log_handler("in_root_logger") +logger = get_logger() +LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None) -def setup_logging(training_args): - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, - ) - # Log on each process the small summary: - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) - # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): +def setup_transformers_logging(process_rank: int): + if is_main_process(process_rank): transformers.utils.logging.set_verbosity_info() - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - logger.info("Training/evaluation parameters %s", training_args) + transformers.utils.logging.disable_default_handler() + transformers.utils.logging.enable_propagation() def get_model(training_args, config, tokenizer): @@ -149,7 +138,7 @@ def on_step_end( loss=self.loss, mini_steps=self.steps, ) - logger.info(f"Step {self.collaborative_optimizer.local_step}") + logger.info(f"Step #{self.collaborative_optimizer.local_step}") logger.info(f"Your current contribution: {self.total_samples_processed} samples") logger.info(f"Performance: {samples_per_second} samples per second.") if self.steps: @@ -220,7 +209,8 @@ def main(): if len(collaboration_args.initial_peers) == 0: raise ValueError("Please specify at least one network endpoint in initial peers.") - setup_logging(training_args) + setup_transformers_logging(training_args.local_rank) + logger.info(f"Training/evaluation parameters:\n{training_args}") # Set seed before initializing model. set_seed(training_args.seed) diff --git a/examples/albert/run_training_monitor.py b/examples/albert/run_training_monitor.py index 69c6b8d6b..3b42760d1 100644 --- a/examples/albert/run_training_monitor.py +++ b/examples/albert/run_training_monitor.py @@ -1,6 +1,5 @@ #!/usr/bin/env python -import logging import time from dataclasses import asdict, dataclass, field from ipaddress import ip_address @@ -13,11 +12,13 @@ from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser import hivemind +from hivemind.utils.logging import get_logger, use_hivemind_log_handler import utils from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments -logger = logging.getLogger(__name__) +use_hivemind_log_handler("in_root_logger") +logger = get_logger() @dataclass @@ -139,7 +140,7 @@ def upload_checkpoint(self, current_loss): self.model.push_to_hub( repo_name=self.repo_path, repo_url=self.repo_url, - commit_message=f"Step {current_step}, loss {current_loss:.3f}", + commit_message=f"Step #{current_step}, loss {current_loss:.3f}", ) logger.info("Finished uploading to Model Hub") diff --git a/hivemind/optim/collaborative.py b/hivemind/optim/collaborative.py index 8f9523f19..d4aac9a61 100644 --- a/hivemind/optim/collaborative.py +++ b/hivemind/optim/collaborative.py @@ -153,7 +153,7 @@ def __init__( self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha) self.last_step_time = None - self.collaboration_state = self.fetch_collaboration_state() + self.collaboration_state = self._fetch_state() self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event() self.lock_local_progress, self.should_report_progress = Lock(), Event() self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter") @@ -237,8 +237,8 @@ def step(self, batch_size: Optional[int] = None, **kwargs): if not self.collaboration_state.ready_for_step: return - logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}") - self.collaboration_state = self.fetch_collaboration_state() + logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}") + self.collaboration_state = self._fetch_state() self.collaboration_state_updated.set() if not self.is_synchronized: @@ -288,8 +288,8 @@ def step_aux(self, **kwargs): if not self.collaboration_state.ready_for_step: return - logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}") - self.collaboration_state = self.fetch_collaboration_state() + logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}") + self.collaboration_state = self._fetch_state() self.collaboration_state_updated.set() with self.lock_collaboration_state: @@ -392,9 +392,9 @@ def check_collaboration_state_periodically(self): continue # if state was updated externally, reset timer with self.lock_collaboration_state: - self.collaboration_state = self.fetch_collaboration_state() + self.collaboration_state = self._fetch_state() - def fetch_collaboration_state(self) -> CollaborationState: + def _fetch_state(self) -> CollaborationState: """Read performance statistics reported by peers, estimate progress towards next batch""" response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf")) current_time = get_dht_time() @@ -452,9 +452,9 @@ def fetch_collaboration_state(self) -> CollaborationState: ) logger.log( self.status_loglevel, - f"Collaboration accumulated {total_samples_accumulated} samples from " - f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds " - f"(refresh in {time_to_next_fetch:.2f}s.)", + f"{self.prefix} accumulated {total_samples_accumulated} samples from " + f"{num_peers} peers for step #{global_optimizer_step}. " + f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)", ) return CollaborationState( global_optimizer_step, diff --git a/hivemind/utils/logging.py b/hivemind/utils/logging.py index 6715b3386..f74b364ed 100644 --- a/hivemind/utils/logging.py +++ b/hivemind/utils/logging.py @@ -1,6 +1,11 @@ import logging import os import sys +import threading +from enum import Enum +from typing import Optional, Union + +logging.addLevelName(logging.WARNING, "WARN") loglevel = os.getenv("LOGLEVEL", "INFO") @@ -11,6 +16,17 @@ use_colors = sys.stderr.isatty() +class HandlerMode(Enum): + NOWHERE = 0 + IN_HIVEMIND = 1 + IN_ROOT_LOGGER = 2 + + +_init_lock = threading.RLock() +_current_mode = HandlerMode.IN_HIVEMIND +_default_handler = None + + class TextStyle: """ ANSI escape codes. Details: https://en.wikipedia.org/wiki/ANSI_escape_code#Colors @@ -60,23 +76,82 @@ def format(self, record: logging.LogRecord) -> str: return super().format(record) -def get_logger(module_name: str) -> logging.Logger: - # trim package name - name_without_prefix = ".".join(module_name.split(".")[1:]) +def _initialize_if_necessary(): + global _current_mode, _default_handler - logging.addLevelName(logging.WARNING, "WARN") - formatter = CustomFormatter( - fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}", - style="{", - datefmt="%b %d %H:%M:%S", - ) - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logger = logging.getLogger(name_without_prefix) - logger.setLevel(loglevel) - logger.addHandler(handler) + with _init_lock: + if _default_handler is not None: + return + + formatter = CustomFormatter( + fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}", + style="{", + datefmt="%b %d %H:%M:%S", + ) + _default_handler = logging.StreamHandler() + _default_handler.setFormatter(formatter) + + _enable_default_handler("hivemind") + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Same as ``logging.getLogger()`` but ensures that the default log handler is initialized. + """ + + _initialize_if_necessary() + return logging.getLogger(name) + + +def _enable_default_handler(name: str) -> None: + logger = get_logger(name) + logger.addHandler(_default_handler) logger.propagate = False - return logger + logger.setLevel(loglevel) + + +def _disable_default_handler(name: str) -> None: + logger = get_logger(name) + logger.removeHandler(_default_handler) + logger.propagate = True + logger.setLevel(logging.NOTSET) + + +def use_hivemind_log_handler(where: Union[HandlerMode, str]) -> None: + """ + Choose loggers where the default hivemind log handler is applied. Options for the ``where`` argument are: + + * "in_hivemind" (default): Use the hivemind log handler in the loggers of the ``hivemind`` package. + Don't propagate their messages to the root logger. + * "nowhere": Don't use the hivemind log handler anywhere. + Propagate the ``hivemind`` messages to the root logger. + * "in_root_logger": Use the hivemind log handler in the root logger + (that is, in all application loggers until they disable propagation to the root logger). + Propagate the ``hivemind`` messages to the root logger. + + The options may be defined as strings (case-insensitive) or values from the HandlerMode enum. + """ + + global _current_mode + + if isinstance(where, str): + # We allow `where` to be a string, so a developer does not have to import the enum for one usage + where = HandlerMode[where.upper()] + + if where == _current_mode: + return + + if _current_mode == HandlerMode.IN_HIVEMIND: + _disable_default_handler("hivemind") + elif _current_mode == HandlerMode.IN_ROOT_LOGGER: + _disable_default_handler(None) + + _current_mode = where + + if _current_mode == HandlerMode.IN_HIVEMIND: + _enable_default_handler("hivemind") + elif _current_mode == HandlerMode.IN_ROOT_LOGGER: + _enable_default_handler(None) def golog_level_to_python(level: str) -> int: