Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make log handlers configurable, shorten entries #378

Merged
merged 17 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run_averager(index):
with lock_stats:
successful_steps += int(success)
total_steps += 1
logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}")
logger.info(f"Averager {index}: {'finished' if success else 'failed'} step #{step}")
logger.info(f"Averager {index}: done.")

threads = []
Expand Down
32 changes: 11 additions & 21 deletions examples/albert/run_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python

import logging
import os
import pickle
from dataclasses import asdict
Expand All @@ -18,32 +17,22 @@
from transformers.trainer_utils import is_main_process

import hivemind
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

import utils
from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments

logger = logging.getLogger(__name__)
LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
use_hivemind_log_handler("in_root_logger")
logger = get_logger()

LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)

def setup_logging(training_args):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This basicConfig() setting had no effect:

> This function does nothing if the root logger already has handlers configured, unless the keyword argument force is set to True.


# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This message is removed since training_args are already logged.

# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
def setup_transformers_logging(process_rank: int):
if is_main_process(process_rank):
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.info("Training/evaluation parameters %s", training_args)
transformers.utils.logging.disable_default_handler()
transformers.utils.logging.enable_propagation()


def get_model(training_args, config, tokenizer):
Expand Down Expand Up @@ -149,7 +138,7 @@ def on_step_end(
loss=self.loss,
mini_steps=self.steps,
)
logger.info(f"Step {self.collaborative_optimizer.local_step}")
logger.info(f"Step #{self.collaborative_optimizer.local_step}")
logger.info(f"Your current contribution: {self.total_samples_processed} samples")
logger.info(f"Performance: {samples_per_second} samples per second.")
if self.steps:
Expand Down Expand Up @@ -220,7 +209,8 @@ def main():
if len(collaboration_args.initial_peers) == 0:
raise ValueError("Please specify at least one network endpoint in initial peers.")

setup_logging(training_args)
setup_transformers_logging(training_args.local_rank)
logger.info(f"Training/evaluation parameters:\n{training_args}")

# Set seed before initializing model.
set_seed(training_args.seed)
Expand Down
7 changes: 4 additions & 3 deletions examples/albert/run_training_monitor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python

import logging
import time
from dataclasses import asdict, dataclass, field
from ipaddress import ip_address
Expand All @@ -13,11 +12,13 @@
from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser

import hivemind
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

import utils
from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments

logger = logging.getLogger(__name__)
use_hivemind_log_handler("in_root_logger")
logger = get_logger()


@dataclass
Expand Down Expand Up @@ -139,7 +140,7 @@ def upload_checkpoint(self, current_loss):
self.model.push_to_hub(
repo_name=self.repo_path,
repo_url=self.repo_url,
commit_message=f"Step {current_step}, loss {current_loss:.3f}",
commit_message=f"Step #{current_step}, loss {current_loss:.3f}",
)
logger.info("Finished uploading to Model Hub")

Expand Down
20 changes: 10 additions & 10 deletions hivemind/optim/collaborative.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
self.last_step_time = None

self.collaboration_state = self.fetch_collaboration_state()
self.collaboration_state = self._fetch_state()
self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
self.lock_local_progress, self.should_report_progress = Lock(), Event()
self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
Expand Down Expand Up @@ -237,8 +237,8 @@ def step(self, batch_size: Optional[int] = None, **kwargs):
if not self.collaboration_state.ready_for_step:
return

logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
self.collaboration_state = self.fetch_collaboration_state()
logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
self.collaboration_state = self._fetch_state()
self.collaboration_state_updated.set()

if not self.is_synchronized:
Expand Down Expand Up @@ -288,8 +288,8 @@ def step_aux(self, **kwargs):
if not self.collaboration_state.ready_for_step:
return

logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
self.collaboration_state = self.fetch_collaboration_state()
logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
self.collaboration_state = self._fetch_state()
self.collaboration_state_updated.set()

with self.lock_collaboration_state:
Expand Down Expand Up @@ -392,9 +392,9 @@ def check_collaboration_state_periodically(self):
continue # if state was updated externally, reset timer

with self.lock_collaboration_state:
self.collaboration_state = self.fetch_collaboration_state()
self.collaboration_state = self._fetch_state()

