Skip to content

Commit

Permalink
[fbsync] Add test to check that classification models are FX-compatib…
Browse files Browse the repository at this point in the history
…le (#3662)

Summary:
* Add test to check that classification models are FX-compatible

* Replace torch.equal with torch.allclose

* remove skipling

Reviewed By: fmassa

Differential Revision: D29264313

fbshipit-source-id: 4e57e255c6ce680fc6deee6a9980a7d189e23597

Co-authored-by: Nicolas Hug <nicolashug@fb.com>
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Jun 22, 2021
1 parent eaea921 commit 0495b05
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import functools
import operator
import torch
import torch.fx
import torch.nn as nn
import torchvision
from torchvision import models
Expand Down Expand Up @@ -140,6 +141,13 @@ def get_export_import_copy(m):
assert_export_import_module(sm, args)


def _check_fx_compatible(model, inputs):
model_fx = torch.fx.symbolic_trace(model)
out = model(inputs)
out_fx = model_fx(inputs)
torch.testing.assert_close(out, out_fx)


# If 'unwrapper' is provided it will be called with the script model outputs
# before they are compared to the eager model outputs. This is useful if the
# model outputs are different between TorchScript / Eager mode
Expand Down Expand Up @@ -408,6 +416,7 @@ def test_classification_model(model_name, dev):
_assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == 50
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)

if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
Expand Down

0 comments on commit 0495b05

Please sign in to comment.