diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 5b18a711db..2426cff78f 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -168,7 +168,7 @@ class Foo(ImageClassifier): def test_optimization(tmpdir): - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) optim = torch.optim.Adam(model.parameters()) task = ClassificationTask(model, optimizer=optim, scheduler=None)