Skip to content

Commit

Permalink
fix bugs of apex in dartsddp
Browse files Browse the repository at this point in the history
  • Loading branch information
pprp committed Aug 3, 2022
1 parent 3590e38 commit 61fd1fa
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions mmrazor/models/algorithms/nas/darts.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,24 @@ def _compute_hessian(self, backup_params, dw, supernet_data,
return hessian


class BatchNormWrapper(nn.Module):
"""Wrapper for BatchNorm. For more information,
Please refer to https://github.com/NVIDIA/apex/issues/121
"""

def __init__(self, m):
super(BatchNormWrapper, self).__init__()
self.m = m
# Set the batch norm to eval mode
self.m.eval()

def forward(self, x):
"""Convert fp16 to fp32 when forward"""
input_type = x.dtype
x = self.m(x.float())
return x.to(input_type)


@MODEL_WRAPPERS.register_module()
class DartsDDP(MMDistributedDataParallel):
"""DDP for Darts and rewrite train_step of MMDDP."""
Expand All @@ -304,6 +322,18 @@ def __init__(self,
device_ids = [int(os.environ['LOCAL_RANK'])]
super().__init__(device_ids=device_ids, **kwargs)

fp16 = True
if fp16:

def add_fp16_bn_wrapper(model):
for child_name, child in model.named_children():
if isinstance(child, nn.BatchNorm2d):
setattr(model, child_name, BatchNormWrapper(child))
else:
add_fp16_bn_wrapper(child)

add_fp16_bn_wrapper(self.module)

def train_step(self, data: List[dict],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""The iteration step during training.
Expand Down

0 comments on commit 61fd1fa

Please sign in to comment.