diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 46e12e383be..6fccfe32800 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -714,7 +714,7 @@ def test_xla_sharded_hlo_dump(self): partition_spec) xst2 = xst1 + 5 hlo = torch_xla._XLAC._get_xla_tensors_hlo([xst2.global_tensor]) - self.assertIn('%p1.4 = f32[1,8]{1,0} parameter(1), sharding', hlo) + self.assertIn('%p1.3 = f32[1,8]{1,0} parameter(1), sharding', hlo) # scalar 5 should be implicitly replicated, so the pre-optimization HLO # shouldn't mark it with sharding. self.assertNotIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo)