Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Support export logs of different ranks in debug mode #968

Merged
merged 19 commits into from
Mar 13, 2023
Merged
68 changes: 55 additions & 13 deletions mmengine/logging/logger.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import os.path as osp
import sys
import warnings
from getpass import getuser
from logging import Logger, LogRecord
from socket import gethostname
from typing import Optional, Union

from termcolor import colored
Expand Down Expand Up @@ -152,7 +156,8 @@ def __init__(self,
ManagerMixin.__init__(self, name)
# Get rank in DDP mode.

rank = _get_rank()
global_rank = _get_rank()
device_id = _get_device_id()

# Config stream_handler. If `rank != 0`. stream_handler can only
# export ERROR logs.
Expand All @@ -162,24 +167,21 @@ def __init__(self,
stream_handler.setFormatter(
MMFormatter(color=True, datefmt='%m/%d %H:%M:%S'))
# Only rank0 `StreamHandler` will log messages below error level.
stream_handler.setLevel(log_level) if rank == 0 else \
stream_handler.setLevel(log_level) if global_rank == 0 else \
stream_handler.setLevel(logging.ERROR)
self.handlers.append(stream_handler)

if log_file is not None:
if rank != 0:
# rename `log_file` with rank suffix.
path_split = log_file.split(os.sep)
if '.' in path_split[-1]:
filename_list = path_split[-1].split('.')
filename_list[-2] = f'{filename_list[-2]}_rank{rank}'
path_split[-1] = '.'.join(filename_list)
else:
path_split[-1] = f'{path_split[-1]}_rank{rank}'
log_file = os.sep.join(path_split)
if global_rank != 0 or log_level == 'DEBUG' or distributed:
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
filename, suffix = osp.splitext(osp.basename(log_file))
hostname = _get_host_info()
filename = f'{filename}_{hostname}_device{device_id}_' \
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
f'rank{global_rank}{suffix}'
log_file = osp.join(osp.dirname(log_file), filename)
# Save multi-ranks logs if distributed is True. The logs of rank0
# will always be saved.
if rank == 0 or distributed:
if global_rank == 0 or distributed or \
logging._nameToLevel[log_level] <= logging.DEBUG:
# Here, the default behaviour of the official logger is 'a'.
# Thus, we provide an interface to change the file mode to
# the default behaviour. `FileHandler` is not supported to
Expand All @@ -192,6 +194,11 @@ def __init__(self,
MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S'))
file_handler.setLevel(log_level)
self.handlers.append(file_handler)
self._log_file = log_file

@property
def log_file(self):
return self._log_file

@classmethod
def get_current_instance(cls) -> 'MMLogger':
Expand Down Expand Up @@ -297,3 +304,38 @@ def _get_rank():
return 0
else:
return get_rank()


def _get_device_id():
"""Get device id of current machine."""
try:
import torch
except ImportError:
return 0
else:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
RangiLyu marked this conversation as resolved.
Show resolved Hide resolved
# TODO: return device id of npu and mlu.
if not torch.cuda.is_available():
return local_rank
num_device = torch.cuda.device_count()
cuda_visible_device = os.getenv('CUDA_VISIBLE_DEVICES', None)
if cuda_visible_device is None:
cuda_visible_device = list(range(num_device))
else:
cuda_visible_device = cuda_visible_device.split(',')
return int(cuda_visible_device[local_rank])
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved


def _get_host_info() -> str:
"""Get hostname and username.

Return empty string if exception raised, e.g. ``getpass.getuser()`` will
lead to error in docker container
"""
host = ''
try:
host = f'{getuser()}@{gethostname()}'
except Exception as e:
warnings.warn(f'Host or user not found: {str(e)}')
finally:
return host
51 changes: 49 additions & 2 deletions tests/test_logging/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import re
import sys
from collections import OrderedDict
from contextlib import contextmanager
from unittest.mock import patch

import pytest

from mmengine.logging import MMLogger, print_log
from mmengine.logging.logger import _get_device_id


class TestLogger:
Expand Down Expand Up @@ -49,10 +51,12 @@ def test_init_rank0(self, tmp_path):
MMLogger._instance_dict.clear()

@patch('mmengine.logging.logger._get_rank', lambda: 1)
@patch('mmengine.logging.logger._get_device_id', lambda: 1)
@patch('mmengine.logging.logger._get_host_info', lambda: 'test')
def test_init_rank1(self, tmp_path):
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
tmp_file = tmp_path / 'tmp_file.log'
log_path = tmp_path / 'tmp_file_rank1.log'
log_path = tmp_path / 'tmp_file_test_device1_rank1.log'
logger = MMLogger.get_instance(
'rank1.pkg2', log_level='INFO', log_file=str(tmp_file))
assert len(logger.handlers) == 1
Expand All @@ -64,7 +68,7 @@ def test_init_rank1(self, tmp_path):
assert logger.handlers[0].level == logging.ERROR
assert logger.handlers[1].level == logging.INFO
assert len(logger.handlers) == 2
assert os.path.exists(log_path)
assert os.path.exists(str(log_path))
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
Expand Down Expand Up @@ -184,3 +188,46 @@ def test_set_level(self, capsys):
logger.warning('hello')
out, _ = capsys.readouterr()
assert 'WARNING' in out


@patch('torch.cuda.device_count', lambda: 4)
def test_get_device_id():

@contextmanager
def patch_env(local_rank, cuda_visible_device):
ori = os.environ.copy()
if local_rank is not None:
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_device
yield
os.environ = ori

# cuda is not available and local_rank is not set
with patch('torch.cuda.is_available', lambda: False), \
patch_env(None, '0,1,2,3'):
assert _get_device_id() == 0

# cuda is not available and local_rank is set
with patch('torch.cuda.is_available', lambda: False), \
patch_env('1', '0,1,2,3'):
assert _get_device_id() == 1

# CUDA_VISIBLE_DEVICES will not influence non-cuda device
with patch('torch.cuda.is_available', lambda: False), \
patch_env('1', '0,100,2,3'):
assert _get_device_id() == 1

# cuda is available and local_rank is not set
with patch('torch.cuda.is_available', lambda: True), \
patch_env(None, '0,1,2,3'):
assert _get_device_id() == 0

# cuda is available and local_rank is set
with patch('torch.cuda.is_available', lambda: True), \
patch_env(2, '0,1,2,3'):
assert _get_device_id() == 2

# CUDA_VISIBLE_DEVICES worked
with patch('torch.cuda.is_available', lambda: True), \
patch_env(2, '0,1,3,5'):
assert _get_device_id() == 3