Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EvaluationDistributedSampler and examples on distributed evaluation #1886

Closed
wants to merge 44 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
6bade5b
distributed sampler + example
SkafteNicki Jul 6, 2023
e4a0992
add tests
SkafteNicki Jul 6, 2023
fa8034f
Merge branch 'master' into distributed
Borda Jul 6, 2023
cb4d6c8
Merge branch 'master' into distributed
SkafteNicki Jul 19, 2023
af52d07
changelog
SkafteNicki Jul 19, 2023
64f4899
update api description
SkafteNicki Jul 19, 2023
c594de9
update documentation
SkafteNicki Jul 19, 2023
f16955e
lightning integration
SkafteNicki Jul 19, 2023
4e949f4
more documentation
SkafteNicki Jul 19, 2023
248a08c
update example
SkafteNicki Jul 19, 2023
6228f76
improve testing
SkafteNicki Jul 19, 2023
9d78350
improve example
SkafteNicki Jul 19, 2023
8171502
revert some
SkafteNicki Jul 19, 2023
de2c0cf
update tests
SkafteNicki Jul 26, 2023
7e267e9
improve testing
SkafteNicki Jul 26, 2023
4983914
Merge branch 'master' into distributed
SkafteNicki Jul 26, 2023
983785a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2023
1939942
reorder test
SkafteNicki Jul 26, 2023
f728b52
fix mistakes
SkafteNicki Jul 26, 2023
c5a4b3e
fix inputs in tests
SkafteNicki Jul 26, 2023
bdbf05d
Merge branch 'master' into distributed
SkafteNicki Jul 28, 2023
56f17fd
Merge branch 'master' into distributed
SkafteNicki Jul 29, 2023
5dc8fb4
Merge branch 'master' into distributed
Borda Aug 1, 2023
e6cecd1
Merge branch 'master' into distributed
SkafteNicki Aug 3, 2023
c7459ee
Merge branch 'master' into distributed
Borda Aug 7, 2023
489c68f
Apply suggestions from code review
Borda Aug 8, 2023
79c357a
Merge branch 'master' into distributed
Borda Aug 8, 2023
e5675db
Merge branch 'master' into distributed
SkafteNicki Aug 9, 2023
43be66c
fix docs
SkafteNicki Aug 9, 2023
ead0380
fix ruff
SkafteNicki Aug 9, 2023
1a9ca0d
Merge branch 'master' into distributed
SkafteNicki Aug 9, 2023
46d3821
fix integration tests
SkafteNicki Aug 10, 2023
90741fa
fix
SkafteNicki Aug 10, 2023
d51ee1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2023
7daeeaa
fix
SkafteNicki Aug 10, 2023
0b6577d
Merge branch 'distributed' of https://github.com/PyTorchLightning/met…
SkafteNicki Aug 10, 2023
b80f567
skip on windows
SkafteNicki Aug 10, 2023
ae89c27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2023
80f0e37
Merge branch 'master' into distributed
Borda Aug 18, 2023
f4d5d44
Merge branch 'master' into distributed
SkafteNicki Aug 21, 2023
a8adc34
fix headings
SkafteNicki Aug 21, 2023
26ed08d
Merge branch 'master' into distributed
SkafteNicki Aug 21, 2023
311bce3
fix mypy
SkafteNicki Aug 21, 2023
c8f5b0b
Apply suggestions from code review
Borda Aug 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for multioutput evaluation in `MeanSquaredError` ([#1937](https://github.com/Lightning-AI/torchmetrics/pull/1937))


- Added `EvaluationDistributedSampler` to utility for proper distributed evaluation ([#1886](https://github.com/Lightning-AI/torchmetrics/pull/1886))


- Added argument `extended_summary` to `MeanAveragePrecision` such that precision, recall, iou can be easily returned ([#1983](https://github.com/Lightning-AI/torchmetrics/pull/1983))


Expand Down
36 changes: 32 additions & 4 deletions docs/source/references/utilities.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
.. role:: hidden
:class: hidden-section

###########################
######################
torchmetrics.utilities
######################

In the following is listed public utility functions that may be beneficial to use in your own code. These functions are
not part of the public API and may change at any time.

**********************************
torchmetrics.utilities.distributed
**********************************

The `distributed` utilities are used to help with syncronization of metrics across multiple processes.

EvaluationDistributedSampler
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.utilities.distributed.EvaluationDistributedSampler
:noindex:

gather_all_tensors
~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.utilities.distributed.gather_all_tensors
:noindex:

***************************
torchmetrics.utilities.data
###########################
***************************

The `data` utilities are used to help with data manipulation, such as converting labels in classification from one format
to another.

select_topk
~~~~~~~~~~~
Expand All @@ -20,9 +48,9 @@ to_onehot

.. autofunction:: torchmetrics.utilities.data.to_onehot

#################################
*********************************
torchmetrics.utilities.exceptions
#################################
*********************************

TorchMetricsUserError
~~~~~~~~~~~~~~~~~~~~~
Expand Down
195 changes: 195 additions & 0 deletions examples/distributed_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example of how to use the distributed evaluation utilities in both Lightning and PyTorch.

To run using only Pytorch:
python distributed_evaluation.py
To run using Lightning:
python distributed_evaluation.py --use_lightning

By default, this example uses the EvaluationDistributedSampler, which is a custom sampler that ensures that no extra
samples are added to the dataset. This is important for evaluation, as we don't want to evaluate on the same samples
multiple times.

If you want to see the difference between the EvaluationDistributedSampler and the standard DistributedSampler, you
add the flag --use_standard. This will use the standard DistributedSampler, which will add extra samples to the dataset
and thus give incorrect results.

"""
import argparse
import os
from typing import Tuple

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torchmetrics
from lightning_utilities import module_available
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset, DistributedSampler, TensorDataset
from torchmetrics.utilities.distributed import EvaluationDistributedSampler

_ = torch.manual_seed(42)


class DummyModel(Module):
"""Dummy model consisting of a single linear layer."""

def __init__(self, n_feature: int) -> None:
super().__init__()
self.linear = torch.nn.Linear(n_feature, 10)

def forward(self, x: Tensor) -> Tensor:
"""Forward pass."""
return self.linear(x)


def calculate_accuracy_manually(dataset: Dataset, model: Module) -> Tensor:
"""Basic function to calculate accuracy manually, without any distributed stuff."""
x, y = dataset.tensors
preds = model(x)
return (preds.argmax(dim=1) == y).float().mean()


def use_lightning(
model: Module, dataset: Dataset, batch_size: int, use_standard: bool, num_processes: int, gpu: bool
) -> None:
"""Use lightning to evaluate a model on a dataset."""
if module_available("lightning"):
from lightning.pytorch import LightningModule, Trainer
else:
from pytorch_lightning import LightningModule, Trainer

sampler_class = DistributedSampler if use_standard else EvaluationDistributedSampler

class DummyLightningModule(LightningModule):
def __init__(self, model: Module) -> None:
super().__init__()
self.model = model
self.metric = torchmetrics.classification.MulticlassAccuracy(num_classes=10, average="micro")

def forward(self, x: Tensor) -> Tensor:
return self.model(x)

def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
preds = model(batch[0])
target = batch[1]
self.metric.update(preds, target)

def on_test_epoch_end(self) -> None:
self.log("test_acc", self.metric.compute())

def test_dataloader(self) -> DataLoader:
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler_class(dataset),
)

model = DummyLightningModule(model)

trainer = Trainer(
devices=num_processes,
accelerator="cpu" if not gpu else "gpu",
)

res = trainer.test(model)
manual_res = calculate_accuracy_manually(dataset, model)
print(manual_res)
if torch.allclose(torch.tensor(res[0]["test_acc"]), manual_res):
print("success! result matched manual calculation")
else:
print("failure! result did not match manual calculation")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall this rather we logging or exit()?



def _use_torch_worker_fn(
rank: int, model: Module, dataset: Dataset, batch_size: int, use_standard: bool, num_processes: int, gpu: bool
) -> None:
"""Worker function for torch.distributed evaluation."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group("nvcc" if gpu else "gloo", rank=rank, world_size=num_processes)

device = torch.device(f"cuda:rank{rank}") if gpu else torch.device("cpu")

sampler_class = DistributedSampler if use_standard else EvaluationDistributedSampler

dataloader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler_class(dataset, num_processes, rank),
)

metric = torchmetrics.classification.MulticlassAccuracy(num_classes=10, average="micro")
metric = metric.to(device)

batches, num_samples = 0, 0
for _, batch in enumerate(dataloader):
if gpu:
batch = batch.cuda()
preds = model(batch[0])
target = batch[1]

metric.update(preds.to(device), target.to(device))
num_samples += len(target)
batches += 1

res = metric.compute()

print(f"Rank {rank} processed {num_samples} samples and {batches} batches and calculated accuracy: {res}")

manual_res = calculate_accuracy_manually(dataset, model)
if torch.allclose(res, manual_res):
print("success! result matched manual calculation")
else:
print("failure! result did not match manual calculation")


def use_torch(
model: Module, dataset: Dataset, batch_size: int, use_standard: bool, num_processes: int, gpu: bool
) -> None:
"""Use torch.distributed to evaluate a model on a dataset."""
mp.spawn(_use_torch_worker_fn, nprocs=2, args=(model, dataset, batch_size, use_standard, num_processes, gpu))


def main() -> None:
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""Main function."""
parser = argparse.ArgumentParser()
parser.add_argument("--use_lightning", action="store_true")
parser.add_argument("--use_standard", action="store_true")
parser.add_argument("--num_processes", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=3)
parser.add_argument("--gpu", action="store_true")
args = parser.parse_args()
print(args)

dataset = TensorDataset(torch.randn(199, 100), torch.randint(0, 10, (199,)))
n_feature = 100
dummy_model = DummyModel(n_feature)

batch_size = 3
if len(dataset) % (args.num_processes * batch_size) == 0:
raise ValueError(
"For this example the dataset size should NOT be divisible by the number of processes times the batch size."
)

if args.use_lightning:
use_lightning(dummy_model, dataset, batch_size, args.use_standard, args.num_processes, args.gpu)
else:
use_torch(dummy_model, dataset, batch_size, args.use_standard, args.num_processes, args.gpu)


if __name__ == "__main__":
main()
7 changes: 4 additions & 3 deletions src/torchmetrics/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.
from torchmetrics.utilities.checks import check_forward_full_state_property
from torchmetrics.utilities.data import apply_to_collection
from torchmetrics.utilities.distributed import class_reduce, reduce
from torchmetrics.utilities.distributed import EvaluationDistributedSampler, class_reduce, reduce
from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn

__all__ = [
"check_forward_full_state_property",
"apply_to_collection",
"check_forward_full_state_property",
"class_reduce",
"reduce",
"EvaluationDistributedSampler",
"rank_zero_debug",
"rank_zero_info",
"rank_zero_warn",
"reduce",
]
63 changes: 63 additions & 0 deletions src/torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from torch import Tensor
from torch.nn import functional as F # noqa: N812
from torch.utils.data import Dataset
from typing_extensions import Literal


Expand Down Expand Up @@ -146,3 +147,65 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
return gathered_result


class EvaluationDistributedSampler(torch.utils.data.DistributedSampler):
"""A specialized distributed sampler for evaluation (test and validation).

It is derived from the PyTorch DistributedSampler, with one core difference: it doesn't add extra samples to make
the data evenly divisible across devices. This is important while evaluating, as adding extra samples will screw
the results towards those duplicated samples.

Normally not adding the extra samples would lead to processes becoming out of sync, but this is handled by the
custom syncronization in Torchmetrics. Thus this sampler does not in general secure that distributed operations
are working outside of Torchmetrics.

Arguments are the same as DistributedSampler, and this implementation only overrides the __init__ method.

Args:
dataset: Dataset used for sampling.
num_replicas (int, optional): Number of processes participating in distributed training. By default,
:attr:`world_size` is retrieved from the current distributed group.
rank (int, optional): Rank of the current process within :attr:`num_replicas`. By default, :attr:`rank` is
retrieved from the current distributed group.
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the indices.
seed (int, optional): random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be
identical across all processes in the distributed group.
drop_last (bool, optional): if ``True``, then the sampler will drop the tail of the data to make it evenly
divisible across the number of replicas.

For a full example on how to use this sampler, using both bare PyTorch but also PyTorch Lightning,
check out the `distributed_evaluation.py` file in the examples folder.

Example::
The distributed sampler is always intended to be used in conjunction with a DataLoader:

>>> import torch
>>> from torch.utils.data import DataLoader, TensorDataset
>>> from torchmetrics.utilities.distributed import EvaluationDistributedSampler
>>> dataset = TensorDataset(torch.arange(10))
>>> dataloader = DataLoader(
... dataset, sampler=EvaluationDistributedSampler(dataset, num_replicas=2)
... ) # doctest: +SKIP

"""

def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Lightning we have a very similar class: https://github.com/Lightning-AI/lightning/blob/fbdbe632c67b05158804b52f4345944781ca4f07/src/lightning/pytorch/overrides/distributed.py#L194

I think the main difference is that yours respects the setting drop_last. I'm not sure why we have the __iter__ overridden there but if you are interested you can compare the two.

self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
# From:
# https://github.com/pytorch/pytorch/issues/25162#issuecomment-1227647626
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed)

len_dataset = len(self.dataset) # type: ignore[arg-type]
if not self.drop_last and len_dataset % self.num_replicas != 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the issue with this that it wouldn't necessarily work with validation, since not all ranks would reach the same distributed function calls and therefore time out which would kill the entire process. Also this would never work with FSDP, since some ranks have a batch more and for fsdp, not all processes would reach the forward syncing points also resulting in timeouts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that in the context of Lightning this wouldn't work well, as it does not support Join (Lightning-AI/pytorch-lightning#3325)
FSDP also doesn't support join afaik (pytorch/pytorch#64683)

But outside Lightning, and taking FSDP out of the equation, I agree this can work and is a good utility to have IMO. It also suits the metric design well, since synchronization is only necessary when all processes have finished collecting their statistics and .compute() can be called.

# some ranks may have less samples, that's fine
if self.rank >= len_dataset % self.num_replicas:
self.num_samples -= 1
self.total_size = len_dataset
7 changes: 7 additions & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from distutils.version import LooseVersion
from typing import Optional

from lightning_utilities import module_available
from lightning_utilities.core.imports import compare_version, package_available

_PYTHON_VERSION = ".".join(map(str, [sys.version_info.major, sys.version_info.minor, sys.version_info.micro]))
Expand All @@ -29,6 +30,12 @@
_TORCH_GREATER_EQUAL_1_12: Optional[bool] = compare_version("torch", operator.ge, "1.12.0")
_TORCH_GREATER_EQUAL_1_13: Optional[bool] = compare_version("torch", operator.ge, "1.13.0")

_LIGHTNING_GREATER_EQUAL_2_0: Optional[bool] = (
compare_version("lightning", operator.ge, "2.0.0")
if module_available("lightning")
else compare_version("pytorch_lightning", operator.ge, "2.0.0")
)

_JIWER_AVAILABLE: bool = package_available("jiwer")
_NLTK_AVAILABLE: bool = package_available("nltk")
_ROUGE_SCORE_AVAILABLE: bool = package_available("rouge_score")
Expand Down
Loading
Loading