Skip to content

Commit

Permalink
Changes To Caching System
Browse files Browse the repository at this point in the history
Fully integrated Loaded Model into Model Class
  • Loading branch information
marcus-neo committed Aug 2, 2021
1 parent b8cf3d9 commit 66aeaa2
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 39 deletions.
12 changes: 12 additions & 0 deletions src/engine/server/models/abstract/BaseModel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from server.services.errors import PortalError, Errors

"""Base model class that should be inherited by all other models."""


Expand All @@ -19,6 +21,7 @@ def __init__(
self._height_ = height
self._width_ = width
self._label_map_ = {}
self._model_ = None

def get_info(self):
"""Returns the name, type, directory and description of the model."""
Expand All @@ -29,6 +32,15 @@ def get_info(self):
"type": self._type_,
}

def get_model(self):
"""Returns self._model_ if it is not None
Throws PortalError Errors.NOTFOUND if model is not found.
"""
if self._model_ is None:
raise PortalError(Errors.NOTFOUND, "Model not found")
return self._model_

def get_key(self):
"""Returns the model key."""
return self._key_
Expand Down
5 changes: 3 additions & 2 deletions src/engine/server/models/darknet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ def load(self):
loaded_model = cv2.dnn.readNetFromDarknet(
self._configname_, self._weightsname_
)
return loaded_model
self._model_ = loaded_model

def predict(self, model, image_array):
def predict(self, image_array):
try:
model = self._model_
(H, W) = image_array.shape[:2]
ln = model.getLayerNames()
ln = [ln[i[0] - 1] for i in model.getUnconnectedOutLayers()]
Expand Down
7 changes: 5 additions & 2 deletions src/engine/server/models/tensorflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ def load(self):
loaded_model = tf.saved_model.load(
os.path.join(self._directory_, "saved_model")
)
return loaded_model
self._model_ = loaded_model

