Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU prediction race condition results in BasePredictionWriter observing incorrect (zero) values #11287

Closed
edpizzi opened this issue Dec 31, 2021 · 2 comments · Fixed by #11288
Labels
bug Something isn't working callback: prediction writer priority: 0 High priority task

Comments

@edpizzi
Copy link
Contributor

edpizzi commented Dec 31, 2021

🐛 Bug

When writing prediction results using a trivial BasePredictionWriter subclass with write_interval "epoch", the final batch is written as zeros when the following conditions are met:

  • GPU accelerator (haven't tested with other non-CPU accelerators)
  • Sufficiently expensive model that the CUDA stream has not completed by the time the prediction writer is called
    • (This is why BoringModel is not appropriate to reproduce here.)

I believe that this is a race condition resulting from non-blocking copies from GPU to CPU without explicit CUDA synchronization. This results in CPU tensors being incorrectly observed as zeros before the GPU computation completes. Details below.

To Reproduce

import argparse
import logging
import os

from torchvision.datasets import FakeData
from torchvision.models import resnet50
from torchvision.transforms import ToTensor

import pytorch_lightning as pl
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.callbacks import BasePredictionWriter
from torch.utils.data import DataLoader
import torch


logging.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')


class Model(pl.LightningModule):

    def __init__(self):
        super().__init__()
        # Arbitrary model, but this is a race, and the race needs the model to
        # be expensive enough to cause the CUDA stream to queue.
        self.model = resnet50(pretrained=False)

    def forward(self, x):
        x = self.model(x)
        # Return scalar values (easier to print), and ensure that 0s are not possible.
        # It's important that the result depends on the slow self.model(x) step.
        return (x.abs() + 0.1).sum(dim=1)

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x, _ = batch
        y = self(x)
        return y


class InferenceWriter(BasePredictionWriter):
    """Writes predictions on epoch end."""

    def __init__(self, output_path: str):
        super().__init__("epoch")
        self.output_file = os.path.join(output_path, "predictions.pt")

    def write_on_epoch_end(
        self, trainer, module, predictions, batch_indices
    ):
        assert len(predictions) == 1
        predictions = predictions[0]
        outputs = torch.cat(predictions)
        logging.info("Saving output to %s.", self.output_file)
        torch.save(outputs, self.output_file)

    def read(self):
        return torch.load(self.output_file)


parser = argparse.ArgumentParser()
parser.add_argument("--output_path", required=True)
parser.add_argument("--batch_size", default=256, type=int)
parser.add_argument("--image_size", default=288, type=int, help="Image size for inference")


def main(args):
    dataset = FakeData(
        args.batch_size * 2,
        image_size=(3, args.image_size, args.image_size),
        transform=ToTensor(),
    )
    model = Model()
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=2,
        persistent_workers=True,  # unnecessary here, but silences warning
    )
    writer = InferenceWriter(args.output_path)
    trainer = pl.Trainer(
        devices=1,
        num_nodes=1,
        accelerator="gpu",
        default_root_dir=args.output_path,
        strategy=DDPSpawnPlugin(find_unused_parameters=False),
        callbacks=[writer],
    )
    logging.info("Starting inference")
    trainer.predict(model, dataloaders=dataloader)
    logging.info("Loading features")
    outputs = writer.read()

    logging.info("First batch values: %s", outputs[:args.batch_size])
    logging.info("Second batch values: %s", outputs[args.batch_size:])


if __name__ == '__main__':
    args = parser.parse_args()
    main(args)

Output:

