Skip to content

Commit

Permalink
fix sharding stage3 bug (#60085) (#60106)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd authored Dec 18, 2023
1 parent 1c0ffeb commit e8ee704
Showing 1 changed file with 27 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,32 @@
from .group_sharded_utils import GroupShardedClipGrad, Type, device_guard


class OrderedSet:
def __init__(self, iterable=None):
self._data = OrderedDict.fromkeys(iterable or [])

def __contains__(self, item):
return item in self._data

def __iter__(self):
return iter(self._data)

def __len__(self):
return len(self._data)

def add(self, item):
self._data[item] = None

def discard(self, item):
self._data.pop(item, None)

def update(self, iterable):
self._data.update((item, None) for item in iterable)

def __repr__(self):
return f"{self.__class__.__name__}({list(self._data)})"


def _all_gather(tensor, buffer_size, group):
"""
The main difference with paddle.distributed.all_gather:
Expand Down Expand Up @@ -148,7 +174,7 @@ def __init__(
{}
) # {param.name: [(start0, end0),(start1, end1), ...]}
self._trainable_params = {} # {id(layer): [trainable_params]}
self._unslice_params = set() # param's numel <= segment_size
self._unslice_params = OrderedSet() # param's numel <= segment_size
self._unslice_params2align = {} # {param.name: param's align}
self._grad_storages = {} # {param.dtype: GradStorage}

Expand Down

0 comments on commit e8ee704

Please sign in to comment.