-
Notifications
You must be signed in to change notification settings - Fork 402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EvaluationDistributedSampler
and examples on distributed evaluation
#1886
Changes from all commits
6bade5b
e4a0992
fa8034f
cb4d6c8
af52d07
64f4899
c594de9
f16955e
4e949f4
248a08c
6228f76
9d78350
8171502
de2c0cf
7e267e9
4983914
983785a
1939942
f728b52
c5a4b3e
bdbf05d
56f17fd
5dc8fb4
e6cecd1
c7459ee
489c68f
79c357a
e5675db
43be66c
ead0380
1a9ca0d
46d3821
90741fa
d51ee1e
7daeeaa
0b6577d
b80f567
ae89c27
80f0e37
f4d5d44
a8adc34
26ed08d
311bce3
c8f5b0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# Copyright The Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""An example of how to use the distributed evaluation utilities in both Lightning and PyTorch. | ||
|
||
To run using only Pytorch: | ||
python distributed_evaluation.py | ||
To run using Lightning: | ||
python distributed_evaluation.py --use_lightning | ||
|
||
By default, this example uses the EvaluationDistributedSampler, which is a custom sampler that ensures that no extra | ||
samples are added to the dataset. This is important for evaluation, as we don't want to evaluate on the same samples | ||
multiple times. | ||
|
||
If you want to see the difference between the EvaluationDistributedSampler and the standard DistributedSampler, you | ||
add the flag --use_standard. This will use the standard DistributedSampler, which will add extra samples to the dataset | ||
and thus give incorrect results. | ||
|
||
""" | ||
import argparse | ||
import os | ||
from typing import Tuple | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.multiprocessing as mp | ||
import torchmetrics | ||
from lightning_utilities import module_available | ||
from torch import Tensor | ||
from torch.nn import Module | ||
from torch.utils.data import DataLoader, Dataset, DistributedSampler, TensorDataset | ||
from torchmetrics.utilities.distributed import EvaluationDistributedSampler | ||
|
||
_ = torch.manual_seed(42) | ||
|
||
|
||
class DummyModel(Module): | ||
"""Dummy model consisting of a single linear layer.""" | ||
|
||
def __init__(self, n_feature: int) -> None: | ||
super().__init__() | ||
self.linear = torch.nn.Linear(n_feature, 10) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
"""Forward pass.""" | ||
return self.linear(x) | ||
|
||
|
||
def calculate_accuracy_manually(dataset: Dataset, model: Module) -> Tensor: | ||
"""Basic function to calculate accuracy manually, without any distributed stuff.""" | ||
x, y = dataset.tensors | ||
preds = model(x) | ||
return (preds.argmax(dim=1) == y).float().mean() | ||
|
||
|
||
def use_lightning( | ||
model: Module, dataset: Dataset, batch_size: int, use_standard: bool, num_processes: int, gpu: bool | ||
) -> None: | ||
"""Use lightning to evaluate a model on a dataset.""" | ||
if module_available("lightning"): | ||
from lightning.pytorch import LightningModule, Trainer | ||
else: | ||
from pytorch_lightning import LightningModule, Trainer | ||
|
||
sampler_class = DistributedSampler if use_standard else EvaluationDistributedSampler | ||
|
||
class DummyLightningModule(LightningModule): | ||
def __init__(self, model: Module) -> None: | ||
super().__init__() | ||
self.model = model | ||
self.metric = torchmetrics.classification.MulticlassAccuracy(num_classes=10, average="micro") | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
return self.model(x) | ||
|
||
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: | ||
preds = model(batch[0]) | ||
target = batch[1] | ||
self.metric.update(preds, target) | ||
|
||
def on_test_epoch_end(self) -> None: | ||
self.log("test_acc", self.metric.compute()) | ||
|
||
def test_dataloader(self) -> DataLoader: | ||
return DataLoader( | ||
dataset, | ||
batch_size=batch_size, | ||
sampler=sampler_class(dataset), | ||
) | ||
|
||
model = DummyLightningModule(model) | ||
|
||
trainer = Trainer( | ||
devices=num_processes, | ||
accelerator="cpu" if not gpu else "gpu", | ||
) | ||
|
||
res = trainer.test(model) | ||
manual_res = calculate_accuracy_manually(dataset, model) | ||
print(manual_res) | ||
if torch.allclose(torch.tensor(res[0]["test_acc"]), manual_res): | ||
print("success! result matched manual calculation") | ||
else: | ||
print("failure! result did not match manual calculation") | ||
|
||
|
||
def _use_torch_worker_fn( | ||
rank: int, model: Module, dataset: Dataset, batch_size: int, use_standard: bool, num_processes: int, gpu: bool | ||
) -> None: | ||
"""Worker function for torch.distributed evaluation.""" | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "29500" | ||
dist.init_process_group("nvcc" if gpu else "gloo", rank=rank, world_size=num_processes) | ||
|
||
device = torch.device(f"cuda:rank{rank}") if gpu else torch.device("cpu") | ||
|
||
sampler_class = DistributedSampler if use_standard else EvaluationDistributedSampler | ||
|
||
dataloader = DataLoader( | ||
dataset, | ||
batch_size=batch_size, | ||
sampler=sampler_class(dataset, num_processes, rank), | ||
) | ||
|
||
metric = torchmetrics.classification.MulticlassAccuracy(num_classes=10, average="micro") | ||
metric = metric.to(device) | ||
|
||
batches, num_samples = 0, 0 | ||
for _, batch in enumerate(dataloader): | ||
if gpu: | ||
batch = batch.cuda() | ||
preds = model(batch[0]) | ||
target = batch[1] | ||
|
||
metric.update(preds.to(device), target.to(device)) | ||
num_samples += len(target) | ||
batches += 1 | ||
|
||
res = metric.compute() | ||
|
||
print(f"Rank {rank} processed {num_samples} samples and {batches} batches and calculated accuracy: {res}") | ||
|
||
manual_res = calculate_accuracy_manually(dataset, model) | ||
if torch.allclose(res, manual_res): | ||
print("success! result matched manual calculation") | ||
else: | ||
print("failure! result did not match manual calculation") | ||
|
||
|
||
def use_torch( | ||
model: Module, dataset: Dataset, batch_size: int, use_standard: bool, num_processes: int, gpu: bool | ||
) -> None: | ||
"""Use torch.distributed to evaluate a model on a dataset.""" | ||
mp.spawn(_use_torch_worker_fn, nprocs=2, args=(model, dataset, batch_size, use_standard, num_processes, gpu)) | ||
|
||
|
||
def main() -> None: | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Main function.""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--use_lightning", action="store_true") | ||
parser.add_argument("--use_standard", action="store_true") | ||
parser.add_argument("--num_processes", type=int, default=2) | ||
parser.add_argument("--batch_size", type=int, default=3) | ||
parser.add_argument("--gpu", action="store_true") | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
dataset = TensorDataset(torch.randn(199, 100), torch.randint(0, 10, (199,))) | ||
n_feature = 100 | ||
dummy_model = DummyModel(n_feature) | ||
|
||
batch_size = 3 | ||
if len(dataset) % (args.num_processes * batch_size) == 0: | ||
raise ValueError( | ||
"For this example the dataset size should NOT be divisible by the number of processes times the batch size." | ||
) | ||
|
||
if args.use_lightning: | ||
use_lightning(dummy_model, dataset, batch_size, args.use_standard, args.num_processes, args.gpu) | ||
else: | ||
use_torch(dummy_model, dataset, batch_size, args.use_standard, args.num_processes, args.gpu) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
import torch | ||
from torch import Tensor | ||
from torch.nn import functional as F # noqa: N812 | ||
from torch.utils.data import Dataset | ||
from typing_extensions import Literal | ||
|
||
|
||
|
@@ -146,3 +147,65 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens | |
slice_param = [slice(dim_size) for dim_size in item_size] | ||
gathered_result[idx] = gathered_result[idx][slice_param] | ||
return gathered_result | ||
|
||
|
||
class EvaluationDistributedSampler(torch.utils.data.DistributedSampler): | ||
"""A specialized distributed sampler for evaluation (test and validation). | ||
|
||
It is derived from the PyTorch DistributedSampler, with one core difference: it doesn't add extra samples to make | ||
the data evenly divisible across devices. This is important while evaluating, as adding extra samples will screw | ||
the results towards those duplicated samples. | ||
|
||
Normally not adding the extra samples would lead to processes becoming out of sync, but this is handled by the | ||
custom syncronization in Torchmetrics. Thus this sampler does not in general secure that distributed operations | ||
are working outside of Torchmetrics. | ||
|
||
Arguments are the same as DistributedSampler, and this implementation only overrides the __init__ method. | ||
|
||
Args: | ||
dataset: Dataset used for sampling. | ||
num_replicas (int, optional): Number of processes participating in distributed training. By default, | ||
:attr:`world_size` is retrieved from the current distributed group. | ||
rank (int, optional): Rank of the current process within :attr:`num_replicas`. By default, :attr:`rank` is | ||
retrieved from the current distributed group. | ||
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the indices. | ||
seed (int, optional): random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be | ||
identical across all processes in the distributed group. | ||
drop_last (bool, optional): if ``True``, then the sampler will drop the tail of the data to make it evenly | ||
divisible across the number of replicas. | ||
|
||
For a full example on how to use this sampler, using both bare PyTorch but also PyTorch Lightning, | ||
check out the `distributed_evaluation.py` file in the examples folder. | ||
|
||
Example:: | ||
The distributed sampler is always intended to be used in conjunction with a DataLoader: | ||
|
||
>>> import torch | ||
>>> from torch.utils.data import DataLoader, TensorDataset | ||
>>> from torchmetrics.utilities.distributed import EvaluationDistributedSampler | ||
>>> dataset = TensorDataset(torch.arange(10)) | ||
>>> dataloader = DataLoader( | ||
... dataset, sampler=EvaluationDistributedSampler(dataset, num_replicas=2) | ||
... ) # doctest: +SKIP | ||
|
||
""" | ||
|
||
def __init__( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Lightning we have a very similar class: https://github.com/Lightning-AI/lightning/blob/fbdbe632c67b05158804b52f4345944781ca4f07/src/lightning/pytorch/overrides/distributed.py#L194 I think the main difference is that yours respects the setting drop_last. I'm not sure why we have the |
||
self, | ||
dataset: Dataset, | ||
num_replicas: Optional[int] = None, | ||
rank: Optional[int] = None, | ||
shuffle: bool = True, | ||
seed: int = 0, | ||
drop_last: bool = False, | ||
) -> None: | ||
# From: | ||
# https://github.com/pytorch/pytorch/issues/25162#issuecomment-1227647626 | ||
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed) | ||
|
||
len_dataset = len(self.dataset) # type: ignore[arg-type] | ||
if not self.drop_last and len_dataset % self.num_replicas != 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the issue with this that it wouldn't necessarily work with validation, since not all ranks would reach the same distributed function calls and therefore time out which would kill the entire process. Also this would never work with FSDP, since some ranks have a batch more and for fsdp, not all processes would reach the forward syncing points also resulting in timeouts. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that in the context of Lightning this wouldn't work well, as it does not support Join (Lightning-AI/pytorch-lightning#3325) But outside Lightning, and taking FSDP out of the equation, I agree this can work and is a good utility to have IMO. It also suits the metric design well, since synchronization is only necessary when all processes have finished collecting their statistics and |
||
# some ranks may have less samples, that's fine | ||
if self.rank >= len_dataset % self.num_replicas: | ||
self.num_samples -= 1 | ||
self.total_size = len_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall this rather we logging or
exit()
?