Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added an alert cache mechanism to avoid loosing information without internet #61

Merged
merged 10 commits into from
Sep 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 167 additions & 74 deletions pyroengine/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
# See LICENSE or go to <https://www.gnu.org/licenses/agpl-3.0.txt> for full license details.

import io
import os
import json
import logging
from PIL import Image
from pathlib import Path
from requests.exceptions import ConnectionError
from datetime import datetime, timedelta
from collections import deque
from typing import Optional, Dict

from pyroclient import client
from .predictor import PyronearPredictor

Expand All @@ -16,130 +25,214 @@ class PyronearEngine:
not based on this image.

Args:
detection_threshold (float): wildfire detection threshold in [0, 1]
api_url (str): url of the pyronear API
pi_zero_credentials (Dict): api credectials for each pizero, the dictionary should as the one in the example
save_every_n_frame (int): Send one frame over N to the api for our dataset
latitude (float): device latitude
longitude (float): device longitude
detection_thresh: wildfire detection threshold in [0, 1]
api_url: url of the pyronear API
client_creds: api credectials for each pizero, the dictionary should be as the one in the example
frame_saving_period: Send one frame over N to the api for our dataset
latitude: device latitude
longitude: device longitude
cache_size: maximum number of alerts to save in cache
frgfm marked this conversation as resolved.
Show resolved Hide resolved
alert_relaxation: number of consecutive positive detections required to send the first alert, and also
the number of consecutive negative detections before stopping the alert
cache_backup_period: number of minutes between each cache backup to disk

