Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AnandInguva committed Feb 6, 2023
1 parent e371f37 commit 3bff28c
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,57 @@ def test_gpu_auto_convert_to_cpu(self):
"are not available. Switching to CPU.",
log.output)

def test_load_torch_script_model(self):
torch_model = PytorchLinearRegression(2, 1)
torch_script_model = torch.jit.script(torch_model)

torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt')

torch.jit.save(torch_script_model, torch_script_path)

model_handler = PytorchModelHandlerTensor(
state_dict_path=torch_script_path, use_torch_script_format=True)

torch_script_model = model_handler.load_model()

self.assertTrue(isinstance(torch_script_model, torch.jit.ScriptModule))

def test_inference_torch_script_model(self):
torch_model = PytorchLinearRegression(2, 1)
torch_model.load_state_dict(
OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
('linear.bias', torch.Tensor([0.5]))]))

torch_script_model = torch.jit.script(torch_model)

torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt')

torch.jit.save(torch_script_model, torch_script_path)

model_handler = PytorchModelHandlerTensor(
state_dict_path=torch_script_path, use_torch_script_format=True)

with TestPipeline() as pipeline:
pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES)
predictions = pcoll | RunInference(model_handler)
assert_that(
predictions,
equal_to(
TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result))

def test_torch_model_class_none(self):
torch_model = PytorchLinearRegression(2, 1)
torch_path = os.path.join(self.tmpdir, 'torch_model.pt')

torch.save(torch_model, torch_path)

with self.assertRaisesRegex(
RuntimeError,
"Please pass both `model_class` and `model_params` to the torch "
"model handler when using it with PyTorch. "
"If you opt to load the entire that was saved using TorchScript"):
_ = PytorchModelHandlerTensor(state_dict_path=torch_path)


if __name__ == '__main__':
unittest.main()

0 comments on commit 3bff28c

Please sign in to comment.