From 3d21839d098b38f463a21e1bf7f1bb3b290c5dbe Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 17 Feb 2024 14:09:57 -0300 Subject: [PATCH] [benchmarks] Fix AMP data-type. (#6550) --- benchmarks/torchbench_model.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index d3f644b4b73..a01344dd558 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -383,8 +383,16 @@ def conversion_dtype(self): return torch.bfloat16 def _get_autocast_with_kwargs(self): + kwargs = {} + + # Set the default data-type based on the accelerator. + if self.benchmark_experiment.accelerator == "cuda": + kwargs["dtype"] = torch.float16 + else: + # Both CPU and TPU autocast mode defaults to bfloat16. + kwargs["dtype"] = torch.bfloat16 + if self.use_amp(): - kwargs = {"dtype": torch.bfloat16} if self.benchmark_experiment.xla: # Should call device specific autocast implementations. # PyTorch/XLA autocast does not run with dynamo, though: @@ -394,7 +402,6 @@ def _get_autocast_with_kwargs(self): else: autocast = torch.cuda.amp.autocast else: - kwargs = {} autocast = contextlib.nullcontext return (autocast, kwargs)