Skip to content

Commit

Permalink
Update existing unit tests for aten::add HLO dumps
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Mar 13, 2024
1 parent 607156a commit a24d373
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def test_xla_sharded_hlo_dump(self):
partition_spec)
xst2 = xst1 + 5
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xst2.global_tensor])
self.assertIn('%p1.4 = f32[1,8]{1,0} parameter(1), sharding', hlo)
self.assertIn('%p1.3 = f32[1,8]{1,0} parameter(1), sharding', hlo)
# scalar 5 should be replicated
self.assertIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo)

Expand Down Expand Up @@ -830,7 +830,7 @@ def test_mark_sharding_ir(self):
actual += 0
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
self.assertIn(
'%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.10, f32[1,128]{1,0} %broadcast.11)',
'%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9, f32[1,128]{1,0} %broadcast.11)',
hlo)

self.assertTrue(torch.allclose(expected, actual.cpu()))
Expand Down

0 comments on commit a24d373

Please sign in to comment.