-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPU prediction race condition results in BasePredictionWriter observing incorrect (zero) values #11287
Comments
@edpizzi THanks for reporting this. I agree, that we should set non_blocking to Also: Do you have any idea, how we can test this in a preferably light way? I'm afraid that running a heavy model is not suitable within CI pipelines... |
Non-blocking GPU->CPU transfers can create race windows where tensor contents are observed to have incorrect values. Lightning-AI#11287
Yes, I can create a PR to omit non_blocking when copying to CPU. I have a draft, which looks like it's already linked here. As for testing -- this is a race, which makes testing in a cheap way a bit difficult. But I haven't experimented with the limits of reproducibility. But changing the input resolution to 64x64 fixed the issue in my codebase. I personally think that omitting non_blocking for CPU tensors requires less testing than anything involving explicit synchronization. I'd be satisfied with demonstrating that the reproduce script passes as a one-off. |
Non-blocking GPU->CPU transfers can create race windows where tensor contents are observed to have incorrect values. Lightning-AI#11287 Tests appear to rely on device=None (contrary to type annotations), so treat None as a CPU device.
Non-blocking GPU->CPU transfers can create race windows where tensor contents are observed to have incorrect values. Lightning-AI#11287
🐛 Bug
When writing prediction results using a trivial BasePredictionWriter subclass with write_interval "epoch", the final batch is written as zeros when the following conditions are met:
I believe that this is a race condition resulting from non-blocking copies from GPU to CPU without explicit CUDA synchronization. This results in CPU tensors being incorrectly observed as zeros before the GPU computation completes. Details below.
To Reproduce
Output:
Expected behavior
Reproduce script: Both the first and the second batch (logged in the script above) should contain nonzero values, from the same distribution. The computation is such that zeroes are not possible outputs.
In general: Tensors returned to the user on API surfaces should contain valid values when their contents are observed.
Cause (I think)
If I understand the semantics of tensors copied from GPU to CPU with non_blocking=True, these tensors are unsafe without explicit CUDA synchronization to ensure that the CPU tensors have valid values (assuming the copy to CPU is covered by something like torch.cuda.synchronize()).
A non_blocking GPU -> CPU transfer creates a CPU tensor that is effectively a future, similar to CUDA tensors whose values have not resolved. But unlike CUDA tensors, the CPU tensor does not act like a future, and does not know that the value has not been written. As a result, operations that observe the tensor values do not wait for the tensor to be valid, and incorrect (zero) values are observed.
Since these futures are sharp edges, I suggest that we don't expose them to users unless we synchronize appropriately.
Similar bugs resulting from non-blocking GPU -> CPU copies have been reported elsewhere, for instance there are two similar reports on this thread.
Possible fixes
I think this could be done using cuda synchronization before callbacks that might see outputs that have been non-blocking moved to CPU. Calling
torch.cuda.synchronize()
before exposing affected tensors (CPU tensors from non-blocking copies from GPU operations) to the user might fix this.However I think it would be better to not set non_blocking=True when copying tensors to CPU to avoid this case. It will be hard to find all the possible cases above. I don't see a strong argument to using non-blocking copies from GPU, and avoiding incorrect results due to the surprising behavior of these tensors seems like a compelling argument against. I expect that changing
move_data_to_device
to not set non_blocking when device is CPU should work.Fixes I've tested
Copying tensors to CPU in
predict_step
:y = y.cpu()
. This uses a blocking transfer, making Lightning's non-blocking.to()
call a no-op.Inspecting the result of GPU operations also fixes this, by forcing us to wait for the GPU. This has to be done on the GPU tensors, before the CPU copy. (eg.
y.mean().item()
inpredict_step
in the reproduce example).Using CPU compute also fixes this, but is slow, since the examples need enough computation to expose the race with GPUs.
Additionally, various things to "win" the race condition here also work:
Environment
- GPU:
- A100-SXM4-40GB
- A100-SXM4-40GB
- A100-SXM4-40GB
- A100-SXM4-40GB
- A100-SXM4-40GB
- A100-SXM4-40GB
- A100-SXM4-40GB
- A100-SXM4-40GB
- available: True
- version: 11.2
- numpy: 1.21.5
- pyTorch_debug: False
- pyTorch_version: 1.10.0
- pytorch-lightning: 1.5.7
- tqdm: 4.62.3
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.9
cc @tchaton @rohitgr7
The text was updated successfully, but these errors were encountered: