Skip to content

Commit

Permalink
Update set_default_tensor_type to set_default_dtype (#5734)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored and jonb377 committed Oct 31, 2023
1 parent a2b3b30 commit 5cb605d
Show file tree
Hide file tree
Showing 11 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion test/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def run_benchmarks(args):
args, benchs = parser.parse_known_args()
args.benchs = benchs

torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_dtype(torch.float32)
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
run_benchmarks(args)
2 changes: 1 addition & 1 deletion test/spmd/test_train_spmd_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def test_loop_fn(loader, epoch):
if FLAGS.profile:
server = xp.start_server(FLAGS.profiler_port)

torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_dtype(torch.float32)
accuracy = train_imagenet()
if accuracy < FLAGS.target_accuracy:
print('Accuracy {} is below target {}'.format(accuracy,
Expand Down
2 changes: 1 addition & 1 deletion test/test_operations_hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_dropout_by_u8_mask(self):


if __name__ == '__main__':
torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
Expand Down
2 changes: 1 addition & 1 deletion test/test_profile_mp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_loop_fn(loader):


def _mp_fn(index, flags):
torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_dtype(torch.float32)
accuracy = train_mnist(flags, dynamic_graph=True, fetch_often=True)
if flags.tidy and os.path.isdir(flags.datadir):
shutil.rmtree(flags.datadir)
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def test_loop_fn(loader, epoch):
def _mp_fn(index, flags):
global FLAGS
FLAGS = flags
torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_dtype(torch.float32)
accuracy = train_imagenet()
if accuracy < FLAGS.target_accuracy:
print('Accuracy {} is below target {}'.format(accuracy,
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_imagenet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def test_loop_fn(loader, epoch):
def _mp_fn(index, flags):
global FLAGS
FLAGS = flags
torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_dtype(torch.float32)
accuracy = train_imagenet()
if accuracy < FLAGS.target_accuracy:
print('Accuracy {} is below target {}'.format(accuracy,
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_imagenet_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def test_loop_fn(loader, epoch):
def _mp_fn(index, flags):
global FLAGS
FLAGS = flags
torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_dtype(torch.float32)
accuracy = train_imagenet()
if accuracy < FLAGS.target_accuracy:
print('Accuracy {} is below target {}'.format(accuracy,
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_loop_fn(loader):


def _mp_fn(index, flags):
torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_dtype(torch.float32)
accuracy = train_mnist(flags)
if flags.tidy and os.path.isdir(flags.datadir):
shutil.rmtree(flags.datadir)
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_mnist_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_loop_fn(loader):


def _mp_fn(index, flags):
torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_dtype(torch.float32)
accuracy = train_mnist(flags)
if flags.tidy and os.path.isdir(flags.datadir):
shutil.rmtree(flags.datadir)
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_mnist_fsdp_with_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def test_loop_fn(model, loader):


def _mp_fn(index, flags):
torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_dtype(torch.float32)
accuracy = train_mnist(flags)
if flags.tidy and os.path.isdir(flags.datadir):
shutil.rmtree(flags.datadir)
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_mnist_zero1.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_loop_fn(loader):


def _mp_fn(index, flags):
torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_dtype(torch.float32)
accuracy = train_mnist(flags)
if flags.tidy and os.path.isdir(flags.datadir):
shutil.rmtree(flags.datadir)
Expand Down

0 comments on commit 5cb605d

Please sign in to comment.