forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Flask REST API Example (ultralytics#2732)
* add files * Update README.md * Update README.md * Update restapi.py pretrained=True and model.eval() are used by default when loading a model now, so no need to call them manually. * PEP8 reformat * PEP8 reformat Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
- Loading branch information
1 parent
d074b73
commit 5eec89f
Showing
3 changed files
with
102 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Flask REST API | ||
[REST](https://en.wikipedia.org/wiki/Representational_state_transfer) [API](https://en.wikipedia.org/wiki/API)s are commonly used to expose Machine Learning (ML) models to other services. This folder contains an example REST API created using Flask to expose the `yolov5s` model from [PyTorch Hub](https://pytorch.org/hub/ultralytics_yolov5/). | ||
|
||
## Requirements | ||
|
||
[Flask](https://palletsprojects.com/p/flask/) is required. Install with: | ||
```shell | ||
$ pip install Flask | ||
``` | ||
|
||
## Run | ||
|
||
After Flask installation run: | ||
|
||
```shell | ||
$ python3 restapi.py --port 5000 | ||
``` | ||
|
||
Then use [curl](https://curl.se/) to perform a request: | ||
|
||
```shell | ||
$ curl -X POST -F image=@zidane.jpg 'http://localhost:5000/v1/object-detection/yolov5s'` | ||
``` | ||
|
||
The model inference results are returned: | ||
|
||
```shell | ||
[{'class': 0, | ||
'confidence': 0.8197850585, | ||
'name': 'person', | ||
'xmax': 1159.1403808594, | ||
'xmin': 750.912902832, | ||
'ymax': 711.2583007812, | ||
'ymin': 44.0350036621}, | ||
{'class': 0, | ||
'confidence': 0.5667674541, | ||
'name': 'person', | ||
'xmax': 1065.5523681641, | ||
'xmin': 116.0448303223, | ||
'ymax': 713.8904418945, | ||
'ymin': 198.4603881836}, | ||
{'class': 27, | ||
'confidence': 0.5661227107, | ||
'name': 'tie', | ||
'xmax': 516.7975463867, | ||
'xmin': 416.6880187988, | ||
'ymax': 717.0524902344, | ||
'ymin': 429.2020568848}] | ||
``` | ||
|
||
An example python script to perform inference using [requests](https://docs.python-requests.org/en/master/) is given in `example_request.py` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
"""Perform test request""" | ||
import pprint | ||
|
||
import requests | ||
|
||
DETECTION_URL = "http://localhost:5000/v1/object-detection/yolov5s" | ||
TEST_IMAGE = "zidane.jpg" | ||
|
||
image_data = open(TEST_IMAGE, "rb").read() | ||
|
||
response = requests.post(DETECTION_URL, files={"image": image_data}).json() | ||
|
||
pprint.pprint(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Run a rest API exposing the yolov5s object detection model | ||
""" | ||
import argparse | ||
import io | ||
|
||
import torch | ||
from PIL import Image | ||
from flask import Flask, request | ||
|
||
app = Flask(__name__) | ||
|
||
DETECTION_URL = "/v1/object-detection/yolov5s" | ||
|
||
|
||
@app.route(DETECTION_URL, methods=["POST"]) | ||
def predict(): | ||
if not request.method == "POST": | ||
return | ||
|
||
if request.files.get("image"): | ||
image_file = request.files["image"] | ||
image_bytes = image_file.read() | ||
|
||
img = Image.open(io.BytesIO(image_bytes)) | ||
|
||
results = model(img, size=640) | ||
data = results.pandas().xyxy[0].to_json(orient="records") | ||
return data | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Flask api exposing yolov5 model") | ||
parser.add_argument("--port", default=5000, type=int, help="port number") | ||
args = parser.parse_args() | ||
|
||
model = torch.hub.load("ultralytics/yolov5", "yolov5s", force_reload=True).autoshape() # force_reload to recache | ||
app.run(host="0.0.0.0", port=args.port) # debug=True causes Restarting with stat |