Skip to content

Commit

Permalink
Add test for num_class in test_model.py (#815)
Browse files Browse the repository at this point in the history
* Add test for loading pretrained models

The update modifies the test to check whether the model can successfully load the pretrained weights. Will raise an error if the model parameters are incorrectly defined or named.

* Add test on 'num_class'

Passing num_class equal to a number other than 1000 helps in making the test more enforcing in nature.
  • Loading branch information
ekagra-ranjan authored and fmassa committed Mar 26, 2019
1 parent 6334466 commit 83d3770
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ def get_available_models():

class Tester(unittest.TestCase):
def _test_model(self, name, input_shape):
model = models.__dict__[name]()
# passing num_class equal to a number other than 1000 helps in making the test more enforcing in nature
model = models.__dict__[name](num_classes=50)
model.eval()
x = torch.rand(input_shape)
out = model(x)
self.assertEqual(out.shape[-1], 1000)
self.assertEqual(out.shape[-1], 50)


for model_name in get_available_models():
Expand Down

0 comments on commit 83d3770

Please sign in to comment.