Skip to content

Commit

Permalink
override gather method in DP
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarang committed Aug 2, 2020
1 parent 7013897 commit a6762e2
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def forward(self, *inputs, **kwargs):
if isinstance(outputs[0], Result):
outputs = self.__gather_structured_result(outputs)
else:
outputs = self.gather(outputs, self.output_device)
outputs = self.gather(outputs)
return outputs

def __gather_structured_result(self, outputs):
Expand All @@ -83,7 +83,7 @@ def __gather_structured_result(self, outputs):
for i, output in enumerate(outputs):
del output['meta']

outputs = self.gather(outputs, self.output_device)
outputs = self.gather(outputs)

# pass minimize to constructor for TrainResult
if 'minimize' in outputs:
Expand All @@ -106,16 +106,16 @@ def gather_map(outputs):
if isinstance(elem, torch.Tensor):
return Gather.apply(self.output_device, self.dim, *outputs)

elif elem is None:
if elem is None:
return None

elif isinstance(elem, Mapping):
if isinstance(elem, Mapping):
if not all((len(elem) == len(d) for d in outputs)):
raise ValueError('All dicts must have the same number of keys')
return elem_type(((k, gather_map([d[k] for d in outputs]))
for k in elem))

elif isinstance(elem, Iterable) and not isinstance(elem, str):
if isinstance(elem, Iterable) and not isinstance(elem, str):
return elem_type(map(gather_map, zip(*outputs)))

return outputs
Expand Down

0 comments on commit a6762e2

Please sign in to comment.