diff --git a/test/bench.py b/test/bench.py index 8d37aafab53..a8908eb9b5e 100644 --- a/test/bench.py +++ b/test/bench.py @@ -128,7 +128,7 @@ def run_benchmarks(args): args, benchs = parser.parse_known_args() args.benchs = benchs - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) run_benchmarks(args) diff --git a/test/spmd/test_train_spmd_imagenet.py b/test/spmd/test_train_spmd_imagenet.py index 2461658c801..cde37989ef8 100644 --- a/test/spmd/test_train_spmd_imagenet.py +++ b/test/spmd/test_train_spmd_imagenet.py @@ -372,7 +372,7 @@ def test_loop_fn(loader, epoch): if FLAGS.profile: server = xp.start_server(FLAGS.profiler_port) - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index 3a0b0d16812..9fe8b7b4aae 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -69,7 +69,7 @@ def test_dropout_by_u8_mask(self): if __name__ == '__main__': - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) torch.manual_seed(42) torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) diff --git a/test/test_profile_mp_mnist.py b/test/test_profile_mp_mnist.py index 5e092b6c394..f70a380132e 100644 --- a/test/test_profile_mp_mnist.py +++ b/test/test_profile_mp_mnist.py @@ -198,7 +198,7 @@ def test_loop_fn(loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags, dynamic_graph=True, fetch_often=True) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index 122e8513dbf..224c77782f8 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -366,7 +366,7 @@ def test_loop_fn(loader, epoch): def _mp_fn(index, flags): global FLAGS FLAGS = flags - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index 3ed92389715..ffcf6ee1386 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -298,7 +298,7 @@ def test_loop_fn(loader, epoch): def _mp_fn(index, flags): global FLAGS FLAGS = flags - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, diff --git a/test/test_train_mp_imagenet_fsdp.py b/test/test_train_mp_imagenet_fsdp.py index fdfdc8a698c..351e19aad75 100644 --- a/test/test_train_mp_imagenet_fsdp.py +++ b/test/test_train_mp_imagenet_fsdp.py @@ -385,7 +385,7 @@ def test_loop_fn(loader, epoch): def _mp_fn(index, flags): global FLAGS FLAGS = flags - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 22253fbea73..0f811d3a67e 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -209,7 +209,7 @@ def test_loop_fn(loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 3c9363f8d09..990ea9bc91a 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -211,7 +211,7 @@ def test_loop_fn(loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_train_mp_mnist_fsdp_with_ckpt.py b/test/test_train_mp_mnist_fsdp_with_ckpt.py index 2bb549e72a4..96d9a9b8bbb 100644 --- a/test/test_train_mp_mnist_fsdp_with_ckpt.py +++ b/test/test_train_mp_mnist_fsdp_with_ckpt.py @@ -313,7 +313,7 @@ def test_loop_fn(model, loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_train_mp_mnist_zero1.py b/test/test_train_mp_mnist_zero1.py index 6f8d3964b52..02a6db04a17 100644 --- a/test/test_train_mp_mnist_zero1.py +++ b/test/test_train_mp_mnist_zero1.py @@ -184,7 +184,7 @@ def test_loop_fn(loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir)