diff --git a/benchmarks/util.py b/benchmarks/util.py index 88bf452bbdd..8ab1d5a0181 100644 --- a/benchmarks/util.py +++ b/benchmarks/util.py @@ -76,7 +76,7 @@ def is_xla_device_available(devkind): return r.returncode == 0 -def move_to_device(item, device, torch_xla2): +def move_to_device(item, device, torch_xla2=False): if torch_xla2: import torch_xla2 import jax