Skip to content

Commit

Permalink
Merge pull request #30 from robmarkcole/support-custom-models
Browse files Browse the repository at this point in the history
Support custom models
  • Loading branch information
robmarkcole authored Dec 16, 2020
2 parents 063f722 + 92e3164 commit 6aed9c1
Show file tree
Hide file tree
Showing 7 changed files with 397 additions and 271 deletions.
265 changes: 161 additions & 104 deletions deepstack/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,36 @@
from typing import Union, List, Set, Dict

## Const
HTTP_OK = 200
DEFAULT_API_KEY = ""
DEFAULT_TIMEOUT = 10 # seconds
DEFAULT_IP = "localhost"
DEFAULT_PORT = 80

## HTTP codes
HTTP_OK = 200
BAD_URL = 404

## API urls
URL_OBJECT_DETECTION = "http://{}:{}/v1/vision/detection"
URL_FACE_DETECTION = "http://{}:{}/v1/vision/face"
URL_FACE_REGISTRATION = "http://{}:{}/v1/vision/face/register"
URL_FACE_RECOGNITION = "http://{}:{}/v1/vision/face/recognize"
URL_SCENE_DETECTION = "http://{}:{}/v1/vision/scene"
URL_BASE_VISION = "http://{ip}:{port}/v1/vision"
URL_CUSTOM = "/custom/{custom_model}"
URL_OBJECT_DETECTION = "/detection"
URL_FACE_DETECTION = "/face"
URL_FACE_REGISTER = "/face/register"
URL_FACE_RECOGNIZE = "/face/recognize"
URL_SCENE_RECOGNIZE = "/scene"


class DeepstackException(Exception):
pass


def format_confidence(confidence: Union[str, float]) -> float:
"""Takes a confidence from the API like
0.55623 and returne 55.6 (%).
"""
return round(float(confidence) * 100, 1)
Takes a confidence from the API like
0.55623 and returns 55.6 (%).
"""
DECIMALS = 1
return round(float(confidence) * 100, DECIMALS)


def get_confidences_above_threshold(
Expand All @@ -31,9 +45,9 @@ def get_confidences_above_threshold(
return [val for val in confidences if val >= confidence_threshold]


def get_recognised_faces(predictions: List[Dict]) -> List[Dict]:
def get_recognized_faces(predictions: List[Dict]) -> List[Dict]:
"""
Get the recognised faces.
Get the recognized faces.
"""
try:
matched_faces = {
Expand All @@ -51,15 +65,17 @@ def get_objects(predictions: List[Dict]) -> List[str]:
Get a list of the unique objects predicted.
"""
labels = [pred["label"] for pred in predictions]
return list(set(labels))
return sorted(list(set(labels)))


def get_object_confidences(predictions: List[Dict], target_object: str):
def get_object_confidences(predictions: List[Dict], target_object: str) -> List[float]:
"""
Return the list of confidences of instances of target label.
"""
confidences = [
pred["confidence"] for pred in predictions if pred["label"] == target_object
float(pred["confidence"])
for pred in predictions
if pred["label"] == target_object
]
return confidences

Expand All @@ -77,160 +93,201 @@ def get_objects_summary(predictions: List[Dict]):

