diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 2d345081e74..c46642fa448 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -314,8 +314,8 @@ def step(self, closure=None, **kwargs): if grad_shard.dtype != self.optimizer_dtype: grad_shard = grad_shard.to(dtype=self.optimizer_dtype) - shard.grad = grad_shard - index += 1 + shard.grad = grad_shard + index += 1 if self.grad_clipping: # Update unscale/clip with sub partitions