diff --git a/test/test_models.py b/test/test_models.py index a3c7ec40342..d40b72e55ff 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -7,6 +7,7 @@ import functools import operator import torch +import torch.fx import torch.nn as nn import torchvision from torchvision import models @@ -140,6 +141,13 @@ def get_export_import_copy(m): assert_export_import_module(sm, args) +def _check_fx_compatible(model, inputs): + model_fx = torch.fx.symbolic_trace(model) + out = model(inputs) + out_fx = model_fx(inputs) + torch.testing.assert_close(out, out_fx) + + # If 'unwrapper' is provided it will be called with the script model outputs # before they are compared to the eager model outputs. This is useful if the # model outputs are different between TorchScript / Eager mode @@ -408,6 +416,7 @@ def test_classification_model(model_name, dev): _assert_expected(out.cpu(), model_name, prec=0.1) assert out.shape[-1] == 50 _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) + _check_fx_compatible(model, x) if dev == torch.device("cuda"): with torch.cuda.amp.autocast():