def fetch_collaboration_state(self) -> CollaborationState:
def _fetch_state(self) -> CollaborationState:
"""Read performance statistics reported by peers, estimate progress towards next batch"""
response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
current_time = get_dht_time()
Expand Down Expand Up @@ -452,9 +452,9 @@ def fetch_collaboration_state(self) -> CollaborationState:
)
logger.log(
self.status_loglevel,
f"Collaboration accumulated {total_samples_accumulated} samples from "
f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
f"(refresh in {time_to_next_fetch:.2f}s.)",
f"{self.prefix} accumulated {total_samples_accumulated} samples from "
f"{num_peers} peers for step #{global_optimizer_step}. "
f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
)
return CollaborationState(
global_optimizer_step,
Expand Down
105 changes: 90 additions & 15 deletions hivemind/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import logging
import os
import sys
import threading
from enum import Enum
from typing import Optional, Union

logging.addLevelName(logging.WARNING, "WARN")

loglevel = os.getenv("LOGLEVEL", "INFO")

Expand All @@ -11,6 +16,17 @@
use_colors = sys.stderr.isatty()


class HandlerMode(Enum):
NOWHERE = 0
IN_HIVEMIND = 1
IN_ROOT_LOGGER = 2


_init_lock = threading.RLock()
_current_mode = HandlerMode.IN_HIVEMIND
_default_handler = None


class TextStyle:
"""
ANSI escape codes. Details: https://en.wikipedia.org/wiki/ANSI_escape_code#Colors
Expand Down Expand Up @@ -60,23 +76,82 @@ def format(self, record: logging.LogRecord) -> str:
return super().format(record)


def get_logger(module_name: str) -> logging.Logger:
# trim package name
name_without_prefix = ".".join(module_name.split(".")[1:])
def _initialize_if_necessary():
global _current_mode, _default_handler

logging.addLevelName(logging.WARNING, "WARN")
formatter = CustomFormatter(
fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
style="{",
datefmt="%b %d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger(name_without_prefix)
logger.setLevel(loglevel)
logger.addHandler(handler)
with _init_lock:
if _default_handler is not None:
return

formatter = CustomFormatter(
fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
style="{",
datefmt="%b %d %H:%M:%S",
)
_default_handler = logging.StreamHandler()
_default_handler.setFormatter(formatter)

_enable_default_handler("hivemind")


def get_logger(name: Optional[str] = None) -> logging.Logger:
"""
Same as ``logging.getLogger()`` but ensures that the default log handler is initialized.
"""

_initialize_if_necessary()
return logging.getLogger(name)


def _enable_default_handler(name: str) -> None:
logger = get_logger(name)
logger.addHandler(_default_handler)
logger.propagate = False
return logger
logger.setLevel(loglevel)


def _disable_default_handler(name: str) -> None:
logger = get_logger(name)
logger.removeHandler(_default_handler)
logger.propagate = True
logger.setLevel(logging.NOTSET)


def use_hivemind_log_handler(where: Union[HandlerMode, str]) -> None:
"""
Choose loggers where the default hivemind log handler is applied. Options for the ``where`` argument are:

* "in_hivemind" (default): Use the hivemind log handler in the loggers of the ``hivemind`` package.
Don't propagate their messages to the root logger.
* "nowhere": Don't use the hivemind log handler anywhere.
Propagate the ``hivemind`` messages to the root logger.
* "in_root_logger": Use the hivemind log handler in the root logger
(that is, in all application loggers until they disable propagation to the root logger).
Propagate the ``hivemind`` messages to the root logger.

The options may be defined as strings (case-insensitive) or values from the HandlerMode enum.
"""

global _current_mode

if isinstance(where, str):
# We allow `where` to be a string, so a developer does not have to import the enum for one usage
where = HandlerMode[where.upper()]

if where == _current_mode:
return

if _current_mode == HandlerMode.IN_HIVEMIND:
_disable_default_handler("hivemind")
elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
_disable_default_handler(None)

_current_mode = where

if _current_mode == HandlerMode.IN_HIVEMIND:
_enable_default_handler("hivemind")
elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
_enable_default_handler(None)


def golog_level_to_python(level: str) -> int:
Expand Down