Skip to content

Commit

Permalink
Merge pull request #116 from datature/fixes/caching-system
Browse files Browse the repository at this point in the history
Fixes/caching system
  • Loading branch information
marcus-neo authored Aug 2, 2021
2 parents 7b7189d + 66aeaa2 commit 5bd7202
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 57 deletions.
4 changes: 3 additions & 1 deletion src/engine/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def run(self):
# pylint: disable=invalid-name
app = Flask(__name__)
server = ServerThread(app)
global_store = GlobalStore(MODEL_LOAD_LIMIT, IDLE_MINUTES, caching_system=CACHE_OPTION)
global_store = GlobalStore(
MODEL_LOAD_LIMIT, IDLE_MINUTES, caching_system=CACHE_OPTION
)


def wait_for_process() -> None:
Expand Down
17 changes: 16 additions & 1 deletion src/engine/server/models/abstract/BaseModel.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,46 @@
from server.services.errors import PortalError, Errors

"""Base model class that should be inherited by all other models."""


class BaseModel:
def __init__(
self,
model_type: str,
directory: str,
name: str,
description: str,
height: int = None,
width: int = None,
):
self._type_ = model_type
self._directory_ = directory
self._name_ = name
self._description_ = description
self._key_ = None
self._height_ = height
self._width_ = width
self._label_map_ = {}
self._model_ = None

def get_info(self):
"""Returns the name, directory and description of the model."""
"""Returns the name, type, directory and description of the model."""
return {
"directory": self._directory_,
"description": self._description_,
"name": self._name_,
"type": self._type_,
}

def get_model(self):
"""Returns self._model_ if it is not None
Throws PortalError Errors.NOTFOUND if model is not found.
"""
if self._model_ is None:
raise PortalError(Errors.NOTFOUND, "Model not found")
return self._model_

def get_key(self):
"""Returns the model key."""
return self._key_
Expand Down
2 changes: 1 addition & 1 deletion src/engine/server/models/abstract/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def Model(
):
"""Factory function that routes the model to the specific class."""

args = [directory, name, description]
args = [model_type, directory, name, description]

