Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[benchmarks] Fix execution with AMP precision. #6512

Merged
merged 6 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions benchmarks/benchmark_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict
import contextlib
import logging
import re
import torch
Expand Down Expand Up @@ -66,6 +67,8 @@ def __init__(self, suite_name, model_name, benchmark_experiment):
self.suite_name = suite_name
self.model_name = model_name
self.benchmark_experiment = benchmark_experiment
self.autocast = contextlib.nullcontext
self.autocast_kwargs = {}

def set_up(self):
"""Set up module, actual batch_size, example_inputs, and optimizer_class
Expand Down Expand Up @@ -125,6 +128,7 @@ def pick_grad(self):
return torch.no_grad()
elif self.benchmark_experiment.test == "train":
return torch.enable_grad()
raise NotImplementedError

def _optimizer_zero_grad(self):
if self.optimizer is not None:
Expand All @@ -141,8 +145,9 @@ def compute_loss(self, pred):

def train(self, inputs, collect_full_output=False):
self._optimizer_zero_grad()
pred = self.module(*inputs)
loss = self.compute_loss(pred)
with self.autocast(**self.autocast_kwargs):
pred = self.module(*inputs)
loss = self.compute_loss(pred)
loss.backward()
self._optimizer_step()
if collect_full_output:
Expand All @@ -152,7 +157,8 @@ def train(self, inputs, collect_full_output=False):
return None

def eval(self, inputs, collect_full_output=False):
pred = self.module(*inputs)
with self.autocast(**self.autocast_kwargs):
pred = self.module(*inputs)
return pred

@property
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def run_single_config(self):

# Repeat the experiment and accumulate metrics.
last_output = None
with benchmark_model.pick_context():
with benchmark_model.pick_grad():
accumulated_metrics = OrderedDict()
for repeat_iteration in range(self._args.repeat):
metrics, last_output = self.run_once_and_gather_metrics(
Expand Down
36 changes: 20 additions & 16 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from os.path import abspath, exists
import sys
import torch
import torch.amp
import torch.nn as nn
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
from torch._dynamo.utils import clone_inputs
import torch_xla
import torch_xla.amp
import torch_xla.core.xla_model as xm
import types
import yaml
Expand Down Expand Up @@ -260,11 +260,17 @@ def set_up(self):

This is model suite specific.
"""
# Set the optimizer class.
# Check if we should use SGD instead of Adam for memory reasons.
if self.benchmark_experiment.test == "train" and self.model_name in TRAIN_WITH_SGD:
self.optimizer_class = torch.optim.SGD
else:
self.optimizer_class = torch.optim.Adam

# Setup the autocast environment if we are running on AMP precision.
self.autocast, self.autocast_kwargs = self._get_autocast_with_kwargs()

# Load the actual benchmark instance.
benchmark = self.load_benchmark()

self.module, self.example_inputs = benchmark.get_module()
Expand Down Expand Up @@ -405,26 +411,24 @@ def pick_grad(self):
# special case
if self.model_name in ("maml",):
return torch.enable_grad()
return super().pick_grad()

if self.benchmark_experiment.test == "eval":
return torch.no_grad()
elif self.benchmark_experiment.test == "train":
return torch.enable_grad()

def pick_amp(self):
def _get_autocast_with_kwargs(self):
if (self.benchmark_experiment.accelerator == "cuda" and
self.is_cuda_precision_amp()):
kwargs = {"dtype": torch.bfloat16}
if self.benchmark_experiment.xla:
return torch_xla.amp.autocast(xm.xla_device())
# Should call device specific autocast implementations.
# PyTorch/XLA autocast does not run with dynamo, though:
# https://github.com/pytorch/xla/issues/6511
autocast = torch.amp.autocast
kwargs["device_type"] = "xla"
else:
return torch.cuda.amp.autocast()
return contextlib.nullcontext()

def pick_context(self):
stack = contextlib.ExitStack()
stack.enter_context(self.pick_amp())
stack.enter_context(self.pick_grad())
return stack
autocast = torch.cuda.amp.autocast
else:
kwargs = {}
autocast = contextlib.nullcontext
return (autocast, kwargs)

def compute_loss(self, pred):
"""Reduce the output of a model to get scalar loss"""
Expand Down
Loading