From a24d373a9e0784215c9cd244dff7a318465fc50d Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Wed, 13 Mar 2024 04:05:57 +0000 Subject: [PATCH] Update existing unit tests for aten::add HLO dumps --- test/spmd/test_xla_sharding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 0e76f1c842e8..6c488c2f72e8 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -714,7 +714,7 @@ def test_xla_sharded_hlo_dump(self): partition_spec) xst2 = xst1 + 5 hlo = torch_xla._XLAC._get_xla_tensors_hlo([xst2.global_tensor]) - self.assertIn('%p1.4 = f32[1,8]{1,0} parameter(1), sharding', hlo) + self.assertIn('%p1.3 = f32[1,8]{1,0} parameter(1), sharding', hlo) # scalar 5 should be replicated self.assertIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo) @@ -830,7 +830,7 @@ def test_mark_sharding_ir(self): actual += 0 hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor]) self.assertIn( - '%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.10, f32[1,128]{1,0} %broadcast.11)', + '%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9, f32[1,128]{1,0} %broadcast.11)', hlo) self.assertTrue(torch.allclose(expected, actual.cpu()))