model_class = {
"tensorflow": TensorflowModel,
Expand Down
5 changes: 3 additions & 2 deletions src/engine/server/models/darknet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ def load(self):
loaded_model = cv2.dnn.readNetFromDarknet(
self._configname_, self._weightsname_
)
return loaded_model
self._model_ = loaded_model

def predict(self, model, image_array):
def predict(self, image_array):
try:
model = self._model_
(H, W) = image_array.shape[:2]
ln = model.getLayerNames()
ln = [ln[i[0] - 1] for i in model.getUnconnectedOutLayers()]
Expand Down
7 changes: 5 additions & 2 deletions src/engine/server/models/tensorflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ def load(self):
loaded_model = tf.saved_model.load(
os.path.join(self._directory_, "saved_model")
)
return loaded_model
self._model_ = loaded_model

def predict(self, model, image_array):
def predict(self, image_array):
if self._model_ is None:
raise PortalError(Errors.NOTFOUND, "Model is not Loaded")
model = self._model_
image_tensor = tf.convert_to_tensor(
cv2.resize(
image_array,
Expand Down
8 changes: 4 additions & 4 deletions src/engine/server/routes/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,10 @@ def predict_single_image(model_id: str) -> tuple:
if model_id not in global_store.get_loaded_model_keys():
raise PortalError(Errors.NOTFOUND, "model_id not loaded.")

model_dict = global_store.get_model_dict(model_id)
model_class = global_store.get_model_class(model_id)

output = predict_image(
model_dict, format_arg, iou, image_directory
model_class, format_arg, iou, image_directory
)
global_store.add_predictions(prediction_key, output)

Expand Down Expand Up @@ -585,10 +585,10 @@ def predict_video_fn(model_id: str) -> tuple:
raise PortalError(Errors.UNINITIALIZED, "No Models loaded.")
if model_id not in global_store.get_loaded_model_keys():
raise PortalError(Errors.NOTFOUND, "model_id not loaded.")
model_dict = global_store.get_model_dict(model_id)
model_class = global_store.get_model_class(model_id)

output = predict_video(
model_dict,
model_class,
iou=iou,
video_directory=video_directory,
frame_interval=frame_interval,
Expand Down
61 changes: 40 additions & 21 deletions src/engine/server/services/global_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flask import Response
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.interval import IntervalTrigger

# Ignore import-error and no-name-in-module due to Pyshell
# pylint: disable=E0401, E0611
from server.services.errors import Errors, PortalError
Expand All @@ -18,6 +19,7 @@
from server.services.filesystem.folder_target import FolderTargets

from server.models.abstract.BaseModel import BaseModel
from server.models.abstract.Model import Model


def _delete_store_():
Expand Down Expand Up @@ -52,7 +54,6 @@ def __init__(self, model_load_limit, idle_minutes, caching_system) -> None:
}
self._idle_minutes_ = idle_minutes


# Flag to enable or diable the caching system
self.caching_system = caching_system
self._store_ = {
Expand All @@ -78,21 +79,18 @@ def _schedule_shutdown_(self):
:return: void
"""
if (
self._is_shutdown_server_(self._idle_minutes_)
or self._op_atomic_
):
if self._is_shutdown_server_(self._idle_minutes_) or self._op_atomic_:
time.sleep(5)
else:
os._exit(0) # pylint: disable=W0212

def set_start_scheduler(self):
"""Start the scheduler
"""
"""Start the scheduler"""
if self._scheduler_ is None:
self._scheduler_ = BackgroundScheduler(daemon=True)
self._scheduler_.add_job(self._schedule_shutdown_, IntervalTrigger(minutes=1))
self._scheduler_.add_job(
self._schedule_shutdown_, IntervalTrigger(minutes=1)
)
self._scheduler_.start()

# Shut down the scheduler when exiting the app
Expand Down Expand Up @@ -122,6 +120,14 @@ def load_cache(self):
self._targeted_folders_ = jsonpickle.decode(
self._store_["targeted_folders"]
)
for _, value in self._store_["registry"].items():
reg_model = Model(
value["model_type"],
value["model_dir"],
value["model_name"],
"",
)
self.add_registered_model(*reg_model.register())
self._is_cache_called_ = True
else:
raise PortalError(
Expand All @@ -136,8 +142,21 @@ def _save_store_(self):
Transfers data from self._store_ into "./server/cache/store.portalCache"
"""
if self.caching_system:
cache_store = self._store_.copy()
updated_registry = {
registry_key: {
key: value
for key, value in self._store_["registry"][
registry_key
].items()
if key in ["model_type", "model_dir", "model_name"]
}
for registry_key in list(self._store_["registry"].keys())
}
cache_store["registry"] = updated_registry

with open(os.getenv("CACHE_DIR"), "w+") as cache:
json.dump(self._store_, cache)
json.dump(cache_store, cache)

# pylint: disable=R0201
def has_cache(self):
Expand Down Expand Up @@ -251,7 +270,7 @@ def add_registered_model(
"""
model_dir = model.get_info()["directory"]
model_name = model.get_info()["name"]

model_type = model.get_info()["type"]
for item in self._store_["registry"]:
if self._store_["registry"][item]["model_dir"] == model_dir:
self._store_["registry"].pop(item)
Expand All @@ -261,9 +280,10 @@ def add_registered_model(
Errors.INVALIDAPI,
"A model with the same name already exists.",
)
serialized_model_class = jsonpickle.encode(model)
model_class = model
self._store_["registry"][key] = {
"class": serialized_model_class,
"class": model_class,
"model_type": model_type,
"model_dir": model_dir,
"model_name": model_name,
}
Expand All @@ -276,13 +296,13 @@ def get_registered_model(self, key: str) -> BaseModel:
:return: The model as a Model class.
"""
if key in self._store_["registry"]:
return jsonpickle.decode(self._store_["registry"][key]["class"])
return self._store_["registry"][key]["class"]
raise PortalError(Errors.INVALIDMODELKEY, "Model not registered.")

def get_registered_model_info(self) -> str:
"""Retrieve directory, description, name of all registered models"""
return {
model_id: jsonpickle.decode(model_dict["class"]).get_info()
model_id: model_dict["class"].get_info()
for model_id, model_dict in self._store_["registry"].items()
}

Expand All @@ -300,14 +320,13 @@ def check_model_limit(self):
"""Check if current registered models exceeds the model limit."""
return len(self._loaded_model_list_) >= self._model_load_limit_

def load_model(self, key: str, model_dict: dict) -> None:
def load_model(self, key: str, model_class: BaseModel) -> None:
"""Add a model into the loaded model list.
:param key: The model key.
:param model_dict: A dictionary of the loaded model
and its model class.
:param model_class: The model class that the model key represents.
"""
self._loaded_model_list_[key] = model_dict
self._loaded_model_list_[key] = model_class

def get_loaded_model_keys(self) -> list:
"""Retrieve all model keys in the loaded model list."""
Expand All @@ -325,11 +344,11 @@ def unload_model(self, key: str) -> None:

self._save_store_()

def get_model_dict(self, key: str) -> tuple:
def get_model_class(self, key: str) -> tuple:
"""Retrieve the model, label map, height and width given the model key.
:param key: The model key.
:return: A dictionary of the loaded model and its model class.
:return: The model class that the model key represents.
"""
return self._loaded_model_list_[key]

Expand Down
5 changes: 2 additions & 3 deletions src/engine/server/services/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ def model_loader(model_id: str) -> Response:
registered_model: BaseModel = global_store.get_registered_model(
model_id
)
loaded_model = registered_model.load()
model_dict = {"model": loaded_model, "model_class": registered_model}
global_store.load_model(model_id, model_dict)
registered_model.load()
global_store.load_model(model_id, registered_model)
return Response(status=200)
except KeyError as e:
raise PortalError(
Expand Down
45 changes: 23 additions & 22 deletions src/engine/server/services/predictions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Module containing the prediction function"""
import os
import cv2
from typing import Optional

import numpy as np

# pylint: disable=E0401, E0611
from server.utilities.prediction_utilities import (
Expand All @@ -17,27 +20,25 @@

# pylint: disable=R0913
def _predict_single_image(
model_dict,
format_arg,
iou,
image_array,
confidence=0.001,
model_class: BaseModel,
format_arg: str,
iou: float,
image_array: np.ndarray,
confidence: Optional[float] = 0.001,
):
"""Make predictions on a single image.
:param model_dict: A dictionary of the loaded model and its model class.
:param model_class: A dictionary of the loaded model and its model class.
:param format_arg: The output format.
:param iou: The intersection of union threshold.
:param image_array: The single image as an array.
:param confidence: The confidence threshold.
:return: The predictions in the format requested by format_arg.
"""
model = model_dict["model"]
model_class: BaseModel = model_dict["model_class"]
print("model_class", model_class)
label_map = model_class.get_label_map()
image_array = cv2.cvtColor(image_array, cv2.COLOR_BGRA2RGB)
detections = model_class.predict(
model=model,
image_array=image_array,
)
suppressed_output = get_suppressed_output(
Expand All @@ -63,22 +64,22 @@ def _predict_single_image(


def predict_image(
model_dict,
format_arg,
iou,
image_directory,
model_class: BaseModel,
format_arg: str,
iou: float,
image_directory: str,
):
"""Make predictions on a single image.
:param model_dict: A dictionary of the loaded model and its model class.
:param model_class: A dictionary of the loaded model and its model class.
:param format_arg: The output format.
:param iou: The intersection of union threshold.
:param image_directory: The directory of the single image.
:return: The predictions in the format requested by format_arg.
"""
image_arr = cv2.imread(image_directory)
return _predict_single_image(
model_dict=model_dict,
model_class=model_class,
format_arg=format_arg,
iou=iou,
image_array=image_arr,
Expand All @@ -87,15 +88,15 @@ def predict_image(

# pylint: disable=R0913
def predict_video(
model_dict,
iou,
video_directory,
frame_interval,
confidence,
model_class: BaseModel,
iou: float,
video_directory: str,
frame_interval: int,
confidence: float,
):
"""Make predictions on a multiple images within the video.
:param model_dict: A dictionary of the loaded model and its model class.
:param model_class: A dictionary of the loaded model and its model class.
:param iou: The intersection of union threshold.
:param video_directory: The directory of the video.
:param frame_interval: The sampling interval of the video.
Expand Down Expand Up @@ -124,7 +125,7 @@ def predict_video(
cap.set(1, count)
# make inference the frame
single_output = _predict_single_image(
model_dict=model_dict,
model_class=model_class,
format_arg="json",
iou=iou,
image_array=frame,
Expand Down

0 comments on commit 5bd7202

Please sign in to comment.