def post_image(
url: str, image_bytes: bytes, api_key: str, timeout: int, data: dict = {}
):
"""Post an image to Deepstack."""
) -> requests.models.Response:
"""Post an image to Deepstack. Only handles exceptions."""
try:
data["api_key"] = api_key
response = requests.post(
data["api_key"] = api_key # Insert the api_key
return requests.post(
url, files={"image": image_bytes}, data=data, timeout=timeout
)
return response
except requests.exceptions.Timeout:
raise DeepstackException(
f"Timeout connecting to Deepstack, current timeout is {timeout} seconds"
f"Timeout connecting to Deepstack, the current timeout is {timeout} seconds, try increasing this value"
)
except requests.exceptions.ConnectionError or requests.exceptions.MissingSchema as exc:
raise DeepstackException(
f"Deepstack connection error, check your IP and port: {exc}"
)
except requests.exceptions.ConnectionError as exc:
raise DeepstackException(f"Connection error: {exc}")


class DeepstackException(Exception):
pass
def process_image(
url: str, image_bytes: bytes, api_key: str, timeout: int, data: dict = {}
) -> Dict:
"""Process image_bytes and detect. Handles common status codes"""
response = post_image(
url=url, image_bytes=image_bytes, api_key=api_key, timeout=timeout, data=data
)
if response.status_code == HTTP_OK:
return response.json()
elif response.status_code == BAD_URL:
raise DeepstackException(f"Bad url supplied, url {url} raised error {BAD_URL}")
else:
raise DeepstackException(
f"Error from Deepstack request, status code: {response.status_code}"
)


class Deepstack(object):
"""Base class for deepstack."""
class DeepstackVision:
"""Base class for Deepstack vision."""

def __init__(
self,
ip_address: str,
port: str,
api_key: str = "",
ip: str = DEFAULT_IP,
port: int = DEFAULT_PORT,
api_key: str = DEFAULT_API_KEY,
timeout: int = DEFAULT_TIMEOUT,
url_detection: str = "",
url_detect: str = "",
url_recognize: str = "",
url_register: str = "",
):

self._ip_address = ip_address
self._port = port
self._url_detection = url_detection
self._url_base = URL_BASE_VISION.format(ip=ip, port=port)
self._url_detect = self._url_base + url_detect
self._url_recognize = self._url_base + url_recognize
self._url_register = self._url_base + url_register
self._api_key = api_key
self._timeout = timeout
self._response = None

def detect(self, image_bytes: bytes):
"""Process image_bytes, performing detection."""
self._response = None
url = self._url_detection.format(self._ip_address, self._port)

response = post_image(url, image_bytes, self._api_key, self._timeout)

if not response.status_code == HTTP_OK:
raise DeepstackException(
f"Error from request, status code: {response.status_code}"
)
return
def detect(self):
"""Process image_bytes and detect."""
raise NotImplementedError

self._response = response.json()
if not self._response["success"]:
error = self._response["error"]
raise DeepstackException(f"Error from Deepstack: {error}")
def recognize(self):
"""Process image_bytes and recognize."""
raise NotImplementedError

@property
def predictions(self):
"""Return the predictions."""
def register(self):
"""Perform a registration."""
raise NotImplementedError


class DeepstackObject(Deepstack):
class DeepstackObject(DeepstackVision):
"""Work with objects"""

def __init__(
self,
ip_address: str,
port: str,
api_key: str = "",
ip: str = DEFAULT_IP,
port: int = DEFAULT_PORT,
api_key: str = DEFAULT_API_KEY,
timeout: int = DEFAULT_TIMEOUT,
custom_model: str = None,
):
super().__init__(
ip_address, port, api_key, timeout, url_detection=URL_OBJECT_DETECTION
)
if not custom_model:
super().__init__(
ip=ip,
port=port,
api_key=api_key,
timeout=timeout,
url_detect=URL_OBJECT_DETECTION,
)
else:
super().__init__(
ip=ip,
port=port,
api_key=api_key,
timeout=timeout,
url_detect=URL_CUSTOM.format(custom_model=custom_model),
)

@property
def predictions(self):
"""Return the predictions."""
return self._response["predictions"]
def detect(self, image_bytes: bytes):
"""Process image_bytes and detect."""
response = process_image(
url=self._url_detect,
image_bytes=image_bytes,
api_key=self._api_key,
timeout=self._timeout,
)
return response["predictions"]


class DeepstackScene(Deepstack):
class DeepstackScene(DeepstackVision):
"""Work with scenes"""

def __init__(
self,
ip_address: str,
port: str,
api_key: str = "",
ip: str = DEFAULT_IP,
port: int = DEFAULT_PORT,
api_key: str = DEFAULT_API_KEY,
timeout: int = DEFAULT_TIMEOUT,
):
super().__init__(
ip_address, port, api_key, timeout, url_detection=URL_SCENE_DETECTION
ip=ip,
port=port,
api_key=api_key,
timeout=timeout,
url_recognize=URL_SCENE_RECOGNIZE,
)

@property
def predictions(self):
"""Return the predictions."""
return self._response
def recognize(self, image_bytes: bytes):
"""Process image_bytes and detect."""
response = process_image(
url=self._url_recognize,
image_bytes=image_bytes,
api_key=self._api_key,
timeout=self._timeout,
)
del response["success"]
return response


class DeepstackFace(Deepstack):
class DeepstackFace(DeepstackVision):
"""Work with objects"""

def __init__(
self,
ip_address: str,
port: str,
api_key: str = "",
ip: str = DEFAULT_IP,
port: int = DEFAULT_PORT,
api_key: str = DEFAULT_API_KEY,
timeout: int = DEFAULT_TIMEOUT,
):
super().__init__(
ip_address, port, api_key, timeout, url_detection=URL_FACE_DETECTION
ip=ip,
port=port,
api_key=api_key,
timeout=timeout,
url_detect=URL_FACE_DETECTION,
url_register=URL_FACE_REGISTER,
url_recognize=URL_FACE_RECOGNIZE,
)

@property
def predictions(self):
"""Return the classifier attributes."""
return self._response["predictions"]
def detect(self, image_bytes: bytes):
"""Process image_bytes and detect."""
response = process_image(
url=self._url_detect,
image_bytes=image_bytes,
api_key=self._api_key,
timeout=self._timeout,
)
return response["predictions"]

def register_face(self, name: str, image_bytes: bytes):
def register(self, name: str, image_bytes: bytes):
"""
Register a face name to a file.
"""

response = post_image(
url=URL_FACE_REGISTRATION.format(self._ip_address, self._port),
response = process_image(
url=self._url_register,
image_bytes=image_bytes,
api_key=self._api_key,
timeout=self._timeout,
data={"userid": name},
)

if response.status_code == 200 and response.json()["success"] == True:
return
elif response.status_code == 200 and response.json()["success"] == False:
error = response.json()["error"]
raise DeepstackException(f"Error from Deepstack: {error}")

def recognise(self, image_bytes: bytes):
"""Process image_bytes, performing recognition."""
url = URL_FACE_RECOGNITION.format(self._ip_address, self._port)

response = post_image(url, image_bytes, self._api_key, self._timeout)
if response["success"] == True:
return response["message"]

if not response.status_code == HTTP_OK:
elif response["success"] == False:
error = response["error"]
raise DeepstackException(
f"Error from request, status code: {response.status_code}"
f"Deepstack raised an error registering a face: {error}"
)
return

self._response = response.json()
if not self._response["success"]:
error = self._response["error"]
raise DeepstackException(f"Error from Deepstack: {error}")
def recognize(self, image_bytes: bytes):
"""Process image_bytes, performing recognition."""
response = process_image(
url=self._url_recognize,
image_bytes=image_bytes,
api_key=self._api_key,
timeout=self._timeout,
)

return response["predictions"]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

VERSION = "0.6"
VERSION = "0.7"

REQUIRES = ["requests"]

Expand Down
Binary file added tests/images/masked.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 6aed9c1

Please sign in to comment.