diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index d3f644b4b733..a01344dd558c 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -383,8 +383,16 @@ def conversion_dtype(self): return torch.bfloat16 def _get_autocast_with_kwargs(self): + kwargs = {} + + # Set the default data-type based on the accelerator. + if self.benchmark_experiment.accelerator == "cuda": + kwargs["dtype"] = torch.float16 + else: + # Both CPU and TPU autocast mode defaults to bfloat16. + kwargs["dtype"] = torch.bfloat16 + if self.use_amp(): - kwargs = {"dtype": torch.bfloat16} if self.benchmark_experiment.xla: # Should call device specific autocast implementations. # PyTorch/XLA autocast does not run with dynamo, though: @@ -394,7 +402,6 @@ def _get_autocast_with_kwargs(self): else: autocast = torch.cuda.amp.autocast else: - kwargs = {} autocast = contextlib.nullcontext return (autocast, kwargs)