Skip to content

Commit

Permalink
[benchmarks] Fix execution with AMP precision. (#6512)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Feb 12, 2024
1 parent a5692c2 commit 408b376
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
12 changes: 9 additions & 3 deletions benchmarks/benchmark_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict
import contextlib
import logging
import re
import torch
Expand Down Expand Up @@ -66,6 +67,8 @@ def __init__(self, suite_name, model_name, benchmark_experiment):
self.suite_name = suite_name
self.model_name = model_name
self.benchmark_experiment = benchmark_experiment
self.autocast = contextlib.nullcontext
self.autocast_kwargs = {}

def set_up(self):
"""Set up module, actual batch_size, example_inputs, and optimizer_class
Expand Down Expand Up @@ -125,6 +128,7 @@ def pick_grad(self):
return torch.no_grad()
elif self.benchmark_experiment.test == "train":
return torch.enable_grad()
raise NotImplementedError

def _optimizer_zero_grad(self):
if self.optimizer is not None:
Expand All @@ -141,8 +145,9 @@ def compute_loss(self, pred):

def train(self, inputs, collect_full_output=False):
self._optimizer_zero_grad()
pred = self.module(*inputs)
loss = self.compute_loss(pred)
with self.autocast(**self.autocast_kwargs):
pred = self.module(*inputs)
loss = self.compute_loss(pred)
loss.backward()
self._optimizer_step()
if collect_full_output:
Expand All @@ -152,7 +157,8 @@ def train(self, inputs, collect_full_output=False):
return None

def eval(self, inputs, collect_full_output=False):
pred = self.module(*inputs)
with self.autocast(**self.autocast_kwargs):
pred = self.module(*inputs)
return pred

@property
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def run_single_config(self):

# Repeat the experiment and accumulate metrics.
last_output = None
with benchmark_model.pick_context():
with benchmark_model.pick_grad():
accumulated_metrics = OrderedDict()
for repeat_iteration in range(self._args.repeat):
metrics, last_output = self.run_once_and_gather_metrics(
Expand Down
36 changes: 20 additions & 16 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from os.path import abspath, exists
import sys
import torch
import torch.amp
import torch.nn as nn
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
from torch._dynamo.utils import clone_inputs
import torch_xla
import torch_xla.amp
import torch_xla.core.xla_model as xm
import types
import yaml
Expand Down Expand Up @@ -269,11 +269,17 @@ def set_up(self):
This is model suite specific.
"""
# Set the optimizer class.
# Check if we should use SGD instead of Adam for memory reasons.
if self.benchmark_experiment.test == "train" and self.model_name in TRAIN_WITH_SGD:
self.optimizer_class = torch.optim.SGD
else:
self.optimizer_class = torch.optim.Adam

# Setup the autocast environment if we are running on AMP precision.
self.autocast, self.autocast_kwargs = self._get_autocast_with_kwargs()

# Load the actual benchmark instance.
benchmark = self.load_benchmark()

self.module, self.example_inputs = benchmark.get_module()
Expand Down Expand Up @@ -417,26 +423,24 @@ def pick_grad(self):
# special case
if self.model_name in ("maml",):
return torch.enable_grad()
return super().pick_grad()

if self.benchmark_experiment.test == "eval":
return torch.no_grad()
elif self.benchmark_experiment.test == "train":
return torch.enable_grad()

def pick_amp(self):
def _get_autocast_with_kwargs(self):
if (self.benchmark_experiment.accelerator == "cuda" and
self.is_cuda_precision_amp()):
kwargs = {"dtype": torch.bfloat16}
if self.benchmark_experiment.xla:
return torch_xla.amp.autocast(xm.xla_device())
# Should call device specific autocast implementations.
# PyTorch/XLA autocast does not run with dynamo, though:
# https://github.com/pytorch/xla/issues/6511
autocast = torch.amp.autocast
kwargs["device_type"] = "xla"
else:
return torch.cuda.amp.autocast()
return contextlib.nullcontext()

def pick_context(self):
stack = contextlib.ExitStack()
stack.enter_context(self.pick_amp())
stack.enter_context(self.pick_grad())
return stack
autocast = torch.cuda.amp.autocast
else:
kwargs = {}
autocast = contextlib.nullcontext
return (autocast, kwargs)

def compute_loss(self, pred):
"""Reduce the output of a model to get scalar loss"""
Expand Down

0 comments on commit 408b376

Please sign in to comment.