Skip to content

Commit

Permalink
improve type support for fbrlogger (#3238)
Browse files Browse the repository at this point in the history
* fbr logger: improve types and kwargs supported

* remove autolist for utils

* add clean directive to docs Makefile

* tidy matrix display

* make reporting of shape more compact

* remove superfluous import

* fix bug in autosummary
  • Loading branch information
leej3 committed Jun 28, 2024
1 parent 24e71af commit 5a66d9e
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 14 deletions.
7 changes: 7 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ docset: html
rebuild:
rm -rf source/generated && make clean && make html

clean:
@echo "Cleaning up..."
python -c "import shutil; shutil.rmtree('$(BUILDDIR)', ignore_errors=True)"
python -c "import shutil; shutil.rmtree('$(SOURCEDIR)/generated', ignore_errors=True)"
python -c "import os; [os.remove(f) for f in os.listdir('.') if f.endswith('.pyc')]"
python -c "import shutil; import os; [shutil.rmtree(f) for f in os.listdir('.') if f == '__pycache__' and os.path.isdir(f)]"

.PHONY: help Makefile docset

# Catch-all target: route all unknown targets to Sphinx using the new
Expand Down
10 changes: 9 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,15 @@ def run(self):
names = [name[0] for name in getmembers(module)]

# Filter out members w/o doc strings
names = [name for name in names if getattr(module, name).__doc__ is not None]
filtered_names = []
for name in names:
try:
if not name.startswith("_") and getattr(module, name).__doc__ is not None:
filtered_names.append(name)
except AttributeError:
continue

names = filtered_names

if auto == "autolist":
# Get list of all classes and functions inside module
Expand Down
34 changes: 25 additions & 9 deletions ignite/handlers/fbresearch_logger.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""FBResearch logger and its helper handlers."""

import datetime
from typing import Any, Optional

# from typing import Any, Dict, Optional, Union
from typing import Any, Callable, List, Optional

import torch

from ignite import utils
from ignite.engine import Engine, Events
from ignite.handlers import Timer


MB = 1024.0 * 1024.0

__all__ = ["FBResearchLogger"]


class FBResearchLogger:
"""Logs training and validation metrics for research purposes.
Expand Down Expand Up @@ -98,16 +98,27 @@ def __init__(self, logger: Any, delimiter: str = " ", show_output: bool = False
self.show_output: bool = show_output

def attach(
self, engine: Engine, name: str, every: int = 1, optimizer: Optional[torch.optim.Optimizer] = None
self,
engine: Engine,
name: str,
every: int = 1,
output_transform: Optional[Callable] = None,
state_attributes: Optional[List[str]] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
) -> None:
"""Attaches all the logging handlers to the given engine.
Args:
engine: The engine to attach the logging handlers to.
name: The name of the engine (e.g., "Train", "Validate") to include in log messages.
every: Frequency of iterations to log information. Logs are generated every 'every' iterations.
output_transform: A function to select the value to log.
state_attributes: A list of attributes to log.
optimizer: The optimizer used during training to log current learning rates.
"""
self.name = name
self.output_transform = output_transform
self.state_attributes = state_attributes
engine.add_event_handler(Events.EPOCH_STARTED, self.log_epoch_started, engine, name)
engine.add_event_handler(Events.ITERATION_COMPLETED(every=every), self.log_every, engine, optimizer=optimizer)
engine.add_event_handler(Events.EPOCH_COMPLETED, self.log_epoch_completed, engine, name)
Expand Down Expand Up @@ -151,10 +162,9 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
outputs = []
if self.show_output and engine.state.output is not None:
output = engine.state.output
if isinstance(output, dict):
outputs += [f"{k}: {v:.4f}" for k, v in output.items()]
else:
outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output] # type: ignore
if self.output_transform is not None:
output = self.output_transform(output)
outputs = utils._to_str_list(output)

lrs = ""
if optimizer is not None:
Expand All @@ -164,6 +174,11 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
for i, g in enumerate(optimizer.param_groups):
lrs += f"lr [g{i}]: {g['lr']:.5f}"

state_attrs = []
if self.state_attributes is not None:
state_attrs = utils._to_str_list(
{name: getattr(engine.state, name, None) for name in self.state_attributes}
)
msg = self.delimiter.join(
[
f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]",
Expand All @@ -172,6 +187,7 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
f"{lrs}",
]
+ outputs
+ [" ".join(state_attrs)]
+ [
f"Iter time: {iter_avg_time:.4f} s",
f"Data prep time: {self.data_timer.value():.4f} s",
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MaximumMeanDiscrepancy(Metric):
More details can be found in `Gretton et al. 2012`__.
__ https://jmlr.csail.mit.edu/papers/v13/gretton12a.html
__ https://www.jmlr.org/papers/volume13/gretton12a/gretton12a.pdf
- ``update`` must receive output of the form ``(x, y)``.
- ``x`` and ``y`` are expected to be in the same shape :math:`(B, \ldots)`.
Expand Down
78 changes: 78 additions & 0 deletions ignite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import hashlib
import logging
import numbers
import random
import shutil
import warnings
Expand All @@ -14,6 +15,7 @@
"convert_tensor",
"apply_to_tensor",
"apply_to_type",
"_to_str_list",
"to_onehot",
"setup_logger",
"manual_seed",
Expand Down Expand Up @@ -90,6 +92,82 @@ def _tree_map(
return func(x, key=key)


def _to_str_list(data: Any) -> List[str]:
"""
Recursively flattens and formats complex data structures, including keys for
dictionaries, into a list of human-readable strings.
This function processes nested dictionaries, lists, tuples, numbers, and
PyTorch tensors, formatting numbers to four decimal places and handling
tensors with special formatting rules. It's particularly useful for logging,
debugging, or any scenario where a human-readable representation of complex,
nested data structures is required.
The function handles the following types:
- Numbers: Formatted to four decimal places.
- PyTorch tensors:
- Scalars are formatted to four decimal places.
- 1D tensors with more than 10 elements show the first 10 elements
followed by an ellipsis.
- 1D tensors with 10 or fewer elements are fully listed.
- Multi-dimensional tensors display their shape.
- Dictionaries: Each key-value pair is included in the output with the key
as a prefix.
- Lists and tuples: Flattened and included in the output. Empty lists/tuples are represented
by an empty string.
- None values: Represented by an empty string.
Args:
data: The input data to be flattened and formatted. It can be a nested
combination of dictionaries, lists, tuples, numbers, and PyTorch
tensors.
Returns:
A list of formatted strings, each representing a part of the input data
structure.
"""
formatted_items: List[str] = []

def format_item(item: Any, prefix: str = "") -> Optional[str]:
if isinstance(item, numbers.Number):
return f"{prefix}{item:.4f}"
elif torch.is_tensor(item):
if item.dim() == 0:
return f"{prefix}{item.item():.4f}" # Format scalar tensor without brackets
elif item.dim() == 1 and item.size(0) > 10:
return f"{prefix}[" + ", ".join(f"{x.item():.4f}" for x in item[:10]) + ", ...]"
elif item.dim() == 1:
return f"{prefix}[" + ", ".join(f"{x.item():.4f}" for x in item) + "]"
else:
return f"{prefix}Shape{list(item.shape)}"
elif isinstance(item, dict):
for key, value in item.items():
formatted_value = format_item(value, f"{key}: ")
if formatted_value is not None:
formatted_items.append(formatted_value)
elif isinstance(item, (list, tuple)):
if not item:
if prefix:
formatted_items.append(f"{prefix}")
else:
values = [format_item(x) for x in item]
values_str = [v for v in values if v is not None]
if values_str:
formatted_items.append(f"{prefix}" + ", ".join(values_str))
elif item is None:
if prefix:
formatted_items.append(f"{prefix}")
return None

# Directly handle single numeric values
if isinstance(data, numbers.Number):
return [f"{data:.4f}"]

format_item(data)
return formatted_items


class _CollectionItem:
types_as_collection_item: Tuple = (int, float, torch.Tensor)

Expand Down
52 changes: 50 additions & 2 deletions tests/ignite/handlers/test_fbresearch_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
from unittest.mock import MagicMock

import pytest
import torch
import torch.nn as nn
import torch.optim as optim

from ignite.engine import Engine, Events
from ignite.handlers.fbresearch_logger import FBResearchLogger # Adjust the import path as necessary
from ignite.engine import create_supervised_trainer, Engine, Events
from ignite.handlers.fbresearch_logger import FBResearchLogger
from ignite.utils import setup_logger


@pytest.fixture
Expand Down Expand Up @@ -56,3 +60,47 @@ def test_output_formatting(mock_engine, fb_research_logger, output, expected_pat

actual_output = fb_research_logger.logger.info.call_args_list[0].args[0]
assert re.search(expected_pattern, actual_output)


def test_logger_type_support():
model = nn.Linear(10, 5)
opt = optim.SGD(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

data = [(torch.rand(4, 10), torch.randint(0, 5, size=(4,))) for _ in range(100)]

trainer = create_supervised_trainer(model, opt, criterion)

logger = setup_logger("trainer", level=logging.INFO)
logger = FBResearchLogger(logger=logger, show_output=True)
logger.attach(trainer, name="Train", every=20, optimizer=opt)

trainer.run(data, max_epochs=4)
trainer.state.output = {"loss": 4.2}
trainer.fire_event(Events.ITERATION_COMPLETED)
trainer.state.output = "4.2"
trainer.fire_event(Events.ITERATION_COMPLETED)
trainer.state.output = [4.2, 4.2]
trainer.fire_event(Events.ITERATION_COMPLETED)
trainer.state.output = (4.2, 4.2)
trainer.fire_event(Events.ITERATION_COMPLETED)


def test_fbrlogger_with_output_transform(mock_logger):
trainer = Engine(lambda e, b: 42)
fbr = FBResearchLogger(logger=mock_logger, show_output=True)
fbr.attach(trainer, "Training", output_transform=lambda x: {"loss": x})
trainer.run(data=[10], epoch_length=1, max_epochs=1)
assert "loss: 42.0000" in fbr.logger.info.call_args_list[-2].args[0]


def test_fbrlogger_with_state_attrs(mock_logger):
trainer = Engine(lambda e, b: 42)
fbr = FBResearchLogger(logger=mock_logger, show_output=True)
fbr.attach(trainer, "Training", state_attributes=["alpha", "beta", "gamma"])
trainer.state.alpha = 3.899
trainer.state.beta = torch.tensor(12.21)
trainer.state.gamma = torch.tensor([21.0, 6.0])
trainer.run(data=[10], epoch_length=1, max_epochs=1)
attrs = "alpha: 3.8990 beta: 12.2100 gamma: [21.0000, 6.0000]"
assert attrs in fbr.logger.info.call_args_list[-2].args[0]
25 changes: 24 additions & 1 deletion tests/ignite/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from packaging.version import Version

from ignite.engine import Engine, Events
from ignite.utils import convert_tensor, deprecated, hash_checkpoint, setup_logger, to_onehot
from ignite.utils import _to_str_list, convert_tensor, deprecated, hash_checkpoint, setup_logger, to_onehot


def test_convert_tensor():
Expand Down Expand Up @@ -55,6 +55,29 @@ def test_convert_tensor():
convert_tensor(12345)


@pytest.mark.parametrize(
"input_data,expected",
[
(42, ["42.0000"]),
([{"a": 15, "b": torch.tensor([2.0])}], ["a: 15.0000", "b: [2.0000]"]),
({"a": 10, "b": 2.33333}, ["a: 10.0000", "b: 2.3333"]),
({"x": torch.tensor(0.1234), "y": [1, 2.3567]}, ["x: 0.1234", "y: 1.0000, 2.3567"]),
(({"nested": [3.1415, torch.tensor(0.0001)]},), ["nested: 3.1415, 0.0001"]),
(
{"large_vector": torch.tensor(range(20))},
["large_vector: [0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, ...]"],
),
({"large_matrix": torch.randn(5, 5)}, ["large_matrix: Shape[5, 5]"]),
({"empty": []}, ["empty: "]),
([], []),
({"none": None}, ["none: "]),
({1: 100, 2: 200}, ["1: 100.0000", "2: 200.0000"]),
],
)
def test__to_str_list(input_data, expected):
assert _to_str_list(input_data) == expected


def test_to_onehot():
indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
actual = to_onehot(indices, 4)
Expand Down

0 comments on commit 5a66d9e

Please sign in to comment.