Skip to content

Commit

Permalink
Improve names, add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Sep 6, 2021
1 parent 1b00d2b commit 18ec6ef
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
4 changes: 2 additions & 2 deletions examples/albert/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from transformers.trainer_utils import is_main_process

import hivemind
from hivemind.utils.logging import get_logger, use_hivemind_log_format
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

import utils
from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments

use_hivemind_log_format("in_root_logger")
use_hivemind_log_handler("in_root_logger")
logger = get_logger()

LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
Expand Down
4 changes: 2 additions & 2 deletions examples/albert/run_training_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser

import hivemind
from hivemind.utils.logging import get_logger, use_hivemind_log_format
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

import utils
from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments

use_hivemind_log_format("in_root_logger")
use_hivemind_log_handler("in_root_logger")
logger = get_logger()


Expand Down
36 changes: 27 additions & 9 deletions hivemind/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def format(self, record: logging.LogRecord) -> str:
_PACKAGE_NAME = __name__.split(".")[0]

_init_lock = threading.RLock()
_current_mode = StyleMode.NOWHERE # This is the initial state before module initialization but not an actual default
_current_mode = HandlerMode.NOWHERE # This is the initial state before module initialization but not an actual default
_default_handler = None


Expand All @@ -90,10 +90,14 @@ def _initialize_if_necessary():
_default_handler = logging.StreamHandler()
_default_handler.setFormatter(formatter)

use_hivemind_log_format(StyleMode.IN_HIVEMIND) # Overriding it to the desired default
use_hivemind_log_handler(HandlerMode.IN_HIVEMIND) # Overriding it to the desired default


def get_logger(name: Optional[str] = None) -> logging.Logger:
"""
Same as ``logging.getLogger()`` but ensures that the default log handler is initialized.
"""

_initialize_if_necessary()
return logging.getLogger(name)

Expand All @@ -112,29 +116,43 @@ def _disable_default_handler(name: str) -> None:
logger.setLevel(logging.NOTSET)


class StyleMode(Enum):
class HandlerMode(Enum):
NOWHERE = 0
IN_HIVEMIND = 1
IN_ROOT_LOGGER = 2


def use_hivemind_log_format(where: Union[StyleMode, str]) -> None:
def use_hivemind_log_handler(where: Union[HandlerMode, str]) -> None:
"""
Choose loggers where the default hivemind log handler is applied. Options for the ``where`` argument are:
* "in_hivemind" (default): Use the hivemind log handler in the loggers of the ``hivemind`` package.
Don't propagate their messages to the root logger.
* "nowhere": Don't use the hivemind log handler anywhere.
Propagate the ``hivemind`` messages to the root logger.
* "in_root_logger": Use the hivemind log handler in the root logger
(that is, in all application loggers until they disable propagation to the root logger).
Propagate the ``hivemind`` messages to the root logger.
The options may be defined as strings (case-insensitive) or values from the HandlerMode enum.
"""

global _current_mode

if isinstance(where, str):
# We allow `where` to be a string, so a developer does not have to import the enum for one usage
where = StyleMode[where.upper()]
where = HandlerMode[where.upper()]

if _current_mode == StyleMode.IN_HIVEMIND:
if _current_mode == HandlerMode.IN_HIVEMIND:
_disable_default_handler(_PACKAGE_NAME)
elif _current_mode == StyleMode.IN_ROOT_LOGGER:
elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
_disable_default_handler(None)

_current_mode = where

if _current_mode == StyleMode.IN_HIVEMIND:
if _current_mode == HandlerMode.IN_HIVEMIND:
_enable_default_handler(_PACKAGE_NAME)
elif _current_mode == StyleMode.IN_ROOT_LOGGER:
elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
_enable_default_handler(None)


Expand Down

0 comments on commit 18ec6ef

Please sign in to comment.