diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 14a411ae253566..3731332d1e7774 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -1219,6 +1219,7 @@ def _parallel_linear(x, inputs={'X': linear_out}, outputs={'Out': out}, attrs={ + 'rank': inner_rank, 'ring_id': ring_id, 'nranks': nranks, 'use_calc_stream': True,