Skip to content

Commit

Permalink
Make log handlers configurable, shorten entries (#378)
Browse files Browse the repository at this point in the history
1. Fix bugs: make `get_logger()` idempotent and don't trim the actual logger name.
2. Allow a developer to choose where the default hivemind log handler is enabled (in hivemind/in the root logger/nowhere).
3. Enable the `in_root_logger` mode in `examples/albert`, so that all messages (from `__main__`, `transformers`, and `hivemind` itself) consistently follow the hivemind style.
4. Change some log messages to improve their presentation.

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
  • Loading branch information
borzunov and mryab committed Sep 6, 2021
1 parent fb3f57b commit b84f62b
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 50 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmark_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run_averager(index):
with lock_stats:
successful_steps += int(success)
total_steps += 1
logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}")
logger.info(f"Averager {index}: {'finished' if success else 'failed'} step #{step}")
logger.info(f"Averager {index}: done.")

threads = []
Expand Down
32 changes: 11 additions & 21 deletions examples/albert/run_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python

import logging
import os
import pickle
from dataclasses import asdict
Expand All @@ -18,32 +17,22 @@
from transformers.trainer_utils import is_main_process

import hivemind
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

import utils
from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments

logger = logging.getLogger(__name__)
LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
use_hivemind_log_handler("in_root_logger")
logger = get_logger()

LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)

def setup_logging(training_args):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
)

# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
def setup_transformers_logging(process_rank: int):
if is_main_process(process_rank):
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.info("Training/evaluation parameters %s", training_args)
transformers.utils.logging.disable_default_handler()
transformers.utils.logging.enable_propagation()


def get_model(training_args, config, tokenizer):
Expand Down Expand Up @@ -149,7 +138,7 @@ def on_step_end(
loss=self.loss,
mini_steps=self.steps,
)
logger.info(f"Step {self.collaborative_optimizer.local_step}")
logger.info(f"Step #{self.collaborative_optimizer.local_step}")
logger.info(f"Your current contribution: {self.total_samples_processed} samples")
logger.info(f"Performance: {samples_per_second} samples per second.")
if self.steps:
Expand Down Expand Up @@ -220,7 +209,8 @@ def main():
if len(collaboration_args.initial_peers) == 0:
raise ValueError("Please specify at least one network endpoint in initial peers.")

setup_logging(training_args)
setup_transformers_logging(training_args.local_rank)
logger.info(f"Training/evaluation parameters:\n{training_args}")

# Set seed before initializing model.
set_seed(training_args.seed)
Expand Down
7 changes: 4 additions & 3 deletions examples/albert/run_training_monitor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python

import logging
import time
from dataclasses import asdict, dataclass, field
from ipaddress import ip_address
Expand All @@ -13,11 +12,13 @@
from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser

import hivemind
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

import utils
from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments

logger = logging.getLogger(__name__)
use_hivemind_log_handler("in_root_logger")
logger = get_logger()


@dataclass
Expand Down Expand Up @@ -139,7 +140,7 @@ def upload_checkpoint(self, current_loss):
self.model.push_to_hub(
repo_name=self.repo_path,
repo_url=self.repo_url,
commit_message=f"Step {current_step}, loss {current_loss:.3f}",
commit_message=f"Step #{current_step}, loss {current_loss:.3f}",
)
logger.info("Finished uploading to Model Hub")

Expand Down
20 changes: 10 additions & 10 deletions hivemind/optim/collaborative.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
self.last_step_time = None

self.collaboration_state = self.fetch_collaboration_state()
self.collaboration_state = self._fetch_state()
self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
self.lock_local_progress, self.should_report_progress = Lock(), Event()
self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
Expand Down Expand Up @@ -237,8 +237,8 @@ def step(self, batch_size: Optional[int] = None, **kwargs):
if not self.collaboration_state.ready_for_step:
return

logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
self.collaboration_state = self.fetch_collaboration_state()
logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
self.collaboration_state = self._fetch_state()
self.collaboration_state_updated.set()

if not self.is_synchronized:
Expand Down Expand Up @@ -288,8 +288,8 @@ def step_aux(self, **kwargs):
if not self.collaboration_state.ready_for_step:
return

logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
self.collaboration_state = self.fetch_collaboration_state()
logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
self.collaboration_state = self._fetch_state()
self.collaboration_state_updated.set()

with self.lock_collaboration_state:
Expand Down Expand Up @@ -392,9 +392,9 @@ def check_collaboration_state_periodically(self):
continue # if state was updated externally, reset timer