2021-12-31 05:27:38 INFO     First batch values: tensor([6145.3164, 6103.9824, 6139.4873, 6153.1626, 6133.7827, 6134.3999, ...
2021-12-31 05:27:39 INFO     Second batch values: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., ...

Expected behavior

Reproduce script: Both the first and the second batch (logged in the script above) should contain nonzero values, from the same distribution. The computation is such that zeroes are not possible outputs.

In general: Tensors returned to the user on API surfaces should contain valid values when their contents are observed.

Cause (I think)

If I understand the semantics of tensors copied from GPU to CPU with non_blocking=True, these tensors are unsafe without explicit CUDA synchronization to ensure that the CPU tensors have valid values (assuming the copy to CPU is covered by something like torch.cuda.synchronize()).

A non_blocking GPU -> CPU transfer creates a CPU tensor that is effectively a future, similar to CUDA tensors whose values have not resolved. But unlike CUDA tensors, the CPU tensor does not act like a future, and does not know that the value has not been written. As a result, operations that observe the tensor values do not wait for the tensor to be valid, and incorrect (zero) values are observed.

Since these futures are sharp edges, I suggest that we don't expose them to users unless we synchronize appropriately.

Similar bugs resulting from non-blocking GPU -> CPU copies have been reported elsewhere, for instance there are two similar reports on this thread.

Possible fixes

I think this could be done using cuda synchronization before callbacks that might see outputs that have been non-blocking moved to CPU. Calling torch.cuda.synchronize() before exposing affected tensors (CPU tensors from non-blocking copies from GPU operations) to the user might fix this.

However I think it would be better to not set non_blocking=True when copying tensors to CPU to avoid this case. It will be hard to find all the possible cases above. I don't see a strong argument to using non-blocking copies from GPU, and avoiding incorrect results due to the surprising behavior of these tensors seems like a compelling argument against. I expect that changing move_data_to_device to not set non_blocking when device is CPU should work.

Fixes I've tested

Copying tensors to CPU in predict_step: y = y.cpu(). This uses a blocking transfer, making Lightning's non-blocking .to() call a no-op.

Inspecting the result of GPU operations also fixes this, by forcing us to wait for the GPU. This has to be done on the GPU tensors, before the CPU copy. (eg. y.mean().item() in predict_step in the reproduce example).

Using CPU compute also fixes this, but is slow, since the examples need enough computation to expose the race with GPUs.

Additionally, various things to "win" the race condition here also work:

  • A simple sleep in write_on_epoch_end "fixes" this
  • Using a smaller inference size (reducing GPU compute) "fixes" this

Environment

  • CUDA:
    - GPU:
    - A100-SXM4-40GB
    - A100-SXM4-40GB
    - A100-SXM4-40GB
    - A100-SXM4-40GB
    - A100-SXM4-40GB
    - A100-SXM4-40GB
    - A100-SXM4-40GB
    - A100-SXM4-40GB
    - available: True
    - version: 11.2
  • Packages:
    - numpy: 1.21.5
    - pyTorch_debug: False
    - pyTorch_version: 1.10.0
    - pytorch-lightning: 1.5.7
    - tqdm: 4.62.3
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.9.9

cc @tchaton @rohitgr7

@edpizzi edpizzi added the bug Something isn't working label Dec 31, 2021
@justusschock
Copy link
Member

@edpizzi THanks for reporting this. I agree, that we should set non_blocking to False when the target device is CPU. Would you mind sending a PR for that?

Also: Do you have any idea, how we can test this in a preferably light way? I'm afraid that running a heavy model is not suitable within CI pipelines...

edpizzi added a commit to edpizzi/pytorch-lightning that referenced this issue Dec 31, 2021
Non-blocking GPU->CPU transfers can create race windows where tensor contents
are observed to have incorrect values. Lightning-AI#11287
@edpizzi
Copy link
Contributor Author

edpizzi commented Dec 31, 2021

Yes, I can create a PR to omit non_blocking when copying to CPU. I have a draft, which looks like it's already linked here.

As for testing -- this is a race, which makes testing in a cheap way a bit difficult. But I haven't experimented with the limits of reproducibility. But changing the input resolution to 64x64 fixed the issue in my codebase. I personally think that omitting non_blocking for CPU tensors requires less testing than anything involving explicit synchronization. I'd be satisfied with demonstrating that the reproduce script passes as a one-off.

edpizzi added a commit to edpizzi/pytorch-lightning that referenced this issue Jan 1, 2022
Non-blocking GPU->CPU transfers can create race windows where tensor contents
are observed to have incorrect values. Lightning-AI#11287

Tests appear to rely on device=None (contrary to type annotations), so treat
None as a CPU device.
edpizzi added a commit to edpizzi/pytorch-lightning that referenced this issue Jan 3, 2022
Non-blocking GPU->CPU transfers can create race windows where tensor contents
are observed to have incorrect values. Lightning-AI#11287
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working callback: prediction writer priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants