Skip to content

Commit

Permalink
[benchmarks] Default to bfloat16 (inference) and AMP (training) pre…
Browse files Browse the repository at this point in the history
…cision. (#6518)
  • Loading branch information
ysiraichi authored Feb 16, 2024
1 parent b7c760d commit 20692cb
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 87 deletions.
11 changes: 10 additions & 1 deletion benchmarks/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn
from torch._dynamo.testing import collect_results
from util import move_to_device
from util import cast_to_dtype, move_to_device

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,8 +104,17 @@ def _prepare_for_train(self):
# optimizer to use. So only initialize it when there is none existing.
self.optimizer = self.optimizer_class(self.module.parameters(), lr=0.01)

def conversion_dtype(self):
return None

def prepare_for_experiment(self, dynamo_compilation_opts):
self.device = self.benchmark_experiment.get_device()
self.dtype = self.conversion_dtype()

if self.dtype is not None:
self.module = self.module.to(self.dtype)
self.example_inputs = cast_to_dtype(self.example_inputs, self.dtype)

self.module = self.module.to(self.device)
self.example_inputs = move_to_device(self.example_inputs, self.device)

Expand Down
128 changes: 42 additions & 86 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,18 @@
"hf_T5_generate",
}

FORCE_AMP_FOR_FP16_BF16_MODELS = {
"DALLE2_pytorch",
"doctr_det_predictor",
"doctr_reco_predictor",
"Super_SloMo",
"tts_angular",
"pyhpc_turbulent_kinetic_energy",
"detectron2_fcos_r_50_fpn",
}

FORCE_FP16_FOR_BF16_MODELS = {"vision_maskrcnn"}


class TorchBenchModelLoader(ModelLoader):

Expand Down Expand Up @@ -295,10 +307,6 @@ def set_up(self):
self.example_inputs = move_to_device(self.example_inputs, "cpu")
self._cleanup()

device = self.benchmark_experiment.get_device()
self.module = self.module.to(device)
self.example_inputs = move_to_device(self.example_inputs, device)

# Torchbench has quite different setup for yolov3, so directly passing
# the right example_inputs
if self.model_name == "yolov3":
Expand Down Expand Up @@ -334,88 +342,13 @@ def load_benchmark(self):
if (device := self.benchmark_experiment.accelerator) == 'tpu':
device = str(self.benchmark_experiment.get_device())

kwargs = {
"test": self.benchmark_experiment.test,
"device": device,
"batch_size": self.benchmark_experiment.batch_size,
}

# Force FP32 when precision is either FP16 or BF16 (only for XLA:CUDA).
# If the model is, e.g. FP16 and XLA_USE_FP16, XLA will unexpectedly up-cast
# return values to FP32.
# Issue: https://github.com/pytorch/xla/issues/6348
if self.benchmark_experiment.accelerator == "cuda" and self.benchmark_experiment.xla:
if self.is_cuda_precision_fp16() or self.is_cuda_precision_bf16():
# PyTorch/benchmark will use these 'extra_args' for converting the model.
kwargs["extra_args"] = ["--precision", "fp32"]

return self.benchmark_cls()(**kwargs)

def get_cuda_precision(self):
test = self.benchmark_experiment.test.upper()
attr = f"DEFAULT_{test}_CUDA_PRECISION"
return getattr(self.benchmark_cls(), attr, None)

def is_cuda_precision_fp16(self):
return self.get_cuda_precision() == "fp16"

def is_cuda_precision_fp32(self):
return self.get_cuda_precision() == "fp32"

def is_cuda_precision_bf16(self):
return self.get_cuda_precision() == "bf16"

def is_cuda_precision_amp(self):
return self.get_cuda_precision() == "amp"

@property
def default_precision_flag(self):
"""
Get the default precision config to XLA, if present.
Whenever a model has a default precision for cuda set
we need to set proper environment flags so XLA catches
the requird precision.
This function is a workaround. Proper solution requires
changes to the PT/XLA bridge so that the input shape
is properly inferred after issuing converts to `torch.nn.Module`.
"""
# At this moment, this method checks the precision flags only if both
# of the items below are true:
#
# 1. Device is CUDA: only check for 'DEFAULT_CUDA_<test>_PRECISION'
#
# 2. Dynamo backend is not inductor: PyTorch/benchmark scripts already
# take care of converting the model to the right precision.
#
if (self.benchmark_experiment.accelerator != "cuda" or
self.benchmark_experiment.dynamo == "inductor"):
return None

if self.get_cuda_precision() is None:
return None

if self.is_cuda_precision_fp16():
return 'XLA_USE_FP16'

if self.is_cuda_precision_bf16():
return 'XLA_USE_BF16'

if self.is_cuda_precision_amp():
return None

if self.is_cuda_precision_fp32():
logger.warning("Sticking with the default fp32 precision.")
return None

raise ValueError(f"Unknown precision: {precision}")
return self.benchmark_cls()(
test=self.benchmark_experiment.test,
device=device,
batch_size=self.benchmark_experiment.batch_size,
)

def update_process_env(self, process_env):
precision_flag = self.default_precision_flag
if precision_flag is not None:
process_env[precision_flag] = '1'

if self.model_name in NEED_LARGER_CACHE:
process_env["XLA_COMPILATION_CACHE_SIZE"] = "2048"

Expand All @@ -425,9 +358,32 @@ def pick_grad(self):
return torch.enable_grad()
return super().pick_grad()

def is_inference(self):
return self.benchmark_experiment.test == "eval"

def is_training(self):
return self.benchmark_experiment.test == "train"

def use_amp(self):
return self.is_training(
) or self.model_name in FORCE_AMP_FOR_FP16_BF16_MODELS

def use_fp16(self):
return self.is_inference() and self.model_name in FORCE_FP16_FOR_BF16_MODELS

def conversion_dtype(self):
if self.is_training() or self.use_amp():
return super().conversion_dtype()

# From here, we are running inference without AMP, for sure.
# Do we have to use float16, instead of bfloat16?
if self.use_fp16():
return torch.float16

return torch.bfloat16

def _get_autocast_with_kwargs(self):
if (self.benchmark_experiment.accelerator == "cuda" and
self.is_cuda_precision_amp()):
if self.use_amp():
kwargs = {"dtype": torch.bfloat16}
if self.benchmark_experiment.xla:
# Should call device specific autocast implementations.
Expand Down
9 changes: 9 additions & 0 deletions benchmarks/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ def move_to_device(item, device):
return pytree.tree_map_only(torch.Tensor, lambda t: t.to(device), item)


def cast_to_dtype(item, dtype):
return pytree.tree_map_only(
torch.Tensor,
lambda t: t.to(dtype)
if isinstance(t, torch.Tensor) and t.is_floating_point() else t,
item,
)


def randomize_input(inputs):
if isinstance(inputs, torch.Tensor):
if inputs.dtype in (torch.float32, torch.float64):
Expand Down

0 comments on commit 20692cb

Please sign in to comment.