Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set model status using torchserve api #1878

Merged
merged 12 commits into from
Jan 31, 2024
61 changes: 60 additions & 1 deletion kubernetes/kserve/kserve_wrapper/TorchserveModel.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
""" The torchserve side inference end-points request are handled to
return a KServe side response """
import logging
import os
import pathlib
import time
from enum import Enum
from typing import Dict, Union

import grpc
import inference_pb2_grpc
import kserve
import requests
from gprc_utils import from_ts_grpc, to_ts_grpc
from inference_pb2 import PredictionResponse
from kserve.errors import ModelMissingError
Expand All @@ -25,6 +28,7 @@
EXPLAINER_URL_FORMAT = EXPLAINER_v2_URL_FORMAT = "http://{0}/explanations/{1}"
REGISTER_URL_FORMAT = "{0}/models?initial_workers=1&url={1}"
UNREGISTER_URL_FORMAT = "{0}/models/{1}"
READINESS_URL_FORMAT = "{0}/models/{1}?customized={2}"


class PredictorProtocol(Enum):
Expand Down Expand Up @@ -150,5 +154,60 @@ def load(self) -> bool:
]
if len(existing_paths) == 0:
raise ModelMissingError(model_path)
self.ready = True

num_try = 0
model_load_customized = os.environ.get("MODEL_LOAD_CUSTOMIZED", "false")
model_load_max_try = int(os.environ.get("MODEL_LOAD_MAX_TRY", 10))
model_load_delay = int(os.environ.get("MODEL_LOAD_DELAY", 30))
model_load_timeout = int(os.environ.get("MODEL_LOAD_TIMEOUT", 5))
while num_try < model_load_max_try and not self.ready:
num_try = num_try + 1
logging.info(
f"Loading {self.name} .. {num_try} of {model_load_max_try} tries.."
)

try:
response = requests.get(
READINESS_URL_FORMAT.format(
self.management_address, self.name, model_load_customized
),
timeout=model_load_timeout,
).json()

default_verison = response[0]

workers = default_verison["workers"]
workers_status = [
worker["id"] for worker in workers if worker["status"] == "READY"
]

worker_ready = False
if len(workers_status) > 0:
worker_ready = True

self.ready = (
worker_ready
if model_load_customized == "false"
else worker_ready and "customizedMetadata" in default_verison
)

except (
requests.ConnectionError,
requests.Timeout,
requests.ConnectTimeout,
requests.ReadTimeout,
) as e:
logging.info(f"The model {self.name} is not ready")

except Exception as e:
logging.info(e)
logging.info(f"Failed loading model {self.name}")
break

logging.info(f"Sleep {model_load_delay} seconds for load {self.name}..")
time.sleep(model_load_delay)

if self.ready:
logging.info(f"The model {self.name} is ready")

return self.ready
6 changes: 6 additions & 0 deletions kubernetes/kserve/kserve_wrapper/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ def parse_config():
# By default model.load() is called on first request. Enabling load all
# model in TS config.properties, all models are loaded at start and the
# below method sets status to true for the models.
# However, even if all preparations related to loading the model (e.g.,
# download pretrained models using online storage) are not completed in
# torchserve handler, if model.ready=true is set, there may be problems.
# Therefore, the ready status is determined using the api provided by
# torchserve.

model.load()
models.append(model)

Expand Down
Loading