From 5bd3cd5f712b65d38812b27cf957261bb06b83c5 Mon Sep 17 00:00:00 2001 From: Edward Brown <30390944+EdwardJB@users.noreply.github.com> Date: Thu, 15 Apr 2021 02:22:11 +0100 Subject: [PATCH] Bugfix/cuda oom detection and handling (#6934) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/utilities/memory.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index d67739c3b3fc2..6c01390a8c81e 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -53,7 +53,8 @@ def is_oom_error(exception): def is_cuda_out_of_memory(exception): return isinstance(exception, RuntimeError) \ and len(exception.args) == 1 \ - and "CUDA out of memory." in exception.args[0] + and "CUDA" in exception.args[0] \ + and "out of memory" in exception.args[0] # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py @@ -76,4 +77,10 @@ def garbage_collection_cuda(): """Garbage collection Torch (CUDA) memory.""" gc.collect() if torch.cuda.is_available(): - torch.cuda.empty_cache() + try: + # This is the last thing that should cause an OOM error, but seemingly it can. + torch.cuda.empty_cache() + except RuntimeError as exception: + if not is_oom_error(exception): + # Only handle OOM errors + raise