Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Added unittests with Alert API #108

Merged
merged 4 commits into from
Aug 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ jobs:
python -m pip install --upgrade pip
pip install -e ".[test]" --upgrade
- name: Run unittests
env:
API_URL: ${{ secrets.API_URL }}
API_LOGIN: ${{ secrets.API_LOGIN }}
API_PWD: ${{ secrets.API_PWD }}
LAT: 48.88
LON: 2.38
run: |
coverage run -m pytest tests/
coverage xml
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ test = [
"pytest>=5.3.2",
"coverage[toml]>=4.5.4",
"requests>=2.20.0,<3.0.0",
"python-dotenv>=0.15.0",
]
quality = [
"flake8>=3.9.0",
Expand Down
54 changes: 31 additions & 23 deletions pyroengine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from PIL import Image
from pyroclient import client
from requests.exceptions import ConnectionError
from requests.models import Response

from .vision import Classifier

Expand All @@ -30,7 +31,7 @@ class Engine:
hub_repo: repository on HF Hub to load the ONNX model from
conf_thresh: confidence threshold to send an alert
api_url: url of the pyronear API
client_creds: api credectials for each pizero, the dictionary should be as the one in the example
cam_creds: api credectials for each camera, the dictionary should be as the one in the example
latitude: device latitude
longitude: device longitude
alert_relaxation: number of consecutive positive detections required to send the first alert, and also
Expand All @@ -43,19 +44,19 @@ class Engine:

Examples:
>>> from pyroengine import Engine
>>> client_creds ={
>>> cam_creds ={
"cam_id_1": {'login':'log1', 'password':'pwd1'},
"cam_id_2": {'login':'log2', 'password':'pwd2'},
}
>>> pyroEngine = Engine("pyronear/rexnet1_3x", 0.5, 'https://api.pyronear.org', client_creds, 48.88, 2.38)
>>> pyroEngine = Engine("pyronear/rexnet1_3x", 0.5, 'https://api.pyronear.org', cam_creds, 48.88, 2.38)
"""

def __init__(
self,
hub_repo: str,
conf_thresh: float = 0.5,
api_url: Optional[str] = None,
client_creds: Optional[Dict[str, Dict[str, str]]] = None,
cam_creds: Optional[Dict[str, Dict[str, str]]] = None,
latitude: Optional[float] = None,
longitude: Optional[float] = None,
alert_relaxation: int = 3,
Expand All @@ -74,13 +75,13 @@ def __init__(

# API Setup
if isinstance(api_url, str):
assert isinstance(latitude, float) and isinstance(longitude, float) and isinstance(client_creds, dict)
assert isinstance(latitude, float) and isinstance(longitude, float) and isinstance(cam_creds, dict)
self.latitude = latitude
self.longitude = longitude
self.api_client = {}
if isinstance(api_url, str) and isinstance(client_creds, dict):
if isinstance(api_url, str) and isinstance(cam_creds, dict):
# Instantiate clients for each camera
for _id, vals in client_creds.items():
for _id, vals in cam_creds.items():
self.api_client[_id] = client.Client(api_url, vals["login"], vals["password"])

# Cache & relaxation
Expand All @@ -90,12 +91,12 @@ def __init__(
self.cache_backup_period = cache_backup_period

# Var initialization
self._states: Dict[str, Dict[str, Any]] = {}
if isinstance(client_creds, dict):
for cam_id in client_creds:
self._states: Dict[str, Dict[str, Any]] = {
"-1": {"consec": 0, "frame_count": 0, "ongoing": False},
}
if isinstance(cam_creds, dict):
for cam_id in cam_creds:
self._states[cam_id] = {"consec": 0, "frame_count": 0, "ongoing": False}
else:
self._states["-1"] = {"consec": 0, "frame_count": 0, "ongoing": False}

# Restore pending alerts cache
self._alerts: deque = deque([], cache_size)
Expand Down Expand Up @@ -152,9 +153,9 @@ def _load_cache(self) -> None:
frame = Image.open(entry["frame_path"], mode="r")
self._alerts.append({"frame": frame, "cam_id": entry["cam_id"], "ts": entry["ts"]})

def heartbeat(self, cam_id: str) -> None:
def heartbeat(self, cam_id: str) -> Response:
"""Updates last ping of device"""
self.api_client[cam_id].heartbeat()
return self.api_client[cam_id].heartbeat()

def _update_states(self, conf: float, cam_key: str) -> bool:
"""Updates the detection states"""
Expand Down Expand Up @@ -195,7 +196,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
# Inference with ONNX
pred = float(self.model(frame.convert("RGB")))
# Log analysis result
device_str = f"Camera {cam_id} - " if isinstance(cam_id, str) else ""
device_str = f"Camera '{cam_id}' - " if isinstance(cam_id, str) else ""
pred_str = "Wildfire detected" if pred >= self.conf_thresh else "No wildfire"
logging.info(f"{device_str}{pred_str} (confidence: {pred:.2%})")

Expand Down Expand Up @@ -236,13 +237,17 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:

return pred

def _upload_frame(self, cam_id: str, media_data: bytes) -> None:
def _upload_frame(self, cam_id: str, media_data: bytes) -> Response:
"""Save frame"""
logging.info("Uploading media...")
logging.info(f"Camera '{cam_id}' - Uploading media...")
# Create a media
media_id = self.api_client[cam_id].create_media_from_device().json()["id"]
# Send media
self.api_client[cam_id].upload_media(media_id=media_id, media_data=media_data)
response = self.api_client[cam_id].create_media_from_device()
if response.status_code // 100 == 2:
media = response.json()
# Upload media
self.api_client[cam_id].upload_media(media_id=media["id"], media_data=media_data)

return response

def _stage_alert(self, frame: Image.Image, cam_id: str) -> None:
# Store information in the queue
Expand All @@ -262,7 +267,7 @@ def _process_alerts(self) -> None:
# try to upload the oldest element
frame_info = self._alerts[0]
cam_id = frame_info["cam_id"]
logging.info("Sending alert...")
logging.info(f"Camera {cam_id} - Sending alert from {frame_info['ts']}...")

try:
# Media creation
Expand All @@ -283,10 +288,13 @@ def _process_alerts(self) -> None:
# Media upload
stream = io.BytesIO()
frame_info["frame"].save(stream, format="JPEG")
self.api_client[cam_id].upload_media(self._alerts[0]["media_id"], media_data=stream.getvalue())
self.api_client[cam_id].upload_media(
self._alerts[0]["media_id"],
media_data=stream.getvalue(),
).json()["id"]
# Clear
self._alerts.popleft()
logging.info(f"Camera {frame_info['cam_id']} - alert sent")
logging.info(f"Camera {cam_id} - alert sent")
stream.seek(0) # "Rewind" the stream to the beginning so we can read its content
except (KeyError, ConnectionError):
logging.warning(f"Camera {cam_id} - unable to upload cache")
Expand Down
20 changes: 18 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@


@pytest.fixture(scope="session")
def mock_classification_image(tmpdir_factory):
def mock_wildfire_stream(tmpdir_factory):
url = "https://github.com/pyronear/pyro-vision/releases/download/v0.1.2/fire_sample_image.jpg"
return Image.open(BytesIO(requests.get(url).content))
return requests.get(url).content


@pytest.fixture(scope="session")
def mock_wildfire_image(tmpdir_factory, mock_wildfire_stream):
return Image.open(BytesIO(mock_wildfire_stream))


@pytest.fixture(scope="session")
def mock_forest_stream(tmpdir_factory):
url = "https://github.com/pyronear/pyro-engine/releases/download/v0.1.1/forest_sample.jpg"
return requests.get(url).content


@pytest.fixture(scope="session")
def mock_forest_image(tmpdir_factory, mock_forest_stream):
return Image.open(BytesIO(mock_forest_stream))
70 changes: 63 additions & 7 deletions tests/test_engine_core.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import json
import os
from datetime import datetime
from pathlib import Path

from dotenv import load_dotenv

from pyroengine.core import Engine


def test_engine(tmpdir_factory, mock_classification_image):
def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):

# Cache
folder = str(tmpdir_factory.mktemp("engine_cache"))

# No API
engine = Engine("pyronear/rexnet1_3x", cache_folder=folder)

# Cache saving
_ts = datetime.utcnow().isoformat()
engine._stage_alert(mock_classification_image, 0)
engine._stage_alert(mock_wildfire_image, 0)
assert len(engine._alerts) == 1
assert engine._alerts[0]["ts"] < datetime.utcnow().isoformat() and _ts < engine._alerts[0]["ts"]
assert engine._alerts[0]["media_id"] is None
Expand All @@ -31,20 +34,73 @@ def test_engine(tmpdir_factory, mock_classification_image):
"cam_id": 0,
"ts": engine._alerts[0]["ts"],
}
# Overrites cache files
engine._dump_cache()

# Cache dump loading
engine = Engine("pyronear/rexnet1_3x", cache_folder=folder)
assert len(engine._alerts) == 1
engine.clear_cache()

# inference
engine = Engine("pyronear/rexnet1_3x", cache_folder=folder)
out = engine.predict(mock_classification_image, 0)
engine = Engine("pyronear/rexnet1_3x", alert_relaxation=3, cache_folder=folder)
out = engine.predict(mock_forest_image)
assert isinstance(out, float) and 0 <= out <= 1
assert engine._states["-1"]["consec"] == 0
out = engine.predict(mock_wildfire_image)
assert isinstance(out, float) and 0 <= out <= 1
assert engine._states["-1"]["consec"] == 1
# Alert relaxation
assert not engine._states["-1"]["ongoing"]
out = engine.predict(mock_classification_image, 0)
out = engine.predict(mock_classification_image, 0)
out = engine.predict(mock_wildfire_image)
assert engine._states["-1"]["consec"] == 2
out = engine.predict(mock_wildfire_image)
assert engine._states["-1"]["consec"] == 3
assert engine._states["-1"]["ongoing"]


def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image):
# Cache
folder = str(tmpdir_factory.mktemp("engine_cache"))
# With API
load_dotenv(Path(__file__).parent.parent.joinpath(".env").absolute())
api_url = os.environ.get("API_URL")
lat = os.environ.get("LAT")
lon = os.environ.get("LON")
cam_creds = {"dummy_cam": {"login": os.environ.get("API_LOGIN"), "password": os.environ.get("API_PWD")}}
# Skip the API-related tests if the URL is not specified
if isinstance(api_url, str):
engine = Engine(
"pyronear/rexnet1_3x",
api_url=api_url,
cam_creds=cam_creds,
latitude=float(lat),
longitude=float(lon),
alert_relaxation=2,
frame_saving_period=3,
cache_folder=folder,
frame_size=(224, 224),
)
# Heartbeat
start_ts = datetime.utcnow().isoformat()
response = engine.heartbeat("dummy_cam")
assert response.status_code // 100 == 2
ts = datetime.utcnow().isoformat()
json_respone = response.json()
assert start_ts < json_respone["last_ping"] < ts
# Send an alert
engine.predict(mock_wildfire_image, "dummy_cam")
assert len(engine._alerts) == 0 and engine._states["dummy_cam"]["consec"] == 1
assert engine._states["dummy_cam"]["frame_count"] == 1
engine.predict(mock_wildfire_image, "dummy_cam")
assert engine._states["dummy_cam"]["consec"] == 2 and engine._states["dummy_cam"]["ongoing"]
assert engine._states["dummy_cam"]["frame_count"] == 2
# Check that a media and an alert have been registered
assert len(engine._alerts) == 0
# Upload a frame
response = engine._upload_frame("dummy_cam", mock_wildfire_stream)
assert response.status_code // 100 == 2
# Upload frame in process
engine.predict(mock_wildfire_image, "dummy_cam")
# Check that a new media has been created & uploaded
assert engine._states["dummy_cam"]["frame_count"] == 0
6 changes: 3 additions & 3 deletions tests/test_engine_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
from pyroengine.vision import Classifier


def test_classifier(mock_classification_image):
def test_classifier(mock_wildfire_image):

# Instantiae the ONNX model
model = Classifier("pyronear/rexnet1_3x")
# Check preprocessing
out = model.preprocess_image(mock_classification_image)
out = model.preprocess_image(mock_wildfire_image)
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1, 3, 224, 224)
# Check inference
out = model(mock_classification_image)
out = model(mock_wildfire_image)
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1,)
assert out >= 0 and out <= 1
Expand Down