Skip to content

Commit

Permalink
test: Added unittests with Alert API (#108)
Browse files Browse the repository at this point in the history
* refactor: Renamed constructor arg

* test: Added unittests for Alert API

* ci: Fixed secrets forwarding

* test: Extended unittests
  • Loading branch information
frgfm authored Aug 6, 2022
1 parent fa594c6 commit 33a2904
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 35 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ jobs:
python -m pip install --upgrade pip
pip install -e ".[test]" --upgrade
- name: Run unittests
env:
API_URL: ${{ secrets.API_URL }}
API_LOGIN: ${{ secrets.API_LOGIN }}
API_PWD: ${{ secrets.API_PWD }}
LAT: 48.88
LON: 2.38
run: |
coverage run -m pytest tests/
coverage xml
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ test = [
"pytest>=5.3.2",
"coverage[toml]>=4.5.4",
"requests>=2.20.0,<3.0.0",
"python-dotenv>=0.15.0",
]
quality = [
"flake8>=3.9.0",
Expand Down
54 changes: 31 additions & 23 deletions pyroengine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from PIL import Image
from pyroclient import client
from requests.exceptions import ConnectionError
from requests.models import Response

from .vision import Classifier

Expand All @@ -30,7 +31,7 @@ class Engine:
hub_repo: repository on HF Hub to load the ONNX model from
conf_thresh: confidence threshold to send an alert
api_url: url of the pyronear API
client_creds: api credectials for each pizero, the dictionary should be as the one in the example
cam_creds: api credectials for each camera, the dictionary should be as the one in the example
latitude: device latitude
longitude: device longitude
alert_relaxation: number of consecutive positive detections required to send the first alert, and also
Expand All @@ -43,19 +44,19 @@ class Engine:
Examples:
>>> from pyroengine import Engine
>>> client_creds ={
>>> cam_creds ={
"cam_id_1": {'login':'log1', 'password':'pwd1'},
"cam_id_2": {'login':'log2', 'password':'pwd2'},
}
>>> pyroEngine = Engine("pyronear/rexnet1_3x", 0.5, 'https://api.pyronear.org', client_creds, 48.88, 2.38)
>>> pyroEngine = Engine("pyronear/rexnet1_3x", 0.5, 'https://api.pyronear.org', cam_creds, 48.88, 2.38)
"""

def __init__(
self,
hub_repo: str,
conf_thresh: float = 0.5,
api_url: Optional[str] = None,
client_creds: Optional[Dict[str, Dict[str, str]]] = None,
cam_creds: Optional[Dict[str, Dict[str, str]]] = None,
latitude: Optional[float] = None,
longitude: Optional[float] = None,
alert_relaxation: int = 3,
Expand All @@ -74,13 +75,13 @@ def __init__(

# API Setup
if isinstance(api_url, str):
assert isinstance(latitude, float) and isinstance(longitude, float) and isinstance(client_creds, dict)
assert isinstance(latitude, float) and isinstance(longitude, float) and isinstance(cam_creds, dict)
self.latitude = latitude
self.longitude = longitude
self.api_client = {}
if isinstance(api_url, str) and isinstance(client_creds, dict):
if isinstance(api_url, str) and isinstance(cam_creds, dict):
# Instantiate clients for each camera
for _id, vals in client_creds.items():
for _id, vals in cam_creds.items():
self.api_client[_id] = client.Client(api_url, vals["login"], vals["password"])

# Cache & relaxation
Expand All @@ -90,12 +91,12 @@ def __init__(
self.cache_backup_period = cache_backup_period

# Var initialization
self._states: Dict[str, Dict[str, Any]] = {}
if isinstance(client_creds, dict):
for cam_id in client_creds:
self._states: Dict[str, Dict[str, Any]] = {
"-1": {"consec": 0, "frame_count": 0, "ongoing": False},
}
if isinstance(cam_creds, dict):
for cam_id in cam_creds:
self._states[cam_id] = {"consec": 0, "frame_count": 0, "ongoing": False}
else:
self._states["-1"] = {"consec": 0, "frame_count": 0, "ongoing": False}

# Restore pending alerts cache
self._alerts: deque = deque([], cache_size)
Expand Down Expand Up @@ -152,9 +153,9 @@ def _load_cache(self) -> None:
frame = Image.open(entry["frame_path"], mode="r")
self._alerts.append({"frame": frame, "cam_id": entry["cam_id"], "ts": entry["ts"]})

def heartbeat(self, cam_id: str) -> None:
def heartbeat(self, cam_id: str) -> Response:
"""Updates last ping of device"""
self.api_client[cam_id].heartbeat()
return self.api_client[cam_id].heartbeat()

def _update_states(self, conf: float, cam_key: str) -> bool:
"""Updates the detection states"""
Expand Down Expand Up @@ -195,7 +196,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
# Inference with ONNX
pred = float(self.model(frame.convert("RGB")))
# Log analysis result
device_str = f"Camera {cam_id} - " if isinstance(cam_id, str) else ""
device_str = f"Camera '{cam_id}' - " if isinstance(cam_id, str) else ""
pred_str = "Wildfire detected" if pred >= self.conf_thresh else "No wildfire"
logging.info(f"{device_str}{pred_str} (confidence: {pred:.2%})")

Expand Down Expand Up @@ -236,13 +237,17 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:

return pred

def _upload_frame(self, cam_id: str, media_data: bytes) -> None:
def _upload_frame(self, cam_id: str, media_data: bytes) -> Response:
"""Save frame"""
logging.info("Uploading media...")
logging.info(f"Camera '{cam_id}' - Uploading media...")
# Create a media
media_id = self.api_client[cam_id].create_media_from_device().json()["id"]
# Send media
self.api_client[cam_id].upload_media(media_id=media_id, media_data=media_data)
response = self.api_client[cam_id].create_media_from_device()
if response.status_code // 100 == 2:
media = response.json()
# Upload media
self.api_client[cam_id].upload_media(media_id=media["id"], media_data=media_data)

return response

def _stage_alert(self, frame: Image.Image, cam_id: str) -> None:
# Store information in the queue
Expand All @@ -262,7 +267,7 @@ def _process_alerts(self) -> None:
# try to upload the oldest element
frame_info = self._alerts[0]
cam_id = frame_info["cam_id"]
logging.info("Sending alert...")
logging.info(f"Camera {cam_id} - Sending alert from {frame_info['ts']}...")

try:
# Media creation
Expand All @@ -283,10 +288,13 @@ def _process_alerts(self) -> None:
# Media upload
stream = io.BytesIO()
frame_info["frame"].save(stream, format="JPEG")
self.api_client[cam_id].upload_media(self._alerts[0]["media_id"], media_data=stream.getvalue())
self.api_client[cam_id].upload_media(
self._alerts[0]["media_id"],
media_data=stream.getvalue(),
).json()["id"]
# Clear
self._alerts.popleft()
logging.info(f"Camera {frame_info['cam_id']} - alert sent")
logging.info(f"Camera {cam_id} - alert sent")
stream.seek(0) # "Rewind" the stream to the beginning so we can read its content
except (KeyError, ConnectionError):
logging.warning(f"Camera {cam_id} - unable to upload cache")
Expand Down
20 changes: 18 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@


@pytest.fixture(scope="session")
def mock_classification_image(tmpdir_factory):
def mock_wildfire_stream(tmpdir_factory):
url = "https://github.com/pyronear/pyro-vision/releases/download/v0.1.2/fire_sample_image.jpg"
return Image.open(BytesIO(requests.get(url).content))
return requests.get(url).content


@pytest.fixture(scope="session")
def mock_wildfire_image(tmpdir_factory, mock_wildfire_stream):
return Image.open(BytesIO(mock_wildfire_stream))


@pytest.fixture(scope="session")
def mock_forest_stream(tmpdir_factory):
url = "https://github.com/pyronear/pyro-engine/releases/download/v0.1.1/forest_sample.jpg"
return requests.get(url).content


@pytest.fixture(scope="session")
def mock_forest_image(tmpdir_factory, mock_forest_stream):
return Image.open(BytesIO(mock_forest_stream))
70 changes: 63 additions & 7 deletions tests/test_engine_core.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import json
import os
from datetime import datetime
from pathlib import Path

from dotenv import load_dotenv

from pyroengine.core import Engine


def test_engine(tmpdir_factory, mock_classification_image):
def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):

# Cache
folder = str(tmpdir_factory.mktemp("engine_cache"))

# No API
engine = Engine("pyronear/rexnet1_3x", cache_folder=folder)

# Cache saving
_ts = datetime.utcnow().isoformat()
engine._stage_alert(mock_classification_image, 0)
engine._stage_alert(mock_wildfire_image, 0)
assert len(engine._alerts) == 1
assert engine._alerts[0]["ts"] < datetime.utcnow().isoformat() and _ts < engine._alerts[0]["ts"]
assert engine._alerts[0]["media_id"] is None
Expand All @@ -31,20 +34,73 @@ def test_engine(tmpdir_factory, mock_classification_image):
"cam_id": 0,
"ts": engine._alerts[0]["ts"],
}
# Overrites cache files
engine._dump_cache()

# Cache dump loading
engine = Engine("pyronear/rexnet1_3x", cache_folder=folder)
assert len(engine._alerts) == 1
engine.clear_cache()

# inference
engine = Engine("pyronear/rexnet1_3x", cache_folder=folder)
out = engine.predict(mock_classification_image, 0)
engine = Engine("pyronear/rexnet1_3x", alert_relaxation=3, cache_folder=folder)
out = engine.predict(mock_forest_image)
assert isinstance(out, float) and 0 <= out <= 1
assert engine._states["-1"]["consec"] == 0
out = engine.predict(mock_wildfire_image)
assert isinstance(out, float) and 0 <= out <= 1
assert engine._states["-1"]["consec"] == 1
# Alert relaxation
assert not engine._states["-1"]["ongoing"]
out = engine.predict(mock_classification_image, 0)
out = engine.predict(mock_classification_image, 0)
out = engine.predict(mock_wildfire_image)
assert engine._states["-1"]["consec"] == 2
out = engine.predict(mock_wildfire_image)
assert engine._states["-1"]["consec"] == 3
assert engine._states["-1"]["ongoing"]


def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image):
# Cache
folder = str(tmpdir_factory.mktemp("engine_cache"))
# With API
load_dotenv(Path(__file__).parent.parent.joinpath(".env").absolute())
api_url = os.environ.get("API_URL")
lat = os.environ.get("LAT")
lon = os.environ.get("LON")
cam_creds = {"dummy_cam": {"login": os.environ.get("API_LOGIN"), "password": os.environ.get("API_PWD")}}
# Skip the API-related tests if the URL is not specified
if isinstance(api_url, str):
engine = Engine(
"pyronear/rexnet1_3x",
api_url=api_url,
cam_creds=cam_creds,
latitude=float(lat),
longitude=float(lon),
alert_relaxation=2,
frame_saving_period=3,
cache_folder=folder,
frame_size=(224, 224),
)
# Heartbeat
start_ts = datetime.utcnow().isoformat()
response = engine.heartbeat("dummy_cam")
assert response.status_code // 100 == 2
ts = datetime.utcnow().isoformat()
json_respone = response.json()
assert start_ts < json_respone["last_ping"] < ts
# Send an alert
engine.predict(mock_wildfire_image, "dummy_cam")
assert len(engine._alerts) == 0 and engine._states["dummy_cam"]["consec"] == 1
assert engine._states["dummy_cam"]["frame_count"] == 1
engine.predict(mock_wildfire_image, "dummy_cam")
assert engine._states["dummy_cam"]["consec"] == 2 and engine._states["dummy_cam"]["ongoing"]
assert engine._states["dummy_cam"]["frame_count"] == 2
# Check that a media and an alert have been registered
assert len(engine._alerts) == 0
# Upload a frame
response = engine._upload_frame("dummy_cam", mock_wildfire_stream)
assert response.status_code // 100 == 2
# Upload frame in process
engine.predict(mock_wildfire_image, "dummy_cam")
# Check that a new media has been created & uploaded
assert engine._states["dummy_cam"]["frame_count"] == 0
6 changes: 3 additions & 3 deletions tests/test_engine_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
from pyroengine.vision import Classifier


def test_classifier(mock_classification_image):
def test_classifier(mock_wildfire_image):

# Instantiae the ONNX model
model = Classifier("pyronear/rexnet1_3x")
# Check preprocessing
out = model.preprocess_image(mock_classification_image)
out = model.preprocess_image(mock_wildfire_image)
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1, 3, 224, 224)
# Check inference
out = model(mock_classification_image)
out = model(mock_wildfire_image)
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1,)
assert out >= 0 and out <= 1
Expand Down

0 comments on commit 33a2904

Please sign in to comment.