Skip to content

Commit

Permalink
Add xla.sync (#6914)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored Apr 11, 2024
1 parent f0354ec commit 4245271
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
13 changes: 12 additions & 1 deletion test/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
import torch
import torch_xla as xla
import torch_xla.runtime as xr
import torch_xla.debug.metrics as met


class TestDevices(parameterized.TestCase):

def setUpClass():
@classmethod
def setUpClass(cls):
xr.set_device_type('CPU')
os.environ['CPU_NUM_DEVICES'] = '4'

def tearDown(self):
met.clear_metrics()

@parameterized.parameters((None, torch.device('xla:0')),
(0, torch.device('xla:0')),
(3, torch.device('xla:3')))
Expand All @@ -29,6 +34,12 @@ def test_real_devices(self):
def test_device_count(self):
self.assertEqual(xla.device_count(), 4)

def test_sync(self):
torch.ones((3, 3), device=xla.device())
xla.sync()

self.assertEqual(met.counter_value('MarkStep'), 1)


if __name__ == "__main__":
absltest.main()
5 changes: 5 additions & 0 deletions torch_xla/torch_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ def real_devices() -> List[str]:
def device_count() -> int:
"""Returns number of addressable devices in the current process."""
return len(real_devices())


def sync():
"""Launches all pending graph operations."""
xm.mark_step()

0 comments on commit 4245271

Please sign in to comment.