Skip to content

Commit

Permalink
[python] allow to register any custom logger (fixes #4783) (#4880)
Browse files Browse the repository at this point in the history
* [python] allow to register any custom logger

* allow customizable logging method name; add unit test

* [python] allow to register any custom logger

* allow customizable logging method name; add unit test

* update tests

* fix lint error

* remove unused method

* fix docstring style

Co-authored-by: gongxudong <gongxudong@kuaishou.com>
  • Loading branch information
RustingSword and gongxudong authored Mar 28, 2022
1 parent d163c2c commit 60e72d5
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 9 deletions.
34 changes: 25 additions & 9 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,37 @@ def warning(self, msg: str) -> None:
warnings.warn(msg, stacklevel=3)


_LOGGER: Union[_DummyLogger, Logger] = _DummyLogger()
_LOGGER: Any = _DummyLogger()
_INFO_METHOD_NAME = "info"
_WARNING_METHOD_NAME = "warning"


def register_logger(logger: Logger) -> None:
def register_logger(
logger: Any, info_method_name: str = "info", warning_method_name: str = "warning"
) -> None:
"""Register custom logger.
Parameters
----------
logger : logging.Logger
logger : Any
Custom logger.
info_method_name : str, optional (default="info")
Method used to log info messages.
warning_method_name : str, optional (default="warning")
Method used to log warning messages.
"""
if not isinstance(logger, Logger):
raise TypeError("Logger should inherit logging.Logger class")
global _LOGGER
def _has_method(logger: Any, method_name: str) -> bool:
return callable(getattr(logger, method_name, None))

if not _has_method(logger, info_method_name) or not _has_method(logger, warning_method_name):
raise TypeError(
f"Logger must provide '{info_method_name}' and '{warning_method_name}' method"
)

global _LOGGER, _INFO_METHOD_NAME, _WARNING_METHOD_NAME
_LOGGER = logger
_INFO_METHOD_NAME = info_method_name
_WARNING_METHOD_NAME = warning_method_name


def _normalize_native_string(func: Callable[[str], None]) -> Callable[[str], None]:
Expand All @@ -76,16 +92,16 @@ def wrapper(msg: str) -> None:


def _log_info(msg: str) -> None:
_LOGGER.info(msg)
getattr(_LOGGER, _INFO_METHOD_NAME)(msg)


def _log_warning(msg: str) -> None:
_LOGGER.warning(msg)
getattr(_LOGGER, _WARNING_METHOD_NAME)(msg)


@_normalize_native_string
def _log_native(msg: str) -> None:
_LOGGER.info(msg)
getattr(_LOGGER, _INFO_METHOD_NAME)(msg)


def _log_callback(msg: bytes) -> None:
Expand Down
68 changes: 68 additions & 0 deletions tests/python_package_test/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging

import numpy as np
import pytest

import lightgbm as lgb

Expand Down Expand Up @@ -97,3 +98,70 @@ def dummy_metric(_, __):
actual_log_wo_gpu_stuff.append(line)

assert "\n".join(actual_log_wo_gpu_stuff) == expected_log


def test_register_invalid_logger():
class LoggerWithoutInfoMethod:
def warning(self, msg: str) -> None:
print(msg)

class LoggerWithoutWarningMethod:
def info(self, msg: str) -> None:
print(msg)

class LoggerWithAttributeNotCallable:
def __init__(self):
self.info = 1
self.warning = 2

expected_error_message = "Logger must provide 'info' and 'warning' method"

with pytest.raises(TypeError, match=expected_error_message):
lgb.register_logger(LoggerWithoutInfoMethod())

with pytest.raises(TypeError, match=expected_error_message):
lgb.register_logger(LoggerWithoutWarningMethod())

with pytest.raises(TypeError, match=expected_error_message):
lgb.register_logger(LoggerWithAttributeNotCallable())


def test_register_custom_logger():
logged_messages = []

class CustomLogger:
def custom_info(self, msg: str) -> None:
logged_messages.append(msg)

def custom_warning(self, msg: str) -> None:
logged_messages.append(msg)

custom_logger = CustomLogger()
lgb.register_logger(
custom_logger,
info_method_name="custom_info",
warning_method_name="custom_warning"
)

lgb.basic._log_info("info message")
lgb.basic._log_warning("warning message")

expected_log = ["info message", "warning message"]
assert logged_messages == expected_log

logged_messages = []
X = np.array([[1, 2, 3],
[1, 2, 4],
[1, 2, 4],
[1, 2, 3]],
dtype=np.float32)
y = np.array([0, 1, 1, 0])
lgb_data = lgb.Dataset(X, y)
lgb.train(
{'objective': 'binary', 'metric': 'auc'},
lgb_data,
num_boost_round=10,
valid_sets=[lgb_data],
categorical_feature=[1]
)
assert logged_messages, "custom logger was not called"

0 comments on commit 60e72d5

Please sign in to comment.