From a37bb3c31b13c6406536002de5db3d8f2ffa8542 Mon Sep 17 00:00:00 2001 From: Zach OBrien Date: Wed, 30 Aug 2023 15:06:24 -0400 Subject: [PATCH] Typehints for model_loader and model_service_worker (#2540) * typehint model_loader.py * typehint model_service_worker.py * unit test for change to handle_connection * fix typo --------- Co-authored-by: Ankith Gunapal --- mypy.ini | 2 +- ts/model_loader.py | 2 +- ts/model_service_worker.py | 33 ++++++++++++------- .../unit_tests/test_model_service_worker.py | 13 ++++++++ 4 files changed, 36 insertions(+), 14 deletions(-) diff --git a/mypy.ini b/mypy.ini index 2e1165bdc1..7a8e90d7c4 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,7 +1,7 @@ [mypy] ; A good-first-issue is to add types to a file ; As you do start adding them in files and slowly make the excluded files empty -files = ts/context.py, ts/model_server.py +files = ts/context.py, ts/model_server.py, ts/model_loader.py, ts/model_service_worker.py exclude = examples, binaries, ts_scripts, test, kubernetes, benchmarks, model-archiver, workflow-archiver, ts/tests, ts/utils diff --git a/ts/model_loader.py b/ts/model_loader.py index 88afcccc68..fb8dcd161b 100644 --- a/ts/model_loader.py +++ b/ts/model_loader.py @@ -73,7 +73,7 @@ def load( batch_size: Optional[int] = None, envelope: Optional[str] = None, limit_max_image_pixels: Optional[bool] = True, - metrics_cache: MetricsCacheYamlImpl = None, + metrics_cache: Optional[MetricsCacheYamlImpl] = None, ) -> Service: """ Load TorchServe 1.0 model from file. diff --git a/ts/model_service_worker.py b/ts/model_service_worker.py index 57e8bf9a7c..e3722f5305 100644 --- a/ts/model_service_worker.py +++ b/ts/model_service_worker.py @@ -10,6 +10,7 @@ import platform import socket import sys +from typing import Optional from ts.arg_parser import ArgParser from ts.metrics.metric_cache_yaml_impl import MetricsCacheYamlImpl @@ -19,8 +20,7 @@ MAX_FAILURE_THRESHOLD = 5 SOCKET_ACCEPT_TIMEOUT = 30.0 DEBUG = False -BENCHMARK = os.getenv("TS_BENCHMARK") -BENCHMARK = BENCHMARK in ["True", "true", "TRUE"] +BENCHMARK = os.getenv("TS_BENCHMARK") in ["True", "true", "TRUE"] LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0)) WORLD_SIZE = int(os.getenv("WORLD_SIZE", 0)) WORLD_RANK = int(os.getenv("RANK", 0)) @@ -34,11 +34,11 @@ class TorchModelServiceWorker(object): def __init__( self, - s_type=None, - s_name=None, - host_addr=None, - port_num=None, - metrics_config=None, + s_type: Optional[str] = None, + s_name: Optional[str] = None, + host_addr: Optional[str] = None, + port_num: Optional[int] = None, + metrics_config: Optional[str] = None, ): self.sock_type = s_type @@ -178,8 +178,13 @@ def handle_connection(self, cl_socket): if BENCHMARK: pr.enable() if cmd == b"I": - resp = service.predict(msg) - cl_socket.sendall(resp) + if service is not None: + resp = service.predict(msg) + cl_socket.sendall(resp) + else: + raise RuntimeError( + "Received command: {}, but service is not loaded".format(cmd) + ) elif cmd == b"L": service, result, code = self.load_model(msg) resp = bytearray() @@ -227,8 +232,8 @@ def run_server(self): while ts_path in sys.path: sys.path.remove(ts_path) - sock_type = None - socket_name = None + sock_type: Optional[str] = None + socket_name: Optional[str] = None # noinspection PyBroadException try: @@ -262,7 +267,11 @@ def run_server(self): except Exception: # pylint: disable=broad-except logging.error("Backend worker process died.", exc_info=True) finally: - if sock_type == "unix" and os.path.exists(socket_name): + if ( + sock_type == "unix" + and socket_name is not None + and os.path.exists(socket_name) + ): os.remove(socket_name) sys.exit(1) diff --git a/ts/tests/unit_tests/test_model_service_worker.py b/ts/tests/unit_tests/test_model_service_worker.py index a17ede9650..0423419a01 100644 --- a/ts/tests/unit_tests/test_model_service_worker.py +++ b/ts/tests/unit_tests/test_model_service_worker.py @@ -185,3 +185,16 @@ def test_handle_connection(self, patches, model_service_worker): model_service_worker.handle_connection(cl_socket) cl_socket.sendall.assert_called() + + def test_handle_connection_recv_inference_before_load( + self, patches, model_service_worker + ): + patches.retrieve_msg.side_effect = [(b"I", "")] + service = Mock() + service.context = None + cl_socket = Mock() + + with pytest.raises( + RuntimeError, match=r"Received command: .*, but service is not loaded" + ): + model_service_worker.handle_connection(cl_socket)