diff --git a/benchmarks/benchmark_model.py b/benchmarks/benchmark_model.py index 3890a506ddd..1a62d986c36 100644 --- a/benchmarks/benchmark_model.py +++ b/benchmarks/benchmark_model.py @@ -1,4 +1,5 @@ from collections import OrderedDict +import contextlib import logging import re import torch @@ -66,6 +67,8 @@ def __init__(self, suite_name, model_name, benchmark_experiment): self.suite_name = suite_name self.model_name = model_name self.benchmark_experiment = benchmark_experiment + self.autocast = contextlib.nullcontext + self.autocast_kwargs = {} def set_up(self): """Set up module, actual batch_size, example_inputs, and optimizer_class @@ -125,6 +128,7 @@ def pick_grad(self): return torch.no_grad() elif self.benchmark_experiment.test == "train": return torch.enable_grad() + raise NotImplementedError def _optimizer_zero_grad(self): if self.optimizer is not None: @@ -141,8 +145,9 @@ def compute_loss(self, pred): def train(self, inputs, collect_full_output=False): self._optimizer_zero_grad() - pred = self.module(*inputs) - loss = self.compute_loss(pred) + with self.autocast(**self.autocast_kwargs): + pred = self.module(*inputs) + loss = self.compute_loss(pred) loss.backward() self._optimizer_step() if collect_full_output: @@ -152,7 +157,8 @@ def train(self, inputs, collect_full_output=False): return None def eval(self, inputs, collect_full_output=False): - pred = self.module(*inputs) + with self.autocast(**self.autocast_kwargs): + pred = self.module(*inputs) return pred @property diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py index 336e36c4287..93309429e64 100644 --- a/benchmarks/experiment_runner.py +++ b/benchmarks/experiment_runner.py @@ -248,7 +248,7 @@ def run_single_config(self): # Repeat the experiment and accumulate metrics. last_output = None - with benchmark_model.pick_context(): + with benchmark_model.pick_grad(): accumulated_metrics = OrderedDict() for repeat_iteration in range(self._args.repeat): metrics, last_output = self.run_once_and_gather_metrics( diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index 6fa64b0329a..7d1cd35be86 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -7,11 +7,11 @@ from os.path import abspath, exists import sys import torch +import torch.amp import torch.nn as nn from torch._dynamo.testing import collect_results, reduce_to_scalar_loss from torch._dynamo.utils import clone_inputs import torch_xla -import torch_xla.amp import torch_xla.core.xla_model as xm import types import yaml @@ -269,11 +269,17 @@ def set_up(self): This is model suite specific. """ + # Set the optimizer class. + # Check if we should use SGD instead of Adam for memory reasons. if self.benchmark_experiment.test == "train" and self.model_name in TRAIN_WITH_SGD: self.optimizer_class = torch.optim.SGD else: self.optimizer_class = torch.optim.Adam + # Setup the autocast environment if we are running on AMP precision. + self.autocast, self.autocast_kwargs = self._get_autocast_with_kwargs() + + # Load the actual benchmark instance. benchmark = self.load_benchmark() self.module, self.example_inputs = benchmark.get_module() @@ -417,26 +423,24 @@ def pick_grad(self): # special case if self.model_name in ("maml",): return torch.enable_grad() + return super().pick_grad() - if self.benchmark_experiment.test == "eval": - return torch.no_grad() - elif self.benchmark_experiment.test == "train": - return torch.enable_grad() - - def pick_amp(self): + def _get_autocast_with_kwargs(self): if (self.benchmark_experiment.accelerator == "cuda" and self.is_cuda_precision_amp()): + kwargs = {"dtype": torch.bfloat16} if self.benchmark_experiment.xla: - return torch_xla.amp.autocast(xm.xla_device()) + # Should call device specific autocast implementations. + # PyTorch/XLA autocast does not run with dynamo, though: + # https://github.com/pytorch/xla/issues/6511 + autocast = torch.amp.autocast + kwargs["device_type"] = "xla" else: - return torch.cuda.amp.autocast() - return contextlib.nullcontext() - - def pick_context(self): - stack = contextlib.ExitStack() - stack.enter_context(self.pick_amp()) - stack.enter_context(self.pick_grad()) - return stack + autocast = torch.cuda.amp.autocast + else: + kwargs = {} + autocast = contextlib.nullcontext + return (autocast, kwargs) def compute_loss(self, pred): """Reduce the output of a model to get scalar loss"""