def predict(self, model, image_array):
def predict(self, image_array):
if self._model_ is None:
raise PortalError(Errors.NOTFOUND, "Model is not Loaded")
model = self._model_
image_tensor = tf.convert_to_tensor(
cv2.resize(
image_array,
Expand Down
8 changes: 4 additions & 4 deletions src/engine/server/routes/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,10 @@ def predict_single_image(model_id: str) -> tuple:
if model_id not in global_store.get_loaded_model_keys():
raise PortalError(Errors.NOTFOUND, "model_id not loaded.")

model_dict = global_store.get_model_dict(model_id)
model_class = global_store.get_model_class(model_id)

output = predict_image(
model_dict, format_arg, iou, image_directory
model_class, format_arg, iou, image_directory
)
global_store.add_predictions(prediction_key, output)

Expand Down Expand Up @@ -585,10 +585,10 @@ def predict_video_fn(model_id: str) -> tuple:
raise PortalError(Errors.UNINITIALIZED, "No Models loaded.")
if model_id not in global_store.get_loaded_model_keys():
raise PortalError(Errors.NOTFOUND, "model_id not loaded.")
model_dict = global_store.get_model_dict(model_id)
model_class = global_store.get_model_class(model_id)

output = predict_video(
model_dict,
model_class,
iou=iou,
video_directory=video_directory,
frame_interval=frame_interval,
Expand Down
11 changes: 5 additions & 6 deletions src/engine/server/services/global_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,13 @@ def check_model_limit(self):
"""Check if current registered models exceeds the model limit."""
return len(self._loaded_model_list_) >= self._model_load_limit_

def load_model(self, key: str, model_dict: dict) -> None:
def load_model(self, key: str, model_class: BaseModel) -> None:
"""Add a model into the loaded model list.
:param key: The model key.
:param model_dict: A dictionary of the loaded model
and its model class.
:param model_class: The model class that the model key represents.
"""
self._loaded_model_list_[key] = model_dict
self._loaded_model_list_[key] = model_class

def get_loaded_model_keys(self) -> list:
"""Retrieve all model keys in the loaded model list."""
Expand All @@ -345,11 +344,11 @@ def unload_model(self, key: str) -> None:

self._save_store_()

def get_model_dict(self, key: str) -> tuple:
def get_model_class(self, key: str) -> tuple:
"""Retrieve the model, label map, height and width given the model key.
:param key: The model key.
:return: A dictionary of the loaded model and its model class.
:return: The model class that the model key represents.
"""
return self._loaded_model_list_[key]

Expand Down
5 changes: 2 additions & 3 deletions src/engine/server/services/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ def model_loader(model_id: str) -> Response:
registered_model: BaseModel = global_store.get_registered_model(
model_id
)
loaded_model = registered_model.load()
model_dict = {"model": loaded_model, "model_class": registered_model}
global_store.load_model(model_id, model_dict)
registered_model.load()
global_store.load_model(model_id, registered_model)
return Response(status=200)
except KeyError as e:
raise PortalError(
Expand Down
45 changes: 23 additions & 22 deletions src/engine/server/services/predictions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Module containing the prediction function"""
import os
import cv2
from typing import Optional

import numpy as np

# pylint: disable=E0401, E0611
from server.utilities.prediction_utilities import (
Expand All @@ -17,27 +20,25 @@

# pylint: disable=R0913
def _predict_single_image(
model_dict,
format_arg,
iou,
image_array,
confidence=0.001,
model_class: BaseModel,
format_arg: str,
iou: float,
image_array: np.ndarray,
confidence: Optional[float] = 0.001,
):
"""Make predictions on a single image.
:param model_dict: A dictionary of the loaded model and its model class.
:param model_class: A dictionary of the loaded model and its model class.
:param format_arg: The output format.
:param iou: The intersection of union threshold.
:param image_array: The single image as an array.
:param confidence: The confidence threshold.
:return: The predictions in the format requested by format_arg.
"""
model = model_dict["model"]
model_class: BaseModel = model_dict["model_class"]
print("model_class", model_class)
label_map = model_class.get_label_map()
image_array = cv2.cvtColor(image_array, cv2.COLOR_BGRA2RGB)
detections = model_class.predict(
model=model,
image_array=image_array,
)
suppressed_output = get_suppressed_output(
Expand All @@ -63,22 +64,22 @@ def _predict_single_image(


def predict_image(
model_dict,
format_arg,
iou,
image_directory,
model_class: BaseModel,
format_arg: str,
iou: float,
image_directory: str,
):
"""Make predictions on a single image.
:param model_dict: A dictionary of the loaded model and its model class.
:param model_class: A dictionary of the loaded model and its model class.
:param format_arg: The output format.
:param iou: The intersection of union threshold.
:param image_directory: The directory of the single image.
:return: The predictions in the format requested by format_arg.
"""
image_arr = cv2.imread(image_directory)
return _predict_single_image(
model_dict=model_dict,
model_class=model_class,
format_arg=format_arg,
iou=iou,
image_array=image_arr,
Expand All @@ -87,15 +88,15 @@ def predict_image(

# pylint: disable=R0913
def predict_video(
model_dict,
iou,
video_directory,
frame_interval,
confidence,
model_class: BaseModel,
iou: float,
video_directory: str,
frame_interval: int,
confidence: float,
):
"""Make predictions on a multiple images within the video.
:param model_dict: A dictionary of the loaded model and its model class.
:param model_class: A dictionary of the loaded model and its model class.
:param iou: The intersection of union threshold.
:param video_directory: The directory of the video.
:param frame_interval: The sampling interval of the video.
Expand Down Expand Up @@ -124,7 +125,7 @@ def predict_video(
cap.set(1, count)
# make inference the frame
single_output = _predict_single_image(
model_dict=model_dict,
model_class=model_class,
format_arg="json",
iou=iou,
image_array=frame,
Expand Down

0 comments on commit 66aeaa2

Please sign in to comment.