diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1f91ed30..c1893ab8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -32,6 +32,12 @@ jobs: python -m pip install --upgrade pip pip install -e ".[test]" --upgrade - name: Run unittests + env: + API_URL: ${{ secrets.API_URL }} + API_LOGIN: ${{ secrets.API_LOGIN }} + API_PWD: ${{ secrets.API_PWD }} + LAT: 48.88 + LON: 2.38 run: | coverage run -m pytest tests/ coverage xml diff --git a/pyproject.toml b/pyproject.toml index f09293fe..a1a6a16c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ test = [ "pytest>=5.3.2", "coverage[toml]>=4.5.4", "requests>=2.20.0,<3.0.0", + "python-dotenv>=0.15.0", ] quality = [ "flake8>=3.9.0", diff --git a/pyroengine/core.py b/pyroengine/core.py index 962ff2e6..2b544889 100644 --- a/pyroengine/core.py +++ b/pyroengine/core.py @@ -15,6 +15,7 @@ from PIL import Image from pyroclient import client from requests.exceptions import ConnectionError +from requests.models import Response from .vision import Classifier @@ -30,7 +31,7 @@ class Engine: hub_repo: repository on HF Hub to load the ONNX model from conf_thresh: confidence threshold to send an alert api_url: url of the pyronear API - client_creds: api credectials for each pizero, the dictionary should be as the one in the example + cam_creds: api credectials for each camera, the dictionary should be as the one in the example latitude: device latitude longitude: device longitude alert_relaxation: number of consecutive positive detections required to send the first alert, and also @@ -43,11 +44,11 @@ class Engine: Examples: >>> from pyroengine import Engine - >>> client_creds ={ + >>> cam_creds ={ "cam_id_1": {'login':'log1', 'password':'pwd1'}, "cam_id_2": {'login':'log2', 'password':'pwd2'}, } - >>> pyroEngine = Engine("pyronear/rexnet1_3x", 0.5, 'https://api.pyronear.org', client_creds, 48.88, 2.38) + >>> pyroEngine = Engine("pyronear/rexnet1_3x", 0.5, 'https://api.pyronear.org', cam_creds, 48.88, 2.38) """ def __init__( @@ -55,7 +56,7 @@ def __init__( hub_repo: str, conf_thresh: float = 0.5, api_url: Optional[str] = None, - client_creds: Optional[Dict[str, Dict[str, str]]] = None, + cam_creds: Optional[Dict[str, Dict[str, str]]] = None, latitude: Optional[float] = None, longitude: Optional[float] = None, alert_relaxation: int = 3, @@ -74,13 +75,13 @@ def __init__( # API Setup if isinstance(api_url, str): - assert isinstance(latitude, float) and isinstance(longitude, float) and isinstance(client_creds, dict) + assert isinstance(latitude, float) and isinstance(longitude, float) and isinstance(cam_creds, dict) self.latitude = latitude self.longitude = longitude self.api_client = {} - if isinstance(api_url, str) and isinstance(client_creds, dict): + if isinstance(api_url, str) and isinstance(cam_creds, dict): # Instantiate clients for each camera - for _id, vals in client_creds.items(): + for _id, vals in cam_creds.items(): self.api_client[_id] = client.Client(api_url, vals["login"], vals["password"]) # Cache & relaxation @@ -90,12 +91,12 @@ def __init__( self.cache_backup_period = cache_backup_period # Var initialization - self._states: Dict[str, Dict[str, Any]] = {} - if isinstance(client_creds, dict): - for cam_id in client_creds: + self._states: Dict[str, Dict[str, Any]] = { + "-1": {"consec": 0, "frame_count": 0, "ongoing": False}, + } + if isinstance(cam_creds, dict): + for cam_id in cam_creds: self._states[cam_id] = {"consec": 0, "frame_count": 0, "ongoing": False} - else: - self._states["-1"] = {"consec": 0, "frame_count": 0, "ongoing": False} # Restore pending alerts cache self._alerts: deque = deque([], cache_size) @@ -152,9 +153,9 @@ def _load_cache(self) -> None: frame = Image.open(entry["frame_path"], mode="r") self._alerts.append({"frame": frame, "cam_id": entry["cam_id"], "ts": entry["ts"]}) - def heartbeat(self, cam_id: str) -> None: + def heartbeat(self, cam_id: str) -> Response: """Updates last ping of device""" - self.api_client[cam_id].heartbeat() + return self.api_client[cam_id].heartbeat() def _update_states(self, conf: float, cam_key: str) -> bool: """Updates the detection states""" @@ -195,7 +196,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: # Inference with ONNX pred = float(self.model(frame.convert("RGB"))) # Log analysis result - device_str = f"Camera {cam_id} - " if isinstance(cam_id, str) else "" + device_str = f"Camera '{cam_id}' - " if isinstance(cam_id, str) else "" pred_str = "Wildfire detected" if pred >= self.conf_thresh else "No wildfire" logging.info(f"{device_str}{pred_str} (confidence: {pred:.2%})") @@ -236,13 +237,17 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: return pred - def _upload_frame(self, cam_id: str, media_data: bytes) -> None: + def _upload_frame(self, cam_id: str, media_data: bytes) -> Response: """Save frame""" - logging.info("Uploading media...") + logging.info(f"Camera '{cam_id}' - Uploading media...") # Create a media - media_id = self.api_client[cam_id].create_media_from_device().json()["id"] - # Send media - self.api_client[cam_id].upload_media(media_id=media_id, media_data=media_data) + response = self.api_client[cam_id].create_media_from_device() + if response.status_code // 100 == 2: + media = response.json() + # Upload media + self.api_client[cam_id].upload_media(media_id=media["id"], media_data=media_data) + + return response def _stage_alert(self, frame: Image.Image, cam_id: str) -> None: # Store information in the queue @@ -262,7 +267,7 @@ def _process_alerts(self) -> None: # try to upload the oldest element frame_info = self._alerts[0] cam_id = frame_info["cam_id"] - logging.info("Sending alert...") + logging.info(f"Camera {cam_id} - Sending alert from {frame_info['ts']}...") try: # Media creation @@ -283,10 +288,13 @@ def _process_alerts(self) -> None: # Media upload stream = io.BytesIO() frame_info["frame"].save(stream, format="JPEG") - self.api_client[cam_id].upload_media(self._alerts[0]["media_id"], media_data=stream.getvalue()) + self.api_client[cam_id].upload_media( + self._alerts[0]["media_id"], + media_data=stream.getvalue(), + ).json()["id"] # Clear self._alerts.popleft() - logging.info(f"Camera {frame_info['cam_id']} - alert sent") + logging.info(f"Camera {cam_id} - alert sent") stream.seek(0) # "Rewind" the stream to the beginning so we can read its content except (KeyError, ConnectionError): logging.warning(f"Camera {cam_id} - unable to upload cache") diff --git a/tests/conftest.py b/tests/conftest.py index 923dd478..5175933f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,22 @@ @pytest.fixture(scope="session") -def mock_classification_image(tmpdir_factory): +def mock_wildfire_stream(tmpdir_factory): url = "https://github.com/pyronear/pyro-vision/releases/download/v0.1.2/fire_sample_image.jpg" - return Image.open(BytesIO(requests.get(url).content)) + return requests.get(url).content + + +@pytest.fixture(scope="session") +def mock_wildfire_image(tmpdir_factory, mock_wildfire_stream): + return Image.open(BytesIO(mock_wildfire_stream)) + + +@pytest.fixture(scope="session") +def mock_forest_stream(tmpdir_factory): + url = "https://github.com/pyronear/pyro-engine/releases/download/v0.1.1/forest_sample.jpg" + return requests.get(url).content + + +@pytest.fixture(scope="session") +def mock_forest_image(tmpdir_factory, mock_forest_stream): + return Image.open(BytesIO(mock_forest_stream)) diff --git a/tests/test_engine_core.py b/tests/test_engine_core.py index 4d3526a1..aab11249 100644 --- a/tests/test_engine_core.py +++ b/tests/test_engine_core.py @@ -1,20 +1,23 @@ import json +import os from datetime import datetime +from pathlib import Path + +from dotenv import load_dotenv from pyroengine.core import Engine -def test_engine(tmpdir_factory, mock_classification_image): +def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image): # Cache folder = str(tmpdir_factory.mktemp("engine_cache")) - # No API engine = Engine("pyronear/rexnet1_3x", cache_folder=folder) # Cache saving _ts = datetime.utcnow().isoformat() - engine._stage_alert(mock_classification_image, 0) + engine._stage_alert(mock_wildfire_image, 0) assert len(engine._alerts) == 1 assert engine._alerts[0]["ts"] < datetime.utcnow().isoformat() and _ts < engine._alerts[0]["ts"] assert engine._alerts[0]["media_id"] is None @@ -31,6 +34,8 @@ def test_engine(tmpdir_factory, mock_classification_image): "cam_id": 0, "ts": engine._alerts[0]["ts"], } + # Overrites cache files + engine._dump_cache() # Cache dump loading engine = Engine("pyronear/rexnet1_3x", cache_folder=folder) @@ -38,13 +43,64 @@ def test_engine(tmpdir_factory, mock_classification_image): engine.clear_cache() # inference - engine = Engine("pyronear/rexnet1_3x", cache_folder=folder) - out = engine.predict(mock_classification_image, 0) + engine = Engine("pyronear/rexnet1_3x", alert_relaxation=3, cache_folder=folder) + out = engine.predict(mock_forest_image) + assert isinstance(out, float) and 0 <= out <= 1 + assert engine._states["-1"]["consec"] == 0 + out = engine.predict(mock_wildfire_image) assert isinstance(out, float) and 0 <= out <= 1 + assert engine._states["-1"]["consec"] == 1 # Alert relaxation assert not engine._states["-1"]["ongoing"] - out = engine.predict(mock_classification_image, 0) - out = engine.predict(mock_classification_image, 0) + out = engine.predict(mock_wildfire_image) + assert engine._states["-1"]["consec"] == 2 + out = engine.predict(mock_wildfire_image) + assert engine._states["-1"]["consec"] == 3 assert engine._states["-1"]["ongoing"] + +def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image): + # Cache + folder = str(tmpdir_factory.mktemp("engine_cache")) # With API + load_dotenv(Path(__file__).parent.parent.joinpath(".env").absolute()) + api_url = os.environ.get("API_URL") + lat = os.environ.get("LAT") + lon = os.environ.get("LON") + cam_creds = {"dummy_cam": {"login": os.environ.get("API_LOGIN"), "password": os.environ.get("API_PWD")}} + # Skip the API-related tests if the URL is not specified + if isinstance(api_url, str): + engine = Engine( + "pyronear/rexnet1_3x", + api_url=api_url, + cam_creds=cam_creds, + latitude=float(lat), + longitude=float(lon), + alert_relaxation=2, + frame_saving_period=3, + cache_folder=folder, + frame_size=(224, 224), + ) + # Heartbeat + start_ts = datetime.utcnow().isoformat() + response = engine.heartbeat("dummy_cam") + assert response.status_code // 100 == 2 + ts = datetime.utcnow().isoformat() + json_respone = response.json() + assert start_ts < json_respone["last_ping"] < ts + # Send an alert + engine.predict(mock_wildfire_image, "dummy_cam") + assert len(engine._alerts) == 0 and engine._states["dummy_cam"]["consec"] == 1 + assert engine._states["dummy_cam"]["frame_count"] == 1 + engine.predict(mock_wildfire_image, "dummy_cam") + assert engine._states["dummy_cam"]["consec"] == 2 and engine._states["dummy_cam"]["ongoing"] + assert engine._states["dummy_cam"]["frame_count"] == 2 + # Check that a media and an alert have been registered + assert len(engine._alerts) == 0 + # Upload a frame + response = engine._upload_frame("dummy_cam", mock_wildfire_stream) + assert response.status_code // 100 == 2 + # Upload frame in process + engine.predict(mock_wildfire_image, "dummy_cam") + # Check that a new media has been created & uploaded + assert engine._states["dummy_cam"]["frame_count"] == 0 diff --git a/tests/test_engine_vision.py b/tests/test_engine_vision.py index d821004a..ee4c90c0 100644 --- a/tests/test_engine_vision.py +++ b/tests/test_engine_vision.py @@ -4,16 +4,16 @@ from pyroengine.vision import Classifier -def test_classifier(mock_classification_image): +def test_classifier(mock_wildfire_image): # Instantiae the ONNX model model = Classifier("pyronear/rexnet1_3x") # Check preprocessing - out = model.preprocess_image(mock_classification_image) + out = model.preprocess_image(mock_wildfire_image) assert isinstance(out, np.ndarray) and out.dtype == np.float32 assert out.shape == (1, 3, 224, 224) # Check inference - out = model(mock_classification_image) + out = model(mock_wildfire_image) assert isinstance(out, np.ndarray) and out.dtype == np.float32 assert out.shape == (1,) assert out >= 0 and out <= 1