Examples:
>>> pi_zero_credentials ={}
>>> pi_zero_credentials['pi_zero_id_1']={'login':'log1', 'password':'pwd1'}
>>> pi_zero_credentials['pi_zero_id_2']={'login':'log2', 'password':'pwd2'}
>>> pyroEngine = PyronearEngine(0.6, 'https://api.pyronear.org', pi_zero_credentials, 50)
>>> pyroEngine.run()
>>> client_creds ={}
>>> client_creds['pi_zero_id_1']={'login':'log1', 'password':'pwd1'}
>>> client_creds['pi_zero_id_2']={'login':'log2', 'password':'pwd2'}
>>> pyroEngine = PyronearEngine(0.6, 'https://api.pyronear.org', client_creds, 50)
"""
def __init__(
self,
detection_threshold=0.5,
api_url=None,
pi_zero_credentials=None,
save_evry_n_frame=None,
latitude=None,
longitude=None
):
detection_thresh: float = 0.5,
api_url: Optional[str] = None,
client_creds: Optional[Dict[str, str]] = None,
frame_saving_period: Optional[int] = None,
latitude: Optional[float] = None,
longitude: Optional[float] = None,
cache_size: int = 100,
alert_relaxation: int = 3,
frgfm marked this conversation as resolved.
Show resolved Hide resolved
cache_backup_period: int = 60,
) -> None:
"""Init engine"""
# Engine Setup
self.pyronearPredictor = PyronearPredictor()
self.detection_threshold = detection_threshold
self.detection_counter = {}
self.event_appening = {}
self.frames_counter = {}
self.save_evry_n_frame = save_evry_n_frame
if pi_zero_credentials is not None:
for pi_zero_id in pi_zero_credentials.keys():
self.detection_counter[pi_zero_id] = 0
self.event_appening[pi_zero_id] = False
self.frames_counter[pi_zero_id] = 0
else:
self.detection_counter['-1'] = 0
self.event_appening['-1'] = False
self.detection_thresh = detection_thresh
self.frame_saving_period = frame_saving_period
self.alert_relaxation = alert_relaxation

# API Setup
self.use_api = False
self.api_url = api_url
self.latitude = latitude
self.longitude = longitude
if self.api_url is not None:
self.use_api = True
self.init_api(pi_zero_credentials)

# Var initialization
self.stream = io.BytesIO()
self.consec_dets = {}
self.ongoing_alert = {}
self.frames_counter = {}
if isinstance(client_creds, dict):
for pi_zero_id in client_creds.keys():
self.consec_dets[pi_zero_id] = 0
self.frames_counter[pi_zero_id] = 0
self.ongoing_alert[pi_zero_id] = False
else:
self.consec_dets['-1'] = 0
self.ongoing_alert['-1'] = 0

def predict(self, frame, pi_zero_id=None):
if self.api_url is not None:
# Instantiate clients for each camera
self.api_client = {}
for _id, vals in client_creds.items():
self.api_client[_id] = client.Client(self.api_url, vals['login'], vals['password'])

# Restore pending alerts cache
self.pending_alerts = deque([], cache_size)
self._backup_folder = Path("data/") # with Docker, the path has to be a bind volume
self.load_cache_from_disk()
self.cache_backup_period = cache_backup_period
self.last_cache_dump = datetime.utcnow()

def predict(self, frame: Image.Image, pi_zero_id: Optional[int] = None) -> float:
""" run prediction on comming frame"""
res = self.pyronearPredictor.predict(frame.convert('RGB')) # run prediction
prob = self.pyronearPredictor.predict(frame.convert('RGB')) # run prediction
if pi_zero_id is None:
logging.info(f"Wildfire detection score ({res:.2%})")
logging.info(f"Wildfire detection score ({prob:.2%})")
else:
self.heartbeat(pi_zero_id)
logging.info(f"Wildfire detection score ({res:.2%}), on device {pi_zero_id}")
logging.info(f"Wildfire detection score ({prob:.2%}), on device {pi_zero_id}")

if res > self.detection_threshold:
# Alert
if prob > self.detection_thresh:
if pi_zero_id is None:
pi_zero_id = '-1' # add default key value

if not self.event_appening[pi_zero_id]:
self.detection_counter[pi_zero_id] += 1
# Ensure counter max value is 3
if self.detection_counter[pi_zero_id] > 3:
self.detection_counter[pi_zero_id] = 3
# Don't increment beyond relaxation
if not self.ongoing_alert[pi_zero_id] and self.consec_dets[pi_zero_id] < self.alert_relaxation:
self.consec_dets[pi_zero_id] += 1

# If counter reach 3, start sending alerts
if self.detection_counter[pi_zero_id] == 3:
self.event_appening[pi_zero_id] = True
if self.consec_dets[pi_zero_id] == self.alert_relaxation:
self.ongoing_alert[pi_zero_id] = True

if self.use_api and self.event_appening[pi_zero_id]:
frame.save(self.stream, format='JPEG')
# Send alert to the api
self.send_alert(pi_zero_id)
self.stream.seek(0) # "Rewind" the stream to the beginning so we can read its content
if isinstance(self.api_url, str) and self.ongoing_alert[pi_zero_id]:
# Save the alert in cache to avoid connection issues
self.save_to_cache(frame, pi_zero_id)

# No wildfire
else:
if self.detection_counter[pi_zero_id] > 0:
self.detection_counter[pi_zero_id] -= 1

if self.detection_counter[pi_zero_id] == 0 and self.event_appening[pi_zero_id]:
# Stop event
self.event_appening[pi_zero_id] = False
if self.consec_dets[pi_zero_id] > 0:
self.consec_dets[pi_zero_id] -= 1
# Consider event as finished
if self.consec_dets[pi_zero_id] == 0:
self.ongoing_alert[pi_zero_id] = False

# Uploading pending alerts
if len(self.pending_alerts) > 0:
self.upload_pending_alerts()

# Check if it's time to backup pending alerts
ts = datetime.utcnow()
if ts > self.last_cache_dump + timedelta(minutes=self.cache_backup_period):
self.save_cache_to_disk()
self.last_cache_dump = ts

# save frame
if self.use_api and self.save_evry_n_frame:
if isinstance(self.api_url, str) and isinstance(self.frame_saving_period, int) and isinstance(pi_zero_id, int):
self.frames_counter[pi_zero_id] += 1
if self.frames_counter[pi_zero_id] == self.save_evry_n_frame:
if self.frames_counter[pi_zero_id] == self.frame_saving_period:
# Reset frame counter
self.frames_counter[pi_zero_id] = 0
# Send frame to the api
frame.save(self.stream, format='JPEG')
self.save_frame(pi_zero_id)
self.stream.seek(0) # "Rewind" the stream to the beginning so we can read its content

return res
return prob

def init_api(self, pi_zero_credentials):
"""Setup api"""
self.api_client = {}
for pi_zero_id in pi_zero_credentials.keys():
self.api_client[pi_zero_id] = client.Client(self.api_url, pi_zero_credentials[pi_zero_id]['login'],
pi_zero_credentials[pi_zero_id]['password'])

def send_alert(self, pi_zero_id):
def send_alert(self, pi_zero_id: int) -> None:
"""Send alert"""
logging.info("Send alert !")
logging.info("Sending alert...")
# Create a media
media_id = self.api_client[pi_zero_id].create_media_from_device().json()["id"]
# Create an alert linked to the media and the event
self.api_client[pi_zero_id].send_alert_from_device(lat=self.latitude, lon=self.longitude, media_id=media_id)
self.api_client[pi_zero_id].upload_media(media_id=media_id, media_data=self.stream.getvalue())

def save_frame(self, pi_zero_id):
def upload_frame(self, pi_zero_id: int) -> None:
"""Save frame"""
logging.info("Upload media for dataset")
logging.info("Uploading media...")
# Create a media
media_id = self.api_client[pi_zero_id].create_media_from_device().json()["id"]
# Send media
self.api_client[pi_zero_id].upload_media(media_id=media_id, media_data=self.stream.getvalue())

def heartbeat(self, pi_zero_id):
def heartbeat(self, pi_zero_id: int) -> None:
"""Updates last ping of device"""
self.api_client[pi_zero_id].heartbeat()

def save_to_cache(self, frame: Image.Image, pi_zero_id: int) -> None:
# Store information in the queue
self.pending_alerts.append(
{"frame": frame, "pi_zero_id": pi_zero_id, "ts": datetime.utcnow()}
)

def upload_pending_alerts(self) -> None:

for _ in range(len(self.pending_alerts)):
# try to upload the oldest element
frame_info = self.pending_alerts[0]

try:
frame_info['frame'].save(self.stream, format='JPEG')
# Send alert to the api
self.send_alert(frame_info['pi_zero_id'])
# No need to upload it anymore
self.pending_alerts.popleft()
logging.info(f"Alert sent by device {frame_info['pi_zero_id']}")
except ConnectionError:
logging.warning(f"Unable to upload cache for device {frame_info['pi_zero_id']}")
self.stream.seek(0) # "Rewind" the stream to the beginning so we can read its content
break

def save_cache_to_disk(self) -> None:

# Remove previous dump
json_path = self._backup_folder.joinpath('pending_alerts.json')
if json_path.is_file():
with open(json_path, 'rb') as f:
data = json.load(f)

for entry in data:
os.remove(entry['frame_path'])
os.remove(json_path)

data = []
for idx, info in enumerate(self.pending_alerts):
# Save frame to disk
info['frame'].save(self._backup_folder.joinpath(f"pending_frame{idx}.jpg"))

# Save path in JSON
data.append({
"frame_path": str(self._backup_folder.joinpath(f"pending_frame{idx}.jpg")),
"pi_zero_id": info["pi_zero_id"],
"ts": info['ts']
})

# JSON dump
if len(data) > 0:
with open(json_path, 'w') as f:
json.dump(data, f)

def load_cache_from_disk(self) -> None:
# Read json
json_path = self._backup_folder.joinpath('pending_alerts.json')
if json_path.is_file():
with open(json_path, 'rb') as f:
data = json.load(f)

for entry in data:
# Open image
frame = Image.open(entry['frame_path'], mode='r')
self.pending_alerts.append(
{"frame": frame, "pi_zero_id": entry['pi_zero_id'], "ts": entry['ts']}
)
11 changes: 6 additions & 5 deletions pyroengine/engine/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ def __init__(self):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
img_size = 448

self.tf = transforms.Compose([transforms.Resize(size=img_size),
transforms.CenterCrop(size=img_size),
transforms.ToTensor(),
normalize
])
self.tf = transforms.Compose([
transforms.Resize(size=img_size),
transforms.CenterCrop(size=img_size),
transforms.ToTensor(),
normalize
])

def predict(self, im):
"""Run prediction"""
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pyrovision >= 0.1.2
python-dotenv >= 0.15.0
requests>=2.25.1
pyroclient@git+https://github.com/pyronear/pyro-api.git#egg=pyroclient&subdirectory=client
pandas>=0.25.2
psutil
3 changes: 3 additions & 0 deletions test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def test_engine(self):

self.assertGreater(res, 0.5)

# Check backup
engine.save_cache_to_disk()


if __name__ == '__main__':
unittest.main()