-
Notifications
You must be signed in to change notification settings - Fork 259
/
Copy pathflask_rest_api_tutorial.py
329 lines (285 loc) Β· 14.3 KB
/
flask_rest_api_tutorial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
# -*- coding: utf-8 -*-
"""
Flaskλ₯Ό μ¬μ©νμ¬ Pythonμμ PyTorchλ₯Ό REST APIλ‘ λ°°ν¬νκΈ°
===========================================================
**Author**: `Avinash Sajjanshetty <https://avi.im>`_
**λ²μ**: `λ°μ ν <http://github.com/9bow>`_
μ΄ νν 리μΌμμλ Flaskλ₯Ό μ¬μ©νμ¬ PyTorch λͺ¨λΈμ λ°°ν¬νκ³ λͺ¨λΈ μΆλ‘ (inference)μ
ν μ μλ REST APIλ₯Ό λ§λ€μ΄λ³΄κ² μ΅λλ€. 미리 νλ ¨λ DenseNet 121 λͺ¨λΈμ λ°°ν¬νμ¬
μ΄λ―Έμ§λ₯Ό μΈμν΄λ³΄κ² μ΅λλ€.
.. tip:: μ¬κΈ°μ μ¬μ©ν λͺ¨λ μ½λλ MIT λΌμ΄μ μ€λ‘ λ°°ν¬λλ©°,
`GitHub <https://github.com/avinassh/pytorch-flask-api>`_ μμ νμΈνμ€ μ μμ΅λλ€.
μ΄κ²μ PyTorch λͺ¨λΈμ μμ©(production)μΌλ‘ λ°°ν¬νλ νν λ¦¬μΌ μ리μ¦μ 첫λ²μ§Έ
νΈμ
λλ€. Flaskλ₯Ό μ¬κΈ°μ μκ°λ κ²μ²λΌ μ¬μ©νλ κ²μ΄ PyTorch λͺ¨λΈμ μ 곡νλ
κ°μ₯ μ¬μ΄ λ°©λ²μ΄μ§λ§, κ³ μ±λ₯μ μꡬνλ λμλ μ ν©νμ§ μμ΅λλ€. κ·Έμ λν΄μλ:
- TorchScriptμ μ΄λ―Έ μ΅μνλ€λ©΄, λ°λ‘ `Loading a TorchScript Model in C++ <https://tutorials.pytorch.kr/advanced/cpp_export.html>`_ λ¬ΈμλΆν° μ½μ΄λ³΄μΈμ.
- TorchScriptκ° λ¬΄μμΈμ§ μμ보λ κ²μ΄ νμνλ€λ©΄ `TorchScript μκ° <https://tutorials.pytorch.kr/beginner/Intro_to_TorchScript_tutorial.html>`_ λΆν° μ½μ΄λ³΄μλ κ²μ μΆμ²ν©λλ€.
"""
######################################################################
# API μ μ
# --------------
#
# λ¨Όμ API μλν¬μΈνΈ(endpoint)μ μμ²(request)μ μλ΅(response)μ μ μνλ κ²λΆν°
# μμν΄λ³΄κ² μ΅λλ€. μλ‘ λ§λ€ API μλν¬μΈνΈλ μ΄λ―Έμ§κ° ν¬ν¨λ ``file`` 맀κ°λ³μλ₯Ό
# HTTP POSTλ‘ ``/predict`` μ μμ²ν©λλ€. μλ΅μ JSON ννμ΄λ©° λ€μκ³Ό κ°μ μμΈ‘ κ²°κ³Όλ₯Ό
# ν¬ν¨ν©λλ€:
#
# .. code-block:: sh
#
# {"class_id": "n02124075", "class_name": "Egyptian_cat"}
#
#
######################################################################
# μμ‘΄μ±(Dependencies)
# -------------------------
#
# μλ λͺ
λ Ήμ΄λ₯Ό μ€ννμ¬ νμν ν¨ν€μ§λ€μ μ€μΉν©λλ€:
#
# .. code-block:: sh
#
# $ pip install Flask==2.0.1 torchvision==0.10.0
######################################################################
# κ°λ¨ν μΉ μλ²
# -----------------
#
# Flaskμ λ¬Έμλ₯Ό μ°Έκ³ νμ¬ μλμ κ°μ μ½λλ‘ κ°λ¨ν μΉ μλ²λ₯Ό ꡬμ±ν©λλ€.
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello():
return 'Hello World!'
###############################################################################
# λν, ImageNet λΆλ₯ IDμ μ΄λ¦μ ν¬ν¨νλ JSONμ νμ νλλ‘ μλ΅ νμμ λ³κ²½νκ² μ΅λλ€.
# μ΄μ ``app.py`` λ μλμ κ°μ΄ λ³κ²½λμμ΅λλ€:
from flask import Flask, jsonify
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
######################################################################
# μΆλ‘ (Inference)
# -----------------
#
# λ€μ μΉμ
μμλ μΆλ‘ μ½λ μμ±μ μ§μ€νκ² μ΅λλ€. λ¨Όμ μ΄λ―Έμ§λ₯Ό DenseNetμ 곡κΈ(feed)ν μ
# μλλ‘ μ€λΉνλ λ°©λ²μ μ΄ν΄λ³Έ λ€, λͺ¨λΈλ‘λΆν° μμΈ‘ κ²°κ³Όλ₯Ό μ»λ λ°©λ²μ μ΄ν΄λ³΄κ² μ΅λλ€.
#
# μ΄λ―Έμ§ μ€λΉνκΈ°
# ~~~~~~~~~~~~~~~~~~~
#
# DenseNet λͺ¨λΈμ 224 x 224μ 3μ±λ RGB μ΄λ―Έμ§λ₯Ό νμλ‘ ν©λλ€.
# λν μ΄λ―Έμ§ ν
μλ₯Ό νκ· λ° νμ€νΈμ°¨ κ°μΌλ‘ μ κ·νν©λλ€. μμΈν λ΄μ©μ
# `μ¬κΈ° <https://pytorch.org/vision/stable/models.html>`_ λ₯Ό μ°Έκ³ νμΈμ.
#
# ``torchvision`` λΌμ΄λΈλ¬λ¦¬μ ``transforms`` λ₯Ό μ¬μ©νμ¬ λ³ν νμ΄νλΌμΈ
# (transform pipeline)μ ꡬμΆν©λλ€. Transformsμ κ΄λ ¨ν λ μμΈν λ΄μ©μ
# `μ¬κΈ° <https://pytorch.org/vision/stable/transforms.html>`_ μμ
# μ½μ΄λ³Ό μ μμ΅λλ€.
import io
import torchvision.transforms as transforms
from PIL import Image
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
######################################################################
# μ λ©μλλ μ΄λ―Έμ§λ₯Ό byte λ¨μλ‘ μ½μ ν, μΌλ ¨μ λ³νμ μ μ©νκ³ Tensorλ₯Ό
# λ°νν©λλ€. μ λ©μλλ₯Ό ν
μ€νΈνκΈ° μν΄μλ μ΄λ―Έμ§λ₯Ό byte λͺ¨λλ‘ μ½μ ν
# Tensorλ₯Ό λ°ννλμ§ νμΈνλ©΄ λ©λλ€. (λ¨Όμ `../_static/img/sample_file.jpeg` μ
# μ»΄ν¨ν° μμ μ€μ κ²½λ‘λ‘ λ°κΏμΌ ν©λλ€.)
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
tensor = transform_image(image_bytes=image_bytes)
print(tensor)
######################################################################
# μμΈ‘(Prediction)
# ~~~~~~~~~~~~~~~~~~~
#
# 미리 νμ΅λ DenseNet 121 λͺ¨λΈμ μ¬μ©νμ¬ μ΄λ―Έμ§ λΆλ₯(class)λ₯Ό μμΈ‘ν©λλ€.
# ``torchvision`` λΌμ΄λΈλ¬λ¦¬μ λͺ¨λΈμ μ¬μ©νμ¬ λͺ¨λΈμ μ½μ΄μ€κ³ μΆλ‘ μ ν©λλ€.
# μ΄ μμ μμλ 미리 νμ΅λ λͺ¨λΈμ μ¬μ©νμ§λ§, μ§μ λ§λ λͺ¨λΈμ λν΄μλ
# μ΄μ λμΌν λ°©λ²μ μ¬μ©ν μ μμ΅λλ€. λͺ¨λΈμ μ½μ΄μ€λ κ²μ μ΄
# :doc:`νν λ¦¬μΌ </beginner/saving_loading_models>` μ μ°Έκ³ νμΈμ.
from torchvision import models
# μ΄λ―Έ νμ΅λ κ°μ€μΉλ₯Ό μ¬μ©νκΈ° μν΄ `weights` μ `IMAGENET1K_V1` κ°μ μ λ¬ν©λλ€:
model = models.densenet121(weights='IMAGENET1K_V1')
# λͺ¨λΈμ μΆλ‘ μλ§ μ¬μ©ν κ²μ΄λ―λ‘, `eval` λͺ¨λλ‘ λ³κ²½ν©λλ€:
model.eval()
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
return y_hat
######################################################################
# ``y_hat`` Tensorλ μμΈ‘λ λΆλ₯ IDμ μΈλ±μ€λ₯Ό ν¬ν¨ν©λλ€. νμ§λ§ μ¬λμ΄ μ½μ μ
# μλ λΆλ₯λͺ
μ΄ μμ΄μΌ νκΈ° λλ¬Έμ, μ΄λ₯Ό μν΄ μ΄λ¦κ³Ό λΆλ₯ IDλ₯Ό 맀ννλ κ²μ΄ νμν©λλ€.
# `μ΄ νμΌ <https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json>`_
# μ λ€μ΄λ‘λ λ°μ ``imagenet_class_index.json`` μ΄λΌλ μ΄λ¦μΌλ‘ μ μ₯ ν, μ μ₯ν κ³³μ
# μμΉλ₯Ό κΈ°μ΅ν΄λμΈμ. (λλ, μ΄ νν 리μΌκ³Ό λκ°μ΄ μ§ννλ κ²½μ°μλ `tutorials/_static`
# μ μ μ₯νμΈμ.) μ΄ νμΌμ ImageNet λΆλ₯ IDμ ImageNet λΆλ₯λͺ
μ μμ ν¬ν¨νκ³ μμ΅λλ€.
# μ΄μ μ΄ JSON νμΌμ λΆλ¬μ μμΈ‘ κ²°κ³Όμ μΈλ±μ€μ ν΄λΉνλ λΆλ₯λͺ
μ κ°μ Έμ€κ² μ΅λλ€.
import json
imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
######################################################################
# ``imagenet_class_index`` μ¬μ (dictionary)μ μ¬μ©νκΈ° μ μ,
# ``imagenet_class_index`` μ ν€ κ°μ΄ λ¬Έμμ΄μ΄λ―λ‘ Tensorμ κ°λ
# λ¬Έμμ΄λ‘ λ³νν΄μΌ ν©λλ€.
# μ λ©μλλ₯Ό ν
μ€νΈν΄λ³΄κ² μ΅λλ€:
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
print(get_prediction(image_bytes=image_bytes))
######################################################################
# λ€μκ³Ό κ°μ μλ΅μ λ°κ² λ κ²μ
λλ€:
['n02124075', 'Egyptian_cat']
######################################################################
# λ°°μ΄μ 첫λ²μ§Έ νλͺ©μ ImageNet λΆλ₯ IDμ΄κ³ , λλ²μ§Έ νλͺ©μ μ¬λμ΄ μ½μ μ μλ
# μ΄λ¦μ
λλ€.
#
######################################################################
# λͺ¨λΈμ API μλ²μ ν΅ν©νκΈ°
# ---------------------------------------
#
# λ§μ§λ§μΌλ‘ μμμ λ§λ Flask API μλ²μ λͺ¨λΈμ μΆκ°νκ² μ΅λλ€.
# API μλ²λ μ΄λ―Έμ§ νμΌμ λ°λ κ²μ κ°μ νκ³ μμΌλ―λ‘, μμ²μΌλ‘λΆν° νμΌμ μ½λλ‘
# ``predict`` λ©μλλ₯Ό μμ ν΄μΌ ν©λλ€:
#
# .. code-block:: python
#
# from flask import request
#
# @app.route('/predict', methods=['POST'])
# def predict():
# if request.method == 'POST':
# # we will get the file from the request
# file = request.files['file']
# # convert that to bytes
# img_bytes = file.read()
# class_id, class_name = get_prediction(image_bytes=img_bytes)
# return jsonify({'class_id': class_id, 'class_name': class_name})
#
#
######################################################################
# ``app.py`` νμΌμ μ΄μ μμ±λμμ΅λλ€. μλκ° μ 체 μ½λμ
λλ€;
# μλ `<PATH/TO/.json/FILE>` μ κ²½λ‘λ₯Ό json νμΌμ μ μ₯ν΄λ κ²½λ‘λ‘ λ°κΎΈλ©΄ λμν©λλ€:
#
# .. code-block:: python
#
# import io
# import json
#
# from torchvision import models
# import torchvision.transforms as transforms
# from PIL import Image
# from flask import Flask, jsonify, request
#
#
# app = Flask(__name__)
# imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
# model = models.densenet121(weights='IMAGENET1K_V1')
# model.eval()
#
#
# def transform_image(image_bytes):
# my_transforms = transforms.Compose([transforms.Resize(255),
# transforms.CenterCrop(224),
# transforms.ToTensor(),
# transforms.Normalize(
# [0.485, 0.456, 0.406],
# [0.229, 0.224, 0.225])])
# image = Image.open(io.BytesIO(image_bytes))
# return my_transforms(image).unsqueeze(0)
#
#
# def get_prediction(image_bytes):
# tensor = transform_image(image_bytes=image_bytes)
# outputs = model.forward(tensor)
# _, y_hat = outputs.max(1)
# predicted_idx = str(y_hat.item())
# return imagenet_class_index[predicted_idx]
#
#
# @app.route('/predict', methods=['POST'])
# def predict():
# if request.method == 'POST':
# file = request.files['file']
# img_bytes = file.read()
# class_id, class_name = get_prediction(image_bytes=img_bytes)
# return jsonify({'class_id': class_id, 'class_name': class_name})
#
#
# if __name__ == '__main__':
# app.run()
#
#
######################################################################
# μ΄μ μΉ μλ²λ₯Ό ν
μ€νΈν΄λ³΄κ² μ΅λλ€! λ€μκ³Ό κ°μ΄ μ€νν΄λ³΄μΈμ:
#
# .. code-block:: sh
#
# FLASK_ENV=development FLASK_APP=app.py flask run
#
#######################################################################
# `requests <https://pypi.org/project/requests/>`_ λΌμ΄λΈλ¬λ¦¬λ₯Ό μ¬μ©νμ¬
# POST μμ²μ λ§λ€μ΄λ³΄κ² μ΅λλ€:
#
# .. code-block:: python
#
# import requests
#
# resp = requests.post("http://localhost:5000/predict",
# files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})
#######################################################################
# `resp.json()` μ νΈμΆνλ©΄ λ€μκ³Ό κ°μ κ²°κ³Όλ₯Ό μΆλ ₯ν©λλ€:
#
# .. code-block:: sh
#
# {"class_id": "n02124075", "class_name": "Egyptian_cat"}
#
######################################################################
# λ€μ λ¨κ³
# --------------
#
# μ§κΈκΉμ§ λ§λ μλ²λ λ§€μ° κ°λ¨νμ¬ μμ© νλ‘κ·Έλ¨(production application)μΌλ‘μ¨
# κ°μΆ°μΌν κ²λ€μ λͺ¨λ κ°μΆμ§ λͺ»νμ΅λλ€. λ°λΌμ, λ€μκ³Ό κ°μ΄ κ°μ ν΄λ³Ό μ μμ΅λλ€:
#
# - ``/predict`` μλν¬μΈνΈλ μμ² μμ λ°λμ μ΄λ―Έμ§ νμΌμ΄ μ λ¬λλ κ²μ κ°μ νκ³
# μμ΅λλ€. νμ§λ§ λͺ¨λ μμ²μ΄ κ·Έλ μ§λ μμ΅λλ€. μ¬μ©μλ λ€λ₯Έ 맀κ°λ³μλ‘ μ΄λ―Έμ§λ₯Ό
# 보λ΄κ±°λ, μ΄λ―Έμ§λ₯Ό μμ 보λ΄μ§ μμμλ μμ΅λλ€.
#
# - μ¬μ©μκ° μ΄λ―Έμ§κ° μλ μ νμ νμΌμ 보λΌμλ μμ΅λλ€. μ¬κΈ°μλ μλ¬λ₯Ό μ²λ¦¬νμ§
# μκ³ μμΌλ―λ‘, μ΄λ¬ν κ²½μ°μ μλ²λ λ€μ΄(break)λ©λλ€. μμΈ μ λ¬(exception throe)μ
# μν λͺ
μμ μΈ μλ¬ νΈλ€λ§ κ²½λ‘λ₯Ό μΆκ°νλ©΄ μλͺ»λ μ
λ ₯μ λ μ μ²λ¦¬ν μ μμ΅λλ€.
#
# - λͺ¨λΈμ λ§μ μ’
λ₯μ μ΄λ―Έμ§ λΆλ₯λ₯Ό μΈμν μ μμ§λ§, λͺ¨λ μ΄λ―Έμ§λ₯Ό μΈμν μ μλ
# κ²μ μλλλ€. λͺ¨λΈμ΄ μ΄λ―Έμ§μμ μ무κ²λ μΈμνμ§ λͺ»νλ κ²½μ°λ₯Ό μ²λ¦¬νλλ‘
# κ°μ ν©λλ€.
#
# - μμμλ Flask μλ²λ₯Ό κ°λ° λͺ¨λμμ μ€ννμμ§λ§, μ΄λ μμ©μΌλ‘ λ°°ν¬νκΈ°μλ
# μ λΉνμ§ μμ΅λλ€. Flask μλ²λ₯Ό μμ©μΌλ‘ λ°°ν¬νλ κ²μ
# `μ΄ νν λ¦¬μΌ <https://flask.palletsprojects.com/en/1.1.x/tutorial/deploy/>`_
# μ μ°Έκ³ ν΄λ³΄μΈμ.
#
# - λν μ΄λ―Έμ§λ₯Ό κ°μ Έμ€λ μμ(form)κ³Ό μμΈ‘ κ²°κ³Όλ₯Ό νμνλ νμ΄μ§λ₯Ό λ§λ€μ΄
# UIλ₯Ό μΆκ°ν μλ μμ΅λλ€. λΉμ·ν νλ‘μ νΈμ `λ°λͺ¨ <https://pytorch-imagenet.herokuapp.com/>`_
# μ μ΄ λ°λͺ¨μ `μμ€ μ½λ <https://github.com/avinassh/pytorch-flask-api-heroku>`_
# λ₯Ό μ°Έκ³ ν΄λ³΄μΈμ.
#
# - μ΄ νν 리μΌμμλ ν λ²μ νλμ μ΄λ―Έμ§μ λν μμΈ‘ κ²°κ³Όλ₯Ό λ°ννλ μλΉμ€λ₯Ό
# λ§λλ λ°©λ²λ§ μ΄ν΄λ³΄μλλ°μ, ν λ²μ μ¬λ¬ μ΄λ―Έμ§μ λν μμΈ‘ κ²°κ³Όλ₯Ό λ°ννλλ‘
# μμ ν΄λ³Ό μ μμ΅λλ€. μΆκ°λ‘, `service-streamer <https://github.com/ShannonAI/service-streamer>`_
# λΌμ΄λΈλ¬λ¦¬λ μλμΌλ‘ μμ²μ νμ λ£μ λ€ λͺ¨λΈμ 곡κΈ(feed)ν μ μλ λ―Έλ-λ°°μΉλ‘
# μνλ§ν©λλ€. `μ΄ νν λ¦¬μΌ <https://github.com/ShannonAI/service-streamer/wiki/Vision-Recognition-Service-with-Flask-and-service-streamer>`_
# μ μ°Έκ³ ν΄λ³΄μΈμ.
#
# - λ§μ§λ§μΌλ‘ μ΄ λ¬Έμ μλ¨μ λ§ν¬λ, PyTorch λͺ¨λΈμ λ°°ν¬νλ λ€λ₯Έ νν 리μΌλ€μ
# μ½μ΄λ³΄λ κ²μ κΆμ₯ν©λλ€.