From 5cb605d62d26407abc4304aa55ff9a719772c3ff Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Mon, 30 Oct 2023 10:15:52 -0700 Subject: [PATCH] Update set_default_tensor_type to set_default_dtype (#5734) --- test/bench.py | 2 +- test/spmd/test_train_spmd_imagenet.py | 2 +- test/test_operations_hlo.py | 2 +- test/test_profile_mp_mnist.py | 2 +- test/test_train_mp_imagenet.py | 2 +- test/test_train_mp_imagenet_amp.py | 2 +- test/test_train_mp_imagenet_fsdp.py | 2 +- test/test_train_mp_mnist.py | 2 +- test/test_train_mp_mnist_amp.py | 2 +- test/test_train_mp_mnist_fsdp_with_ckpt.py | 2 +- test/test_train_mp_mnist_zero1.py | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/test/bench.py b/test/bench.py index 8d37aafab538..a8908eb9b5e3 100644 --- a/test/bench.py +++ b/test/bench.py @@ -128,7 +128,7 @@ def run_benchmarks(args): args, benchs = parser.parse_known_args() args.benchs = benchs - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) run_benchmarks(args) diff --git a/test/spmd/test_train_spmd_imagenet.py b/test/spmd/test_train_spmd_imagenet.py index 2461658c801c..cde37989ef85 100644 --- a/test/spmd/test_train_spmd_imagenet.py +++ b/test/spmd/test_train_spmd_imagenet.py @@ -372,7 +372,7 @@ def test_loop_fn(loader, epoch): if FLAGS.profile: server = xp.start_server(FLAGS.profiler_port) - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index 3a0b0d168124..9fe8b7b4aae0 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -69,7 +69,7 @@ def test_dropout_by_u8_mask(self): if __name__ == '__main__': - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) torch.manual_seed(42) torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) diff --git a/test/test_profile_mp_mnist.py b/test/test_profile_mp_mnist.py index 5e092b6c3947..f70a380132ef 100644 --- a/test/test_profile_mp_mnist.py +++ b/test/test_profile_mp_mnist.py @@ -198,7 +198,7 @@ def test_loop_fn(loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags, dynamic_graph=True, fetch_often=True) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index 65a359ca9bee..543eeb85abb6 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -358,7 +358,7 @@ def test_loop_fn(loader, epoch): def _mp_fn(index, flags): global FLAGS FLAGS = flags - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index 3ed923897156..ffcf6ee1386a 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -298,7 +298,7 @@ def test_loop_fn(loader, epoch): def _mp_fn(index, flags): global FLAGS FLAGS = flags - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, diff --git a/test/test_train_mp_imagenet_fsdp.py b/test/test_train_mp_imagenet_fsdp.py index fdfdc8a698c1..351e19aad75d 100644 --- a/test/test_train_mp_imagenet_fsdp.py +++ b/test/test_train_mp_imagenet_fsdp.py @@ -385,7 +385,7 @@ def test_loop_fn(loader, epoch): def _mp_fn(index, flags): global FLAGS FLAGS = flags - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 8ae434acfcb8..ec510e982882 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -203,7 +203,7 @@ def test_loop_fn(loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 3c9363f8d09c..990ea9bc91a3 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -211,7 +211,7 @@ def test_loop_fn(loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_train_mp_mnist_fsdp_with_ckpt.py b/test/test_train_mp_mnist_fsdp_with_ckpt.py index 2bb549e72a4e..96d9a9b8bbb1 100644 --- a/test/test_train_mp_mnist_fsdp_with_ckpt.py +++ b/test/test_train_mp_mnist_fsdp_with_ckpt.py @@ -313,7 +313,7 @@ def test_loop_fn(model, loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_train_mp_mnist_zero1.py b/test/test_train_mp_mnist_zero1.py index 6f8d3964b52a..02a6db04a172 100644 --- a/test/test_train_mp_mnist_zero1.py +++ b/test/test_train_mp_mnist_zero1.py @@ -184,7 +184,7 @@ def test_loop_fn(loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir)