diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index 8e7b0a7ed6e..cae60b7889d 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -1,3 +1,4 @@ +import os import sys import torch @@ -162,6 +163,28 @@ def test_separate_graphs(self): self.assertEqual(t1.item(), 3) + def test_xm_save_no_aliasing(self): + """ + Test that xm.save() does not perform aliasing. + """ + xla_device = xm.xla_device() + t0 = torch.tensor([1], device=xla_device) + t1 = torch.tensor([2], device=xla_device) + xm.mark_step() + + t2 = t0 + t1 + t1.add_(1) + + # Save the new value of t1 should not result in the old value + # being donated... + xm.save(t1, os.devnull) + + # otherwise this mark_step could crash, or compute the wrong value + # for t2. + xm.mark_step() + + self.assertEqual(t2.item(), 3) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 6e1936c258a..4f3da8d650a 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1278,7 +1278,7 @@ def _maybe_convert_to_cpu(data: Any, convert: bool = True) -> ToXlaTensorArena: def convert_fn(tensors): torch_xla._XLAC._xla_sync_multi( - tensors, devices=[], wait=True, sync_xla_data=True) + tensors, devices=[], wait=True, sync_xla_data=False) if not convert: return tensors return torch_xla._XLAC._xla_get_cpu_tensors(tensors) diff --git a/torch_xla/utils/serialization.py b/torch_xla/utils/serialization.py index ed3797ad945..05cfa93e2ea 100644 --- a/torch_xla/utils/serialization.py +++ b/torch_xla/utils/serialization.py @@ -25,7 +25,7 @@ def _rewrite_data(path, data, save_tensors): def convert_fn(tensors): torch_xla._XLAC._xla_sync_multi( - tensors, devices=[], wait=True, sync_xla_data=True) + tensors, devices=[], wait=True, sync_xla_data=False) rewritten_tensors = [] for i, t in enumerate(tensors): if save_tensors: