Skip to content

Commit

Permalink
Typehints for model_loader and model_service_worker (#2540)
Browse files Browse the repository at this point in the history
* typehint model_loader.py

* typehint model_service_worker.py

* unit test for change to handle_connection

* fix typo

---------

Co-authored-by: Ankith Gunapal <agunapal@ischool.Berkeley.edu>
  • Loading branch information
ZachOBrien and agunapal authored Aug 30, 2023
1 parent 30ff033 commit a37bb3c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 14 deletions.
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[mypy]
; A good-first-issue is to add types to a file
; As you do start adding them in files and slowly make the excluded files empty
files = ts/context.py, ts/model_server.py
files = ts/context.py, ts/model_server.py, ts/model_loader.py, ts/model_service_worker.py

exclude = examples, binaries, ts_scripts, test, kubernetes, benchmarks, model-archiver, workflow-archiver, ts/tests, ts/utils

Expand Down
2 changes: 1 addition & 1 deletion ts/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def load(
batch_size: Optional[int] = None,
envelope: Optional[str] = None,
limit_max_image_pixels: Optional[bool] = True,
metrics_cache: MetricsCacheYamlImpl = None,
metrics_cache: Optional[MetricsCacheYamlImpl] = None,
) -> Service:
"""
Load TorchServe 1.0 model from file.
Expand Down
33 changes: 21 additions & 12 deletions ts/model_service_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import platform
import socket
import sys
from typing import Optional

from ts.arg_parser import ArgParser
from ts.metrics.metric_cache_yaml_impl import MetricsCacheYamlImpl
Expand All @@ -19,8 +20,7 @@
MAX_FAILURE_THRESHOLD = 5
SOCKET_ACCEPT_TIMEOUT = 30.0
DEBUG = False
BENCHMARK = os.getenv("TS_BENCHMARK")
BENCHMARK = BENCHMARK in ["True", "true", "TRUE"]
BENCHMARK = os.getenv("TS_BENCHMARK") in ["True", "true", "TRUE"]
LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 0))
WORLD_RANK = int(os.getenv("RANK", 0))
Expand All @@ -34,11 +34,11 @@ class TorchModelServiceWorker(object):

def __init__(
self,
s_type=None,
s_name=None,
host_addr=None,
port_num=None,
metrics_config=None,
s_type: Optional[str] = None,
s_name: Optional[str] = None,
host_addr: Optional[str] = None,
port_num: Optional[int] = None,
metrics_config: Optional[str] = None,
):
self.sock_type = s_type

Expand Down Expand Up @@ -178,8 +178,13 @@ def handle_connection(self, cl_socket):
if BENCHMARK:
pr.enable()
if cmd == b"I":
resp = service.predict(msg)
cl_socket.sendall(resp)
if service is not None:
resp = service.predict(msg)
cl_socket.sendall(resp)
else:
raise RuntimeError(
"Received command: {}, but service is not loaded".format(cmd)
)
elif cmd == b"L":
service, result, code = self.load_model(msg)
resp = bytearray()
Expand Down Expand Up @@ -227,8 +232,8 @@ def run_server(self):
while ts_path in sys.path:
sys.path.remove(ts_path)

sock_type = None
socket_name = None
sock_type: Optional[str] = None
socket_name: Optional[str] = None

# noinspection PyBroadException
try:
Expand Down Expand Up @@ -262,7 +267,11 @@ def run_server(self):
except Exception: # pylint: disable=broad-except
logging.error("Backend worker process died.", exc_info=True)
finally:
if sock_type == "unix" and os.path.exists(socket_name):
if (
sock_type == "unix"
and socket_name is not None
and os.path.exists(socket_name)
):
os.remove(socket_name)

sys.exit(1)
13 changes: 13 additions & 0 deletions ts/tests/unit_tests/test_model_service_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,16 @@ def test_handle_connection(self, patches, model_service_worker):
model_service_worker.handle_connection(cl_socket)

cl_socket.sendall.assert_called()

def test_handle_connection_recv_inference_before_load(
self, patches, model_service_worker
):
patches.retrieve_msg.side_effect = [(b"I", "")]
service = Mock()
service.context = None
cl_socket = Mock()

with pytest.raises(
RuntimeError, match=r"Received command: .*, but service is not loaded"
):
model_service_worker.handle_connection(cl_socket)

0 comments on commit a37bb3c

Please sign in to comment.