Skip to content

Commit

Permalink
add testcase for dartsddp
Browse files Browse the repository at this point in the history
  • Loading branch information
pprp committed Aug 3, 2022
1 parent c72aa30 commit 3590e38
Showing 1 changed file with 47 additions and 3 deletions.
50 changes: 47 additions & 3 deletions tests/test_models/test_algorithms/test_darts.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ def forward(


class TestDarts(TestCase):
device: str = 'cpu'

def setUp(self) -> None:
self.device: str = 'cpu'

OPTIMIZER_CFG = dict(
type='SGD',
lr=0.5,
Expand Down Expand Up @@ -209,14 +210,15 @@ def setUpClass(cls) -> None:
backend = 'gloo'
dist.init_process_group(backend, rank=0, world_size=1)

def prepare_model(self) -> Darts:
def prepare_model(self, device_ids=None) -> Darts:
model = ToyDiffModule().to(self.device)
mutator = DiffModuleMutator().to(self.device)
mutator.prepare_from_supernet(model)

algo = Darts(model, mutator)

return DartsDDP(module=algo, find_unused_parameters=True)
return DartsDDP(
module=algo, find_unused_parameters=True, device_ids=device_ids)

@classmethod
def tearDownClass(cls) -> None:
Expand All @@ -229,6 +231,48 @@ def test_init(self) -> None:
# ddp_model = DartsDDP(module=model, device_ids=[0])
self.assertIsInstance(ddp_model, DartsDDP)

def test_dartsddp_train_step(self) -> None:
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)

# data is tensor
algo = Darts(model, mutator)
ddp_model = DartsDDP(module=algo, find_unused_parameters=True)
data = self._prepare_fake_data()
optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG)
loss = ddp_model.train_step(data, optim_wrapper)

print(loss)

# data is tuple or list
algo = Darts(model, mutator)
ddp_model = DartsDDP(module=algo, find_unused_parameters=True)
data = [self._prepare_fake_data() for _ in range(2)]
optim_wrapper_dict = OptimWrapperDict(
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
loss = ddp_model.train_step(data, optim_wrapper_dict)

print(loss)

def test_dartsddp_with_unroll(self) -> None:
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)

# data is tuple or list
algo = Darts(model, mutator, unroll=True)
ddp_model = DartsDDP(module=algo, find_unused_parameters=True)

data = [self._prepare_fake_data() for _ in range(2)]
optim_wrapper_dict = OptimWrapperDict(
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
loss = ddp_model.train_step(data, optim_wrapper_dict)

print(loss)


if __name__ == '__main__':
import unittest
Expand Down

0 comments on commit 3590e38

Please sign in to comment.