-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
57 lines (47 loc) · 1.71 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import io
import torch
from flask import Flask, jsonify, request, render_template
from PIL import Image
from main import DEVICE, DTYPE, FOOD101_CLASSES
from src.mobilenet import MyMobileNet
from src.data_utils import get_test_transform, load_classes
# Set constants
MODEL_PATH = 'checkpoint/my_model/best_model.pth.tar' # Replace with appropraite checkpoint file path
CLASS_PATH = 'data/food-101/meta/classes.txt'
app = Flask(__name__)
# Load classes and model
classes = load_classes(CLASS_PATH)
model = MyMobileNet(
output_classes=FOOD101_CLASSES,
device=DEVICE,
checkpoint_path=MODEL_PATH
)
model.eval()
def transform_image(image_bytes):
"""Process raw image bytes and apply test image transformations for model input."""
transform = get_test_transform()
image = Image.open(io.BytesIO(image_bytes))
return transform(image).unsqueeze(0)
def get_prediction(image_bytes):
"""Get predicted class index and class for the image bytes."""
x = transform_image(image_bytes)
x = x.to(device=DEVICE, dtype=DTYPE)
with torch.no_grad(): # Disable gradient tracking
score = model(x)
_, pred = score.max(1)
return pred.item(), classes[pred.item()]
@app.route('/', methods=['GET'])
def index():
"""Render the main page."""
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
"""Handle prediction requests."""
if request.method == 'POST':
# Receive and read file from request
file = request.files['file']
img_bytes = file.read()
class_id, class_name = get_prediction(img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run(debug=True)