with self.lock_collaboration_state:
self.collaboration_state = self.fetch_collaboration_state()
self.collaboration_state = self._fetch_state()

def fetch_collaboration_state(self) -> CollaborationState:
def _fetch_state(self) -> CollaborationState:
"""Read performance statistics reported by peers, estimate progress towards next batch"""
response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
current_time = get_dht_time()
Expand Down Expand Up @@ -452,9 +452,9 @@ def fetch_collaboration_state(self) -> CollaborationState:
)
logger.log(
self.status_loglevel,
f"Collaboration accumulated {total_samples_accumulated} samples from "
f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
f"(refresh in {time_to_next_fetch:.2f}s.)",
f"{self.prefix} accumulated {total_samples_accumulated} samples from "
f"{num_peers} peers for step #{global_optimizer_step}. "
f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
)
return CollaborationState(
global_optimizer_step,
Expand Down
105 changes: 90 additions & 15 deletions hivemind/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import logging
import os
import sys
import threading
from enum import Enum
from typing import Optional, Union

logging.addLevelName(logging.WARNING, "WARN")

loglevel = os.getenv("LOGLEVEL", "INFO")

Expand All @@ -11,6 +16,17 @@
use_colors = sys.stderr.isatty()


class HandlerMode(Enum):
NOWHERE = 0
IN_HIVEMIND = 1
IN_ROOT_LOGGER = 2


_init_lock = threading.RLock()
_current_mode = HandlerMode.IN_HIVEMIND
_default_handler = None


class TextStyle:
"""
ANSI escape codes. Details: https://en.wikipedia.org/wiki/ANSI_escape_code#Colors
Expand Down Expand Up @@ -60,23 +76,82 @@ def format(self, record: logging.LogRecord) -> str:
return super().format(record)


def get_logger(module_name: str) -> logging.Logger:
# trim package name
name_without_prefix = ".".join(module_name.split(".")[1:])
def _initialize_if_necessary():
global _current_mode, _default_handler

logging.addLevelName(logging.WARNING, "WARN")
formatter = CustomFormatter(
fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
style="{",
datefmt="%b %d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger(name_without_prefix)
logger.setLevel(loglevel)
logger.addHandler(handler)
with _init_lock:
if _default_handler is not None:
return

formatter = CustomFormatter(
fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
style="{",
datefmt="%b %d %H:%M:%S",
)
_default_handler = logging.StreamHandler()
_default_handler.setFormatter(formatter)

_enable_default_handler("hivemind")


def get_logger(name: Optional[str] = None) -> logging.Logger:
"""
Same as ``logging.getLogger()`` but ensures that the default log handler is initialized.
"""

_initialize_if_necessary()
return logging.getLogger(name)


def _enable_default_handler(name: str) -> None:
logger = get_logger(name)
logger.addHandler(_default_handler)
logger.propagate = False
return logger
logger.setLevel(loglevel)


def _disable_default_handler(name: str) -> None:
logger = get_logger(name)
logger.removeHandler(_default_handler)
logger.propagate = True
logger.setLevel(logging.NOTSET)


def use_hivemind_log_handler(where: Union[HandlerMode, str]) -> None:
"""
Choose loggers where the default hivemind log handler is applied. Options for the ``where`` argument are:
* "in_hivemind" (default): Use the hivemind log handler in the loggers of the ``hivemind`` package.
Don't propagate their messages to the root logger.
* "nowhere": Don't use the hivemind log handler anywhere.
Propagate the ``hivemind`` messages to the root logger.
* "in_root_logger": Use the hivemind log handler in the root logger
(that is, in all application loggers until they disable propagation to the root logger).
Propagate the ``hivemind`` messages to the root logger.
The options may be defined as strings (case-insensitive) or values from the HandlerMode enum.
"""

global _current_mode

if isinstance(where, str):
# We allow `where` to be a string, so a developer does not have to import the enum for one usage
where = HandlerMode[where.upper()]

if where == _current_mode:
return

if _current_mode == HandlerMode.IN_HIVEMIND:
_disable_default_handler("hivemind")
elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
_disable_default_handler(None)

_current_mode = where

if _current_mode == HandlerMode.IN_HIVEMIND:
_enable_default_handler("hivemind")
elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
_enable_default_handler(None)


def golog_level_to_python(level: str) -> int:
Expand Down

0 comments on commit b84f62b

Please sign in to comment.