Skip to content

Commit

Permalink
standardized the unittest of darts
Browse files Browse the repository at this point in the history
  • Loading branch information
pprp committed Aug 5, 2022
1 parent 61fd1fa commit 2bfecd0
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions tests/test_models/test_algorithms/test_darts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mmengine.model import BaseModel
from mmengine.optim import build_optim_wrapper
from mmengine.optim.optimizer import OptimWrapper, OptimWrapperDict
from torch import Tensor
from torch.optim import SGD

from mmrazor.models import Darts, DiffModuleMutator, DiffMutableOP
Expand Down Expand Up @@ -154,7 +155,8 @@ def test_search_subnet(self) -> None:
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
algo = Darts(model, mutator)
print(algo.search_subnet())
subnet = algo.search_subnet()
self.assertIsInstance(subnet, dict)

def test_darts_train_step(self) -> None:
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
Expand All @@ -167,7 +169,7 @@ def test_darts_train_step(self) -> None:
optim_wrapper = build_optim_wrapper(algo, self.OPTIM_WRAPPER_CFG)
loss = algo.train_step(data, optim_wrapper)

print(loss)
self.assertTrue(isinstance(loss['loss'], Tensor))

# data is tuple or list
algo = Darts(model, mutator)
Expand All @@ -177,7 +179,7 @@ def test_darts_train_step(self) -> None:
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
loss = algo.train_step(data, optim_wrapper_dict)

print(loss)
self.assertIsNotNone(loss)

def test_darts_with_unroll(self) -> None:
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
Expand All @@ -192,7 +194,7 @@ def test_darts_with_unroll(self) -> None:
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
loss = algo.train_step(data, optim_wrapper_dict)

print(loss)
self.assertIsNotNone(loss)


class TestDartsDDP(TestDarts):
Expand Down Expand Up @@ -243,7 +245,7 @@ def test_dartsddp_train_step(self) -> None:
optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG)
loss = ddp_model.train_step(data, optim_wrapper)

print(loss)
self.assertIsNotNone(loss)

# data is tuple or list
algo = Darts(model, mutator)
Expand All @@ -254,7 +256,7 @@ def test_dartsddp_train_step(self) -> None:
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
loss = ddp_model.train_step(data, optim_wrapper_dict)

print(loss)
self.assertIsNotNone(loss)

def test_dartsddp_with_unroll(self) -> None:
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
Expand All @@ -271,7 +273,7 @@ def test_dartsddp_with_unroll(self) -> None:
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
loss = ddp_model.train_step(data, optim_wrapper_dict)

print(loss)
self.assertIsNotNone(loss)


if __name__ == '__main__':
Expand Down

0 comments on commit 2bfecd0

Please sign in to comment.