Skip to content

Commit

Permalink
Update of the main branch before the new deployment in Ardeche (#145)
Browse files Browse the repository at this point in the history
* ci on develop (#134)

* update develop tour (#135)

* fix args (#136)

* fix: Fixed script arg name in src/run.py (#138)

* clean backup by size (#141)

* clean backup by size

* black

* missing deps

* drop function

* drop function 2

* Day only (#142)

* sunset_sunrise script

* update output file path

* check if day time

* put back real api

* style

* use datetime timedelta

* Make all params availables from run (#144)

* make all params available in run

* make params availables

* put back api

* style

* put back alert_relaxation to 2

* Yolov5 (#143)

* switch to yolov5

* missing import

* style

* fix tests

* unused import

* update readme

* letterbox transform

* do not resize with pillow before pred

* style

* downgrad opencv

* wrong img name

* long line

* missing deps

* model path

* header

* lib for opencv

* create model folder

* update init

* remove hub deps

* update threshold
  • Loading branch information
MateoLostanlen authored Feb 2, 2023
1 parent fe82a48 commit 4b5406e
Show file tree
Hide file tree
Showing 14 changed files with 208 additions and 110 deletions.
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ COPY ./README.md /tmp/README.md
COPY ./setup.py /tmp/setup.py

COPY ./src/requirements.txt /tmp/requirements.txt
RUN pip install --upgrade pip setuptools wheel \
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y\
&& pip install --upgrade pip setuptools wheel \
&& pip install -r /tmp/requirements.txt \
&& pip cache purge \
&& rm -rf /root/.cache/pip
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ You can use the library like any other python package to detect wildfires as fol
from pyroengine.core import Engine
from PIL import Image

engine = Engine("pyronear/rexnet1_3x")
engine = Engine()

im = Image.open("path/to/your/image.jpg").convert('RGB')

Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ dependencies = [
"Pillow>=8.4.0",
"onnxruntime>=1.10.0,<2.0.0",
"numpy>=1.19.5,<2.0.0",
"huggingface-hub>=0.4.0,<1.0.0",
"pyroclient>=0.1.2",
"requests>=2.20.0,<3.0.0",
"opencv-python==4.5.5.64",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -107,7 +107,6 @@ module = [
"onnxruntime.*",
"requests.*",
"PIL.*",
"huggingface_hub.*",
"pyroclient.*",
"urllib3.*",
]
Expand Down
2 changes: 1 addition & 1 deletion pyroengine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .core import *
from . import engine, sensors
from . import engine, sensors, utils
from .version import __version__
86 changes: 61 additions & 25 deletions pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import glob
import io
import json
import logging
Expand All @@ -26,6 +27,25 @@
logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True)


def is_day_time(cache, delta=3600):
"""Read sunset and sunrise hour in sunset_sunrise.txt and check if we are currently on daytime. We don't want to
trigger night alerts for now. We take 1 hour margin
Args:
cache (Path): cache folder where sunset_sunrise.txt is located
delta (int): delta before and after sunset / sunrise in sec
Returns:
bool: is day time
"""
with open(cache.joinpath("sunset_sunrise.txt")) as f:
lines = f.readlines()
sunrise = datetime.strptime(lines[0][:-1], "%H:%M")
sunset = datetime.strptime(lines[1][:-1], "%H:%M")
now = datetime.strptime(datetime.now().isoformat().split("T")[1][:5], "%H:%M")
return (now - sunrise).total_seconds() > -delta and (sunset - now).total_seconds() > -delta


class Engine:
"""This implements an object to manage predictions and API interactions for wildfire alerts.
Expand All @@ -50,13 +70,13 @@ class Engine:
>>> "cam_id_1": {'login':'log1', 'password':'pwd1'},
>>> "cam_id_2": {'login':'log2', 'password':'pwd2'},
>>> }
>>> pyroEngine = Engine("pyronear/rexnet1_3x", 0.5, 'https://api.pyronear.org', cam_creds, 48.88, 2.38)
>>> pyroEngine = Engine("data/model.onnx", 0.25, 'https://api.pyronear.org', cam_creds, 48.88, 2.38)
"""

def __init__(
self,
hub_repo: str,
conf_thresh: float = 0.5,
model_path: Optional[str] = "data/model.onnx",
conf_thresh: Optional[float] = 0.25,
api_url: Optional[str] = None,
cam_creds: Optional[Dict[str, Dict[str, str]]] = None,
latitude: Optional[float] = None,
Expand All @@ -68,12 +88,13 @@ def __init__(
cache_size: int = 100,
cache_folder: str = "data/",
backup_size: int = 30,
jpeg_quality: int = 80,
**kwargs: Any,
) -> None:
"""Init engine"""
# Engine Setup

self.model = Classifier(hub_repo, **kwargs)
self.model = Classifier(model_path)
self.conf_thresh = conf_thresh

# API Setup
Expand All @@ -91,7 +112,7 @@ def __init__(
self.frame_saving_period = frame_saving_period
self.alert_relaxation = alert_relaxation
self.frame_size = frame_size
self.jpeg_quality = 50
self.jpeg_quality = jpeg_quality
self.cache_backup_period = cache_backup_period

# Local backup
Expand Down Expand Up @@ -203,23 +224,28 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
except ConnectionError:
logging.warning(f"Unable to reach the pyro-api with {cam_id}")

# Inference with ONNX
pred = float(self.model(frame.convert("RGB")))
# Log analysis result
device_str = f"Camera '{cam_id}' - " if isinstance(cam_id, str) else ""
pred_str = "Wildfire detected" if pred >= self.conf_thresh else "No wildfire"
logging.info(f"{device_str}{pred_str} (confidence: {pred:.2%})")

cam_key = cam_id or "-1"
# Reduce image size to save bandwidth
if isinstance(self.frame_size, tuple):
frame = frame.resize(self.frame_size[::-1], Image.BILINEAR)
frame_resize = frame.resize(self.frame_size[::-1], Image.BILINEAR)

# Alert
cam_key = cam_id or "-1"
to_be_staged = self._update_states(pred, cam_key)
if to_be_staged and len(self.api_client) > 0 and isinstance(cam_id, str):
# Save the alert in cache to avoid connection issues
self._stage_alert(frame, cam_id)
if is_day_time(self._cache):

# Inference with ONNX
pred = float(self.model(frame.convert("RGB")))
# Log analysis result
device_str = f"Camera '{cam_id}' - " if isinstance(cam_id, str) else ""
pred_str = "Wildfire detected" if pred >= self.conf_thresh else "No wildfire"
logging.info(f"{device_str}{pred_str} (confidence: {pred:.2%})")

# Alert

to_be_staged = self._update_states(pred, cam_key)
if to_be_staged and len(self.api_client) > 0 and isinstance(cam_id, str):
# Save the alert in cache to avoid connection issues
self._stage_alert(frame_resize, cam_id)
else:
pred = 0 # return default value

# Uploading pending alerts
if len(self._alerts) > 0:
Expand All @@ -236,10 +262,10 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
self._states[cam_key]["frame_count"] += 1
if self._states[cam_key]["frame_count"] == self.frame_saving_period:
# Save frame on device
self._local_backup(frame, cam_id, is_alert=False)
self._local_backup(frame_resize, cam_id, is_alert=False)
# Send frame to the api
stream = io.BytesIO()
frame.save(stream, format="JPEG", quality=self.jpeg_quality)
frame_resize.save(stream, format="JPEG", quality=self.jpeg_quality)
try:
self._upload_frame(cam_id, stream.getvalue())
# Reset frame counter
Expand Down Expand Up @@ -334,13 +360,23 @@ def _local_backup(self, img: Image.Image, cam_id: str, is_alert: bool = False) -
img.save(file)

def _clean_local_backup(self, backup_cache) -> None:
"""Clean local backup after _backup_size days
"""Clean local backup when it's bigger than _backup_size MB
Args:
backup_cache (Path): backup to clean
"""
backup_by_days = list(backup_cache.glob("*"))
backup_by_days.sort()
nb_folder_to_remove = len(backup_by_days) - self._backup_size
for _, folder in zip(range(nb_folder_to_remove), backup_by_days):
shutil.rmtree(folder)
for folder in backup_by_days:
s = (
sum(
os.path.getsize(f)
for f in glob.glob(str(backup_cache) + "/**/*", recursive=True)
if os.path.isfile(f)
)
// 1024**2
)
if s > self._backup_size:
shutil.rmtree(folder)
else:
break
53 changes: 53 additions & 0 deletions pyroengine/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (C) 2023, Pyronear.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.


import cv2
import numpy as np

__all__ = ["letterbox"]


def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, stride=32):
"""Letterbox image transform for yolo models
Args:
im (np.array): Input image
new_shape (tuple, optional): Image size. Defaults to (640, 640).
color (tuple, optional): Pixel fill value for the area outside the transformed image.
Defaults to (114, 114, 114).
auto (bool, optional): auto padding. Defaults to True.
stride (int, optional): padding stride. Defaults to 32.
Returns:
np.array: Output image
"""
# Resize and pad image while meeting stride-multiple constraints
im = np.array(im)
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)

# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])

# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding

if auto: # minimum rectangle
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding

dw /= 2 # divide padding into 2 sides
dh /= 2

if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
# add border
h, w = im.shape[:2]
im_b = np.zeros((h + top + bottom, w + left + right, 3)) + color
im_b[top : top + h, left : left + w, :] = im

return im_b.astype("uint8")
66 changes: 30 additions & 36 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,71 +3,65 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import json
from typing import Any, Optional
import os
import urllib
from typing import Optional

import numpy as np
import onnxruntime
from huggingface_hub import hf_hub_download
from PIL import Image

from .utils import letterbox

__all__ = ["Classifier"]

MODEL_URL = "https://github.com/pyronear/pyro-vision/releases/download/v0.2.0/yolov5s_v002.onnx"


class Classifier:
"""Implements an image classification model using ONNX backend.
Examples:
>>> from pyroengine.vision import Classifier
>>> model = Classifier("pyronear/rexnet1_3x")
>>> model = Classifier()
Args:
hub_repo: repository from HuggingFace Hub to load the model from
model_path: overrides the model path
cfg_path: overrides the configuration file from the model
kwargs: keyword args of `huggingface_hub.hf_hub_download`
model_path: model path
"""

def __init__(
self,
hub_repo: str,
model_path: Optional[str] = None,
cfg_path: Optional[str] = None,
**kwargs: Any,
) -> None:
# Download model config & checkpoint
_path = cfg_path or hf_hub_download(hub_repo, filename="config.json", **kwargs)
with open(_path, "rb") as f:
self.cfg = json.load(f)

_path = model_path or hf_hub_download(hub_repo, filename="model.onnx", **kwargs)
self.ort_session = onnxruntime.InferenceSession(_path)

def preprocess_image(self, pil_img: Image.Image) -> np.ndarray:
def __init__(self, model_path: Optional[str] = "data/model.onnx") -> None:
# Download model if not available
if not os.path.isfile(model_path):
os.makedirs(os.path.split(model_path)[0], exist_ok=True)
print(f"Downloading model from {MODEL_URL} ...")
urllib.request.urlretrieve(MODEL_URL, model_path)

self.ort_session = onnxruntime.InferenceSession(model_path)

def preprocess_image(self, pil_img: Image.Image, img_size=(640, 384)) -> np.ndarray:
"""Preprocess an image for inference
Args:
pil_img: a valid pillow image
img_size: image size
Returns:
the resized and normalized image of shape (1, C, H, W)
"""

# Resizing
img = pil_img.resize(self.cfg["input_shape"][-2:][::-1], Image.BILINEAR)
# (H, W, C) --> (C, H, W)
img = np.asarray(img).transpose((2, 0, 1)).astype(np.float32) / 255
# Normalization
img -= np.array(self.cfg["mean"])[:, None, None]
img /= np.array(self.cfg["std"])[:, None, None]
np_img = letterbox(np.array(pil_img)) # letterbox
np_img = np.expand_dims(np_img.astype("float"), axis=0)
np_img = np.ascontiguousarray(np_img.transpose((0, 3, 1, 2))) # BHWC to BCHW
np_img = np_img.astype("float32") / 255

return img[None, ...]
return np_img

def __call__(self, pil_img: Image.Image) -> np.ndarray:
np_img = self.preprocess_image(pil_img)

# ONNX inference
ort_input = {self.ort_session.get_inputs()[0].name: np_img}
ort_out = self.ort_session.run(None, ort_input)
# Sigmoid
return 1 / (1 + np.exp(-ort_out[0][0]))
y = self.ort_session.run(["output0"], {"images": np_img})[0]
# Non maximum suppression need to be added here when we will use the location information
# let's avoid useless compute for now

return np.max(y[0, :, 4])
23 changes: 23 additions & 0 deletions scripts/get_sunset_sunrise.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash
# This script must be run with a crontab, run every day at 4am
# 0 4 * * * bash /home/pi/pyro-engine/scripts/update_script.sh

# First obtain a location code from: https://weather.codes/search/
# Insert your location. For example FRXX0076 is a location code for Paris, FRANCE

location="FRXX0076"
tmpfile=/tmp/$location.out

# Obtain sunrise and sunset raw data from weather.com
wget -q "https://weather.com/weather/today/l/$location" -O "$tmpfile"

SUNR=$(grep SunriseSunset "$tmpfile" | grep -oE '((1[0-2]|0?[1-9]):([0-5][0-9]) ?([AaPp][Mm]))' | head -1)
SUNS=$(grep SunriseSunset "$tmpfile" | grep -oE '((1[0-2]|0?[1-9]):([0-5][0-9]) ?([AaPp][Mm]))' | tail -1)


sunrise=$(date --date="$SUNR" +%R)
sunset=$(date --date="$SUNS" +%R)

echo $sunrise > /home/pi/pyro-engine/data/sunset_sunrise.txt
echo $sunset >> /home/pi/pyro-engine/data/sunset_sunrise.txt

Loading

0 comments on commit 4b5406e

Please sign in to comment.