Skip to content

Commit

Permalink
Clean-up CUDA cache. (#6325)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Jan 19, 2024
1 parent e93597d commit 423bb0b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
28 changes: 21 additions & 7 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ class TorchBenchModel(BenchmarkModel):
def __init__(self, suite_name, model_name, benchmark_experiment):
super().__init__(suite_name, model_name, benchmark_experiment)

def _cleanup(self):
# Garbage-collect right now.
gc.collect()

# If we are using CUDA, clean-up its cache left-over.
if self.benchmark_experiment.accelerator == "cuda":
torch.cuda.empty_cache()

def set_up(self):
"""Set up module, actual batch_size, example_inputs, and optimizer_class
Expand All @@ -181,12 +189,16 @@ def set_up(self):

# Move the initialized model to XLA device.
if self.benchmark_experiment.xla:
import torch.utils._pytree as pytree
# First, move the model and the inputs to CPU.
# This avoids having dupplicated data on CUDA.
if self.benchmark_experiment.accelerator == "cuda":
self.module = self.module.to("cpu")
self.example_inputs = move_to_device(self.example_inputs, "cpu")
self._cleanup()

device = self.benchmark_experiment.get_device()
self.module = self.module.to(device)
self.example_inputs = pytree.tree_map_only(torch.Tensor,
lambda t: t.to(device),
self.example_inputs)
self.example_inputs = move_to_device(self.example_inputs, device)

# Torchbench has quite different setup for yolov3, so directly passing
# the right example_inputs
Expand All @@ -197,7 +209,7 @@ def set_up(self):
self.optimizer = benchmark.optimizer

del benchmark
gc.collect()
self._cleanup()

def load_benchmark(self):
try:
Expand Down Expand Up @@ -247,13 +259,15 @@ def default_precision_flag(self):
elif test == "train" and hasattr(benchmark, 'DEFAULT_TRAIN_CUDA_PRECISION'):
precision = benchmark.DEFAULT_TRAIN_CUDA_PRECISION
else:
precision = None
logger.warning("No default precision set. No patching needed.")
return None

del benchmark
gc.collect()
self._cleanup()

precision_flag = None
if precision is None:
return None
if precision == "fp16":
precision_flag = 'XLA_USE_FP16'
elif precision == "amp":
Expand Down
12 changes: 2 additions & 10 deletions benchmarks/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import random
import subprocess
import torch
import torch.utils._pytree as pytree
import sys
import torch_xla.core.xla_model as xm
from torch_xla._internal import tpu
Expand Down Expand Up @@ -76,16 +77,7 @@ def is_xla_device_available(devkind):


def move_to_device(item, device):
if isinstance(item, torch.Tensor):
return item.to(device=device)
elif isinstance(item, list):
return [move_to_device(t, device) for t in item]
elif isinstance(item, tuple):
return tuple(move_to_device(t, device) for t in item)
elif isinstance(item, dict):
return dict((k, move_to_device(t, device)) for k, t in item.items())
else:
return item
return pytree.tree_map_only(torch.Tensor, lambda t: t.to(device), item)


def randomize_input(inputs):
Expand Down

0 comments on commit 423bb0b

Please sign in to comment.