diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index c47b21d2f25a7..c6ee809f78846 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -609,6 +609,57 @@ def test_gpu_auto_convert_to_cpu(self): "are not available. Switching to CPU.", log.output) + def test_load_torch_script_model(self): + torch_model = PytorchLinearRegression(2, 1) + torch_script_model = torch.jit.script(torch_model) + + torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt') + + torch.jit.save(torch_script_model, torch_script_path) + + model_handler = PytorchModelHandlerTensor( + state_dict_path=torch_script_path, use_torch_script_format=True) + + torch_script_model = model_handler.load_model() + + self.assertTrue(isinstance(torch_script_model, torch.jit.ScriptModule)) + + def test_inference_torch_script_model(self): + torch_model = PytorchLinearRegression(2, 1) + torch_model.load_state_dict( + OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])), + ('linear.bias', torch.Tensor([0.5]))])) + + torch_script_model = torch.jit.script(torch_model) + + torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt') + + torch.jit.save(torch_script_model, torch_script_path) + + model_handler = PytorchModelHandlerTensor( + state_dict_path=torch_script_path, use_torch_script_format=True) + + with TestPipeline() as pipeline: + pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES) + predictions = pcoll | RunInference(model_handler) + assert_that( + predictions, + equal_to( + TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result)) + + def test_torch_model_class_none(self): + torch_model = PytorchLinearRegression(2, 1) + torch_path = os.path.join(self.tmpdir, 'torch_model.pt') + + torch.save(torch_model, torch_path) + + with self.assertRaisesRegex( + RuntimeError, + "Please pass both `model_class` and `model_params` to the torch " + "model handler when using it with PyTorch. " + "If you opt to load the entire that was saved using TorchScript"): + _ = PytorchModelHandlerTensor(state_dict_path=torch_path) + if __name__ == '__main__': unittest.main()