Skip to content

Commit

Permalink
fix ut bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyang07 committed Oct 14, 2022
1 parent 0f28970 commit 430e53f
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 55 deletions.
2 changes: 1 addition & 1 deletion mmrazor/models/architectures/heads/darts_subnet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,4 @@ def loss(self, feats: Tuple[torch.Tensor],

losses.update(add_prefix(aux_losses, 'aux_head.'))

return losses
return losses
73 changes: 21 additions & 52 deletions tests/test_models/test_algorithms/test_darts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Dict, List, Tuple
from typing import Dict
from unittest import TestCase

import pytest
Expand All @@ -18,15 +18,13 @@
from mmrazor.models.algorithms.nas.darts import DartsDDP
from mmrazor.registry import MODELS

# from unittest.mock import Mock

MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)


@MODELS.register_module()
class ToyDiffModule(BaseModel):
class ToyDiffModule2(BaseModel):

def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
Expand Down Expand Up @@ -72,15 +70,6 @@ def forward(self, batch_inputs, data_samples=None, mode='tensor'):
return out


class ToyDataPreprocessor(torch.nn.Module):

def forward(
self,
data: Dict,
training: bool = True) -> Tuple[torch.Tensor, List[ClsDataSample]]:
return data['inputs'], data['data_samples']


class TestDarts(TestCase):

def setUp(self) -> None:
Expand All @@ -97,14 +86,14 @@ def setUp(self) -> None:

def test_init(self) -> None:
# initiate darts when `norm_training` is True.
model = ToyDiffModule()
model = ToyDiffModule2()
mutator = DiffModuleMutator()
algo = Darts(architecture=model, mutator=mutator, norm_training=True)
algo.eval()
self.assertTrue(model.bn.training)

# initiate darts with built mutator
model = ToyDiffModule()
model = ToyDiffModule2()
mutator = DiffModuleMutator()
algo = Darts(model, mutator)
self.assertIs(algo.mutator, mutator)
Expand All @@ -125,7 +114,7 @@ def test_init(self) -> None:

def test_forward_loss(self) -> None:
inputs = torch.randn(1, 3, 8, 8)
model = ToyDiffModule()
model = ToyDiffModule2()

# supernet
mutator = DiffModuleMutator()
Expand All @@ -150,7 +139,7 @@ def _prepare_fake_data(self) -> Dict:
return {'inputs': imgs, 'data_samples': data_samples}

def test_search_subnet(self) -> None:
model = ToyDiffModule()
model = ToyDiffModule2()

mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
Expand All @@ -159,7 +148,7 @@ def test_search_subnet(self) -> None:
self.assertIsInstance(subnet, dict)

def test_darts_train_step(self) -> None:
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
model = ToyDiffModule2()
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)

Expand All @@ -182,7 +171,7 @@ def test_darts_train_step(self) -> None:
self.assertIsNotNone(loss)

def test_darts_with_unroll(self) -> None:
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
model = ToyDiffModule2()
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)

Expand All @@ -205,19 +194,17 @@ def setUpClass(cls) -> None:
os.environ['MASTER_PORT'] = '12345'

# initialize the process group
if torch.cuda.is_available():
backend = 'nccl'
cls.device = 'cuda'
else:
backend = 'gloo'
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
dist.init_process_group(backend, rank=0, world_size=1)

def prepare_model(self, device_ids=None) -> Darts:
model = ToyDiffModule().to(self.device)
mutator = DiffModuleMutator().to(self.device)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = ToyDiffModule2()
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)

algo = Darts(model, mutator)
algo = Darts(model, mutator).to(self.device)

return DartsDDP(
module=algo, find_unused_parameters=True, device_ids=device_ids)
Expand All @@ -230,52 +217,34 @@ def tearDownClass(cls) -> None:
not torch.cuda.is_available(), reason='cuda device is not avaliable')
def test_init(self) -> None:
ddp_model = self.prepare_model()
# ddp_model = DartsDDP(module=model, device_ids=[0])
self.assertIsInstance(ddp_model, DartsDDP)

def test_dartsddp_train_step(self) -> None:
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)

# data is tensor
algo = Darts(model, mutator)
ddp_model = DartsDDP(module=algo, find_unused_parameters=True)
ddp_model = self.prepare_model()
data = self._prepare_fake_data()
optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG)
loss = ddp_model.train_step(data, optim_wrapper)

self.assertIsNotNone(loss)

# data is tuple or list
algo = Darts(model, mutator)
ddp_model = DartsDDP(module=algo, find_unused_parameters=True)
ddp_model = self.prepare_model()
data = [self._prepare_fake_data() for _ in range(2)]
optim_wrapper_dict = OptimWrapperDict(
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
architecture=OptimWrapper(SGD(ddp_model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(ddp_model.parameters(), lr=0.01)))
loss = ddp_model.train_step(data, optim_wrapper_dict)

self.assertIsNotNone(loss)

def test_dartsddp_with_unroll(self) -> None:
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)

# data is tuple or list
algo = Darts(model, mutator, unroll=True)
ddp_model = DartsDDP(module=algo, find_unused_parameters=True)

ddp_model = self.prepare_model()
data = [self._prepare_fake_data() for _ in range(2)]
optim_wrapper_dict = OptimWrapperDict(
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
architecture=OptimWrapper(SGD(ddp_model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(ddp_model.parameters(), lr=0.01)))
loss = ddp_model.train_step(data, optim_wrapper_dict)

self.assertIsNotNone(loss)


if __name__ == '__main__':
import unittest
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,4 @@ def test_init(self) -> None:

@classmethod
def tearDownClass(cls) -> None:
dist.destroy_process_group()
dist.destroy_process_group()
2 changes: 1 addition & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,4 @@ def main():


if __name__ == '__main__':
main()
main()

0 comments on commit 430e53f

Please sign in to comment.