Skip to content

Commit

Permalink
[zero-3] add bwd support for list/dict types returned in fwd (#1857)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra committed Apr 26, 2022
1 parent b4fcd98 commit a52cbf8
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 2 deletions.
22 changes: 20 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,38 @@ def move_to_cpu(tensor_list):
tensor.data = tensor.data.cpu()


def is_builtin_type(obj):
# https://stackoverflow.com/a/17795199
return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins"


#apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
if isinstance(outputs, (tuple, list)):
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module,
functional,
backward_function,
output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
return outputs.__class__(touched_outputs)
elif isinstance(outputs, dict):
# apply inplace to avoid recreating dict inherited objects
for key in outputs.keys():
outputs[key] = _apply_to_tensors_only(module,
functional,
backward_function,
outputs[key])
return outputs
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
if not is_builtin_type(outputs):
logger.warning(
f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
"The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
"output tensors and therefore may not get triggered properly.")
return outputs


Expand Down
64 changes: 64 additions & 0 deletions tests/unit/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,3 +1222,67 @@ def _go(model, hidden_dim):
model.step()

_go(model=model, hidden_dim=hidden_dim)


@pytest.mark.parametrize('return_type', [tuple, list, dict])
def test_z3_dict_fwd(return_type):
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 3
}
}
hidden_dim = 10

class MyModel(torch.nn.Module):
def __init__(self, hidden_dim):
super(MyModel, self).__init__()
self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)
self.cel = torch.nn.CrossEntropyLoss()

def forward(self, x, y):
x = self.l1(x)
loss = self.cel(x, y)
if return_type == dict:
val = {'a': x, 'loss': loss, 'b': 1, 'c': None}
elif return_type == list:
val = [x, loss]
elif return_type == tuple:
val = (x, loss)
else:
raise NotImplementedError
return val

@distributed_test(world_size=[1])
def _go(hidden_dim):
with deepspeed.zero.Init():
model = MyModel(hidden_dim)

model, _, _, _ = deepspeed.initialize(model=model,
model_parameters=model.parameters(),
config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
torch.distributed.barrier()
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
if return_type == dict:
loss = loss['loss']
else:
loss = loss[1]
model.backward(loss)
model.step()

_go(hidden_dim)

0 comments on commit a52cbf8

Please sign in to comment.