Skip to content

Commit

Permalink
[Cherry-pick] Add is_distributed field in sharding reshard param_meta (
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy authored and Mangodadada committed Sep 10, 2024
1 parent 542b0b2 commit ee0446e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion paddlenlp/trainer/utils/sharding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,8 @@ def _gather_sharding_metas(self):
param_meta = {}
for k, v in model.state_dict().items():
structure_name_mapping[k] = v.name
param_meta[k] = (v.shape, int(v.dtype))
is_distributed = getattr(v, "is_distributed", False)
param_meta[k] = (v.shape, int(v.dtype), is_distributed)

sharding_metas = {}
sharding_meta = {}
Expand Down

0 comments on commit ee0446e

Please sign in to comment.