diff --git a/src/engine/server/models/abstract/BaseModel.py b/src/engine/server/models/abstract/BaseModel.py index 8648c368..be3a6db9 100644 --- a/src/engine/server/models/abstract/BaseModel.py +++ b/src/engine/server/models/abstract/BaseModel.py @@ -1,3 +1,5 @@ +from server.services.errors import PortalError, Errors + """Base model class that should be inherited by all other models.""" @@ -19,6 +21,7 @@ def __init__( self._height_ = height self._width_ = width self._label_map_ = {} + self._model_ = None def get_info(self): """Returns the name, type, directory and description of the model.""" @@ -29,6 +32,15 @@ def get_info(self): "type": self._type_, } + def get_model(self): + """Returns self._model_ if it is not None + + Throws PortalError Errors.NOTFOUND if model is not found. + """ + if self._model_ is None: + raise PortalError(Errors.NOTFOUND, "Model not found") + return self._model_ + def get_key(self): """Returns the model key.""" return self._key_ diff --git a/src/engine/server/models/darknet_model.py b/src/engine/server/models/darknet_model.py index 98f1e280..04cadf8a 100644 --- a/src/engine/server/models/darknet_model.py +++ b/src/engine/server/models/darknet_model.py @@ -87,10 +87,11 @@ def load(self): loaded_model = cv2.dnn.readNetFromDarknet( self._configname_, self._weightsname_ ) - return loaded_model + self._model_ = loaded_model - def predict(self, model, image_array): + def predict(self, image_array): try: + model = self._model_ (H, W) = image_array.shape[:2] ln = model.getLayerNames() ln = [ln[i[0] - 1] for i in model.getUnconnectedOutLayers()] diff --git a/src/engine/server/models/tensorflow_model.py b/src/engine/server/models/tensorflow_model.py index b8fd6a4b..96bcbf21 100644 --- a/src/engine/server/models/tensorflow_model.py +++ b/src/engine/server/models/tensorflow_model.py @@ -62,9 +62,12 @@ def load(self): loaded_model = tf.saved_model.load( os.path.join(self._directory_, "saved_model") ) - return loaded_model + self._model_ = loaded_model - def predict(self, model, image_array): + def predict(self, image_array): + if self._model_ is None: + raise PortalError(Errors.NOTFOUND, "Model is not Loaded") + model = self._model_ image_tensor = tf.convert_to_tensor( cv2.resize( image_array, diff --git a/src/engine/server/routes/routes.py b/src/engine/server/routes/routes.py index 0c885bd4..f8cd71b6 100644 --- a/src/engine/server/routes/routes.py +++ b/src/engine/server/routes/routes.py @@ -488,10 +488,10 @@ def predict_single_image(model_id: str) -> tuple: if model_id not in global_store.get_loaded_model_keys(): raise PortalError(Errors.NOTFOUND, "model_id not loaded.") - model_dict = global_store.get_model_dict(model_id) + model_class = global_store.get_model_class(model_id) output = predict_image( - model_dict, format_arg, iou, image_directory + model_class, format_arg, iou, image_directory ) global_store.add_predictions(prediction_key, output) @@ -585,10 +585,10 @@ def predict_video_fn(model_id: str) -> tuple: raise PortalError(Errors.UNINITIALIZED, "No Models loaded.") if model_id not in global_store.get_loaded_model_keys(): raise PortalError(Errors.NOTFOUND, "model_id not loaded.") - model_dict = global_store.get_model_dict(model_id) + model_class = global_store.get_model_class(model_id) output = predict_video( - model_dict, + model_class, iou=iou, video_directory=video_directory, frame_interval=frame_interval, diff --git a/src/engine/server/services/global_store.py b/src/engine/server/services/global_store.py index 9c094ddd..48378466 100644 --- a/src/engine/server/services/global_store.py +++ b/src/engine/server/services/global_store.py @@ -320,14 +320,13 @@ def check_model_limit(self): """Check if current registered models exceeds the model limit.""" return len(self._loaded_model_list_) >= self._model_load_limit_ - def load_model(self, key: str, model_dict: dict) -> None: + def load_model(self, key: str, model_class: BaseModel) -> None: """Add a model into the loaded model list. :param key: The model key. - :param model_dict: A dictionary of the loaded model - and its model class. + :param model_class: The model class that the model key represents. """ - self._loaded_model_list_[key] = model_dict + self._loaded_model_list_[key] = model_class def get_loaded_model_keys(self) -> list: """Retrieve all model keys in the loaded model list.""" @@ -345,11 +344,11 @@ def unload_model(self, key: str) -> None: self._save_store_() - def get_model_dict(self, key: str) -> tuple: + def get_model_class(self, key: str) -> tuple: """Retrieve the model, label map, height and width given the model key. :param key: The model key. - :return: A dictionary of the loaded model and its model class. + :return: The model class that the model key represents. """ return self._loaded_model_list_[key] diff --git a/src/engine/server/services/model_loader.py b/src/engine/server/services/model_loader.py index 0b8b28ae..a6d31224 100644 --- a/src/engine/server/services/model_loader.py +++ b/src/engine/server/services/model_loader.py @@ -26,9 +26,8 @@ def model_loader(model_id: str) -> Response: registered_model: BaseModel = global_store.get_registered_model( model_id ) - loaded_model = registered_model.load() - model_dict = {"model": loaded_model, "model_class": registered_model} - global_store.load_model(model_id, model_dict) + registered_model.load() + global_store.load_model(model_id, registered_model) return Response(status=200) except KeyError as e: raise PortalError( diff --git a/src/engine/server/services/predictions.py b/src/engine/server/services/predictions.py index 07c2ee45..4d5265fa 100644 --- a/src/engine/server/services/predictions.py +++ b/src/engine/server/services/predictions.py @@ -1,6 +1,9 @@ """Module containing the prediction function""" import os import cv2 +from typing import Optional + +import numpy as np # pylint: disable=E0401, E0611 from server.utilities.prediction_utilities import ( @@ -17,27 +20,25 @@ # pylint: disable=R0913 def _predict_single_image( - model_dict, - format_arg, - iou, - image_array, - confidence=0.001, + model_class: BaseModel, + format_arg: str, + iou: float, + image_array: np.ndarray, + confidence: Optional[float] = 0.001, ): """Make predictions on a single image. - :param model_dict: A dictionary of the loaded model and its model class. + :param model_class: A dictionary of the loaded model and its model class. :param format_arg: The output format. :param iou: The intersection of union threshold. :param image_array: The single image as an array. :param confidence: The confidence threshold. :return: The predictions in the format requested by format_arg. """ - model = model_dict["model"] - model_class: BaseModel = model_dict["model_class"] + print("model_class", model_class) label_map = model_class.get_label_map() image_array = cv2.cvtColor(image_array, cv2.COLOR_BGRA2RGB) detections = model_class.predict( - model=model, image_array=image_array, ) suppressed_output = get_suppressed_output( @@ -63,14 +64,14 @@ def _predict_single_image( def predict_image( - model_dict, - format_arg, - iou, - image_directory, + model_class: BaseModel, + format_arg: str, + iou: float, + image_directory: str, ): """Make predictions on a single image. - :param model_dict: A dictionary of the loaded model and its model class. + :param model_class: A dictionary of the loaded model and its model class. :param format_arg: The output format. :param iou: The intersection of union threshold. :param image_directory: The directory of the single image. @@ -78,7 +79,7 @@ def predict_image( """ image_arr = cv2.imread(image_directory) return _predict_single_image( - model_dict=model_dict, + model_class=model_class, format_arg=format_arg, iou=iou, image_array=image_arr, @@ -87,15 +88,15 @@ def predict_image( # pylint: disable=R0913 def predict_video( - model_dict, - iou, - video_directory, - frame_interval, - confidence, + model_class: BaseModel, + iou: float, + video_directory: str, + frame_interval: int, + confidence: float, ): """Make predictions on a multiple images within the video. - :param model_dict: A dictionary of the loaded model and its model class. + :param model_class: A dictionary of the loaded model and its model class. :param iou: The intersection of union threshold. :param video_directory: The directory of the video. :param frame_interval: The sampling interval of the video. @@ -124,7 +125,7 @@ def predict_video( cap.set(1, count) # make inference the frame single_output = _predict_single_image( - model_dict=model_dict, + model_class=model_class, format_arg="json", iou=iou, image_array=frame,