Skip to content

Commit

Permalink
Use collections.abc.Mapping to handle both the dict and the UserDict …
Browse files Browse the repository at this point in the history
…types (#180)

* Use Mapping to handle dict and UserDict

* Address comments

* Remove ** syntax
  • Loading branch information
mariosasko authored Oct 4, 2021
1 parent 5343b4e commit c5c73e0
Showing 1 changed file with 6 additions and 17 deletions.
23 changes: 6 additions & 17 deletions src/accelerate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import importlib
import os
import random
from collections import UserDict
from collections.abc import Mapping
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, Union
Expand Down Expand Up @@ -163,7 +163,7 @@ def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_oth
for o in data
),
)
elif isinstance(data, UserDict):
elif isinstance(data, Mapping):
return type(data)(
{
k: recursively_apply(
Expand All @@ -172,15 +172,6 @@ def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_oth
for k, v in data.items()
}
)
elif isinstance(data, dict):
return type(data)(
**{
k: recursively_apply(
func, v, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
)
for k, v in data.items()
}
)
elif test_type(data):
return func(data, *args, **kwargs)
elif error_on_other_type:
Expand Down Expand Up @@ -310,7 +301,7 @@ def extract_model_from_parallel(model):
def _tpu_gather(tensor, name="gather tensor"):
if isinstance(tensor, (list, tuple)):
return honor_type(tensor, (_tpu_gather(t, name=f"{name}_{i}") for i, t in enumerate(tensor)))
elif isinstance(tensor, (dict, UserDict)):
elif isinstance(tensor, Mapping):
return type(tensor)({k: _tpu_gather(v, name=f"{name}_{k}") for k, v in tensor.items()})
elif not isinstance(tensor, torch.Tensor):
raise TypeError(f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors.")
Expand Down Expand Up @@ -365,7 +356,7 @@ def _gpu_broadcast_one(tensor, src=0):
def _tpu_broadcast(tensor, src=0, name="broadcast tensor"):
if isinstance(tensor, (list, tuple)):
return honor_type(tensor, (_tpu_broadcast(t, name=f"{name}_{i}") for i, t in enumerate(tensor)))
elif isinstance(tensor, (dict, UserDict)):
elif isinstance(tensor, Mapping):
return type(tensor)({k: _tpu_broadcast(v, name=f"{name}_{k}") for k, v in tensor.items()})
return xm.mesh_reduce(name, tensor, lambda x: x[src])

Expand Down Expand Up @@ -448,7 +439,7 @@ def find_batch_size(data):
"""
if isinstance(data, (tuple, list)):
return find_batch_size(data[0])
elif isinstance(data, (dict, UserDict)):
elif isinstance(data, Mapping):
for k in data.keys():
return find_batch_size(data[k])
elif not isinstance(data, torch.Tensor):
Expand All @@ -471,10 +462,8 @@ def concatenate(data, dim=0):
"""
if isinstance(data[0], (tuple, list)):
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
elif isinstance(data[0], UserDict):
elif isinstance(data[0], Mapping):
return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
elif isinstance(data[0], dict):
return type(data[0])(**{k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
elif not isinstance(data[0], torch.Tensor):
raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
return torch.cat(data, dim=dim)
Expand Down

0 comments on commit c5c73e0

Please sign in to comment.