-
Notifications
You must be signed in to change notification settings - Fork 259
/
Copy pathcustom_dataset_transforms_loader.py
459 lines (361 loc) Β· 18.5 KB
/
custom_dataset_transforms_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
"""
μ¬μ©μ μ μ PyTorch Dataloader μμ±νκΈ°
===========================================================
λ¨Έμ λ¬λ μκ³ λ¦¬μ¦μ κ°λ°νκΈ° μν΄μλ λ°μ΄ν° μ μ²λ¦¬μ λ§μ λ
Έλ ₯μ΄ νμν©λλ€. PyTorchλ λ°μ΄ν°λ₯Ό λ‘λνλλ° μ½κ³ κ°λ₯νλ€λ©΄
λ μ’μ κ°λ
μ±μ κ°μ§ μ½λλ₯Ό λ§λ€κΈ°μν΄ λ§μ λꡬλ€μ μ 곡ν©λλ€. μ΄ λ μνΌμμλ λ€μ μΈ κ°μ§λ₯Ό λ°°μΈ μ μμ΅λλ€.
1. PyTorch λ°μ΄ν°μ
APIλ€μ μ΄μ©νμ¬ μ¬μ©μ μ μ λ°μ΄ν°μ
λ§λ€κΈ°.
2. ꡬμ±κ°λ₯νλ©° νΈμΆ λ μ μλ μ¬μ©μ μ μ transform λ§λ€κΈ°.
3. μ΄λ¬ν μ»΄ν¬λνΈλ€μ ν©μ³μ μ¬μ©μ μ μ dataloader λ§λ€κΈ°.
μ΄ νν 리μΌμ μ€ννκΈ° μν΄μλ λ€μμ ν¨ν€μ§λ€μ΄ μ€μΉ λμλμ§ νμΈν΄ μ£ΌμΈμ.
- ``scikit-image``: μ΄λ―Έμ§ I/Oμ μ΄λ―Έμ§ λ³νμ νμν©λλ€.
- ``pandas``: CSVλ₯Ό λ μ½κ² νμ±νκΈ° μν΄ νμν©λλ€.
μμ±λκ³ μλ μ΄ μμ μμ, μ΄ λ μνΌλ `Sasank Chilamkurthy <https://chsasank.github.io>`__ μ μ€λ¦¬μ§λ νν 리μΌμ λ°νμΌλ‘ νλ©°
λμ€μλ `Joe Spisak <https://github.com/jspisak>`__ μ μν΄ μμ λμμ΅λλ€.
νκ΅μ΄λ‘ `Jae Joong Lee <https://https://github.com/JaeLee18>`__ μ μν΄ λ²μλμμ΅λλ€.
"""
######################################################################
# μ€μ
# ----------------------
#
# λ¨Όμ μ΄ λ μνΌμ νμν λͺ¨λ λΌμ΄λΈλ¬λ¦¬λ€μ λΆλ¬μ€λλ‘ νκ² μ΅λλ€.
#
#
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# κ²½κ³ λ©μμ§ λ¬΄μνκΈ°
import warnings
warnings.filterwarnings("ignore")
plt.ion() # λ°μν λͺ¨λ μ€μ
######################################################################
# 첫 λ²μ§Έ: λ°μ΄ν°μ
# ----------------------
#
######################################################################
#
# μ°λ¦¬κ° λ€λ£° λ°μ΄ν°μ
μ μΌκ΅΄ ν¬μ¦μ
λλ€.
# μ λ°μ μΌλ‘, ν μΌκ΅΄μλ 68κ°μ λλλ§ν¬λ€μ΄ νμλμ΄ μμ΅λλ€.
#
# λ€μ λ¨κ³λ‘λ, `μ¬κΈ° <https://download.pytorch.org/tutorial/faces.zip>`_ μμ
# λ°μ΄ν°μ
μ λ€μ΄ λ°μ μ΄λ―Έμ§λ€μ΄ βdata/faces/β μ κ²½λ‘μ μμΉνκ² ν΄μ£ΌμΈμ.
#
# **μλ¦Ό:** μ¬μ€ μ΄ λ°μ΄ν°μ
μ imagenet λ°μ΄ν°μ
μμ βfaceβ νκ·Έλ₯Ό ν¬ν¨νκ³ μλ μ΄λ―Έμ§μ
# `dlib` μ ν¬μ¦ μμΈ‘ `<https://blog.dlib.net/2014/08/real-time-face-pose-estimation.html>`__ μ μ μ©νμ¬ μμ±νμμ΅λλ€.
#
# ::
#
# !wget https://download.pytorch.org/tutorial/faces.zip
# !mkdir data/faces/
# import zipfile
# with zipfile.ZipFile("faces.zip","r") as zip_ref:
# zip_ref.extractall("/data/faces/")
# %cd /data/faces/
######################################################################
# μ΄ λ°μ΄ν°μ
μ λ€μκ³Ό κ°μ μ€λͺ
μ΄ λ¬λ €μλ CSVνμΌμ΄ ν¬ν¨λμ΄ μμ΅λλ€.
#
# ::
#
# image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
# 0805personali01.jpg,27,83,27,98, ... 84,134
# 1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312
#
# μ΄μ CSVνμΌμ λΉ λ₯΄κ² μ½κ³ νμΌ μμ μλ μ€λͺ
λ€μ (N, 2) λ°°μ΄λ‘ μ½μ΄λ΄
μλ€.
# μ¬κΈ°μ Nμ λλλ§ν¬μ κ°―μμ
λλ€.
#
landmarks_frame = pd.read_csv('faces/face_landmarks.csv')
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:]
landmarks = np.asarray(landmarks)
landmarks = landmarks.astype('float').reshape(-1, 2)
print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
######################################################################
# 1.1 μ΄λ―Έμ§λ₯Ό νμνκΈ° μν΄ κ°λ¨ν λμ ν¨μ μμ±νκΈ°
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# λ€μμΌλ‘λ μ΄λ―Έμ§λ₯Ό 보μ¬μ£ΌκΈ° μν΄ κ°λ¨ν λμ ν¨μλ₯Ό μμ±νμ¬ μ΄λ―Έμ§κ° κ°μ§κ³ μλ λλλ§ν¬λ€κ³Ό
# μ΄λ―Έμ§ μνμ 보μ¬μ£Όλλ‘ νκ² μ΅λλ€.
#
def show_landmarks(image, landmarks):
""" λλλ§ν¬μ ν¨κ» μ΄λ―Έμ§ 보μ¬μ£ΌκΈ° """
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
plt.pause(0.001) # μ μ λ©μΆμ΄ λνκ° μ
λ°μ΄νΈ λκ² ν©λλ€
plt.figure()
show_landmarks(io.imread(os.path.join('faces/', img_name)),
landmarks)
plt.show()
######################################################################
# 1.2 λ°μ΄ν°μ
ν΄λμ€ λ§λ€κΈ°
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# μ΄μ PyTorch λ°μ΄ν°μ
ν΄λμ€μ λν΄ μμλ΄
μλ€.
#
#
######################################################################
# ``torch.utils.data.Dataset`` μ μΆμ ν΄λμ€λ‘μ λ°μ΄ν°μ
μ λ§‘κ³ μμ΅λλ€
# ``Dataset`` μ μμλ°μμΌ νλ©° λ€μμ λ©μλλ€μ μ€λ²λΌμ΄λ ν΄μΌν©λλ€.
#
# - ``__len__`` μλ ``len(dataset)`` λ°μ΄ν°μ
μ μ¬μ΄μ¦λ₯Ό λ°νν©λλ€.
# - ``__getitem__`` λ μ΄λ¬ν μΈλ±μ±μ μ§μνκ³ ``dataset[i]``
# :math:``i``Β λ²μ§Έ μνμ μ»κΈ° μν΄ μ¬μ©λ©λλ€.
#
# μ°λ¦¬μ μΌκ΅΄ λλλ§ν¬ λ°μ΄ν°μ
μ μν λ°μ΄ν°μ
ν΄λμ€λ₯Ό λ§λ€μ΄ λ΄
μλ€.
# μ°λ¦¬λ csvνμΌμ ``__init__`` μμ μ½κ³ μ΄λ―Έμ§λ€μ ``__getitem__`` μμ μ½λλ‘ λ¨κ²¨λκ² μ΅λλ€.
# μ΄λ¬ν λ°©λ²μ λ©λͺ¨λ¦¬λ₯Ό ν¨μ¨μ μΌλ‘ μ¬μ©νλλ‘ νλλ° κ·Έ μ΄μ λ λͺ¨λ μ΄λ―Έμ§λ₯Ό ν λ²μ λ©λͺ¨λ¦¬μ μ μ₯νμ§ μκ³
# νμν λλ§λ€ λΆλ¬μ€κΈ° λλ¬Έμ
λλ€.
#
# μ°λ¦¬ λ°μ΄ν°μ
μ μνμ dict ννλ‘ μ΄λ κ² ``{'image': image, 'landmarks': landmarks}`` λμ΄μμ΅λλ€.
# λ°μ΄ν°μ
μ μ νμ 맀κ°λ³μμΈ ``transform`` μ κ°μ§κ³ μμ΄μ
# νμν νλ‘μΈμ± μ΄λκ²μ΄λ μνμ μ μ© λ μ μμ΅λλ€.
# ``transform`` μ΄ μΌλ§λ μ μ©νμ§λ λ€λ₯Έ λ μνΌμμ νμΈ ν΄ λ³Ό μ μμ΅λλ€.
#
class FaceLandmarksDataset(Dataset):
""" μΌκ΅΄ λλλ§ν¬ λ°μ΄ν°μ
. """
def __init__(self, csv_file, root_dir, transform=None):
"""
맀κ°λ³μ :
csv_file (λ¬Έμμ΄): μ€λͺ
μ΄ ν¬ν¨λ csv νμΌ κ²½λ‘.
root_dir (λ¬Έμμ): λͺ¨λ μ΄λ―Έμ§κ° μλ ν΄λ κ²½λ‘.
transform (νΈμΆκ°λ₯ν ν¨μ, μ νμ 맀κ°λ³μ): μνμ μ μ© λ μ μλ μ νμ λ³ν.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:]
landmarks = np.array([landmarks])
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
######################################################################
# 1.3 λ°λ³΅λ¬Έμ ν΅ν λ°μ΄ν° μν μ¬μ©
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
######################################################################
# λ€μμΌλ‘λ μ΄ ν΄λμ€λ₯Ό μΈμ€ν΄μ€ννκ³ λ°μ΄ν° μνμ λ°λ³΅λ¬Έμ μ΄μ©νμ¬ μ¬μ©ν΄λ΄
μλ€.
# μ°λ¦¬λ 첫 4κ°μ μνλ€λ§ μΆλ ₯νκ³ κ·Έ 4κ° μνλ€μ λλλ§ν¬λ₯Ό 보μ¬μ£Όκ² μ΅λλ€.
#
face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
root_dir='faces/')
fig = plt.figure()
for i in range(len(face_dataset)):
sample = face_dataset[i]
print(i, sample['image'].shape, sample['landmarks'].shape)
ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample)
if i == 3:
plt.show()
break
######################################################################
# λ λ²μ§Έ: λ°μ΄ν° λ³ν
# ---------------------------
#
######################################################################
# μ°λ¦¬λ μ§κΈκΉμ§ μ΄λμ λ μ¬μ©μ μ μ λ°μ΄ν°μ
μ λ§λ€μ΄ 보μλλ° μ΄μ λ μ¬μ©μ μ μ λ³νμ λ§λ€ μ°¨λ‘ μ
λλ€.
# μ»΄ν¨ν° λΉμ μμλ μ¬μ©μ μ μ λ³νμ μκ³ λ¦¬μ¦μ μΌλ°νμν€κ³ μ νλλ₯Ό μ¬λ¦¬λλ° λμμ μ€λλ€.
# λ³νλ€μ νλ ¨μμ μ¬μ©μ΄ λλ©° μ£Όλ‘ λ°μ΄ν° μ¦κ°μΌλ‘ μ°Έμ‘°λλ©° μ΅κ·Όμ λͺ¨λΈ κ°λ°μμ νν μ¬μ©λ©λλ€.
#
# λ°μ΄ν°μ
μ λ€λ£°λ μμ£Ό μΌμ΄λλ λ¬Έμ μ€ νλλ λͺ¨λ μνλ€μ΄ κ°μ ν¬κΈ°λ₯Ό κ°μ§κ³ μμ§ μμ κ²½μ°μ
λλ€.
# λλΆλΆμ μ κ²½λ§λ€μ 미리 μ ν΄μ§ ν¬κΈ°μ μ΄λ―Έμ§λ€μ λ°μλ€μ
λλ€.
# κ·Έλ κΈ° λλ¬Έμ μ°λ¦¬λ μ μ²λ¦¬ μ½λλ₯Ό μμ±ν΄μΌν νμκ° μμ΅λλ€.
# μ΄μ μΈκ°μ λ³νμ λ§λ€μ΄ λ΄
μλ€.
#
# - ``Rescale``: μ΄λ―Έμ§ ν¬κΈ°λ₯Ό λ³κ²½ν λ μ¬μ©λ©λλ€.
# - ``RandomCrop``: 무μμλ‘ μ΄λ―Έμ§λ₯Ό μλΌλ΄λ©° λ°μ΄ν° μ¦κ°μ μ°μ
λλ€.
# - ``ToTensor``: Numpy μ΄λ―Έμ§λ€μ νμ΄ν μΉ μ΄λ―Έμ§λ‘ λ³νν λ μ¬μ©λ©λλ€. (κ·Έλ¬κΈ° μν΄μλ μ΄λ―Έμ§ μ°¨μμ μμλ₯Ό λ°κΏμΌν©λλ€.)
#
# μ°λ¦¬λ μμ μΈκ°μ λ³νλ€μ λ¨μν ν¨μ λμ μ νΈμΆκ°λ₯ν ν΄λμ€λ‘ λ§λ€μ΄μ λ§€λ² λ³νμ΄ νΈμΆλ λ νμ 맀κ°λ³μκ° λ겨μ§μ§ μλλ‘ ν κ²λλ€.
# κ·Έλ¬κΈ° μν΄μλ μ°λ¦¬λ λ¨μν ``__call__`` λ©μλλ₯Ό λ§λ€κ³ νμνλ€λ©΄ ``__init__`` λ₯Ό λ§λ€λ©΄ λ©λλ€.
# κ·Έλ¬λ©΄ μ°λ¦¬λ λ³νμ μ΄λ°μμΌλ‘ μ¬μ©ν μ μμ΅λλ€.
#
# ::
#
# tsfm = Transform(params)
# transformed_sample = tsfm(sample)
#
# μ΄λ»κ² μ΄λ° λ³νλ€μ΄ μ΄λ―Έμ§μ λλλ§ν¬μ μ μ©μ΄ λμλμ§ μλλ₯Ό λ΄μ£ΌμκΈΈ λ°λλλ€.
#
######################################################################
# 2.1 νΈμΆ κ°λ₯ν ν΄λμ€λ€ μμ±νκΈ°
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# κ°κ°μ λ³νμ λ§λ νΈμΆ κ°λ₯ν ν΄λμ€ μμ±μ μμν΄ λ΄
μλ€.
#
#
class Rescale(object):
""" μ£Όμ΄μ§ ν¬κΈ°λ‘ μνμμ μλ μ΄λ―Έμ§λ₯Ό μ¬λ³ν ν©λλ€.
Args:
output_size (tuple λλ int): μνλ κ²°κ³Όκ°μ ν¬κΈ°μ
λλ€.
tupleλ‘ μ£Όμ΄μ§λ€λ©΄ κ²°κ³Όκ°μ output_size μ λμΌν΄μΌνλ©°
intμΌλλ μ€μ λ κ°λ³΄λ€ μμ μ΄λ―Έμ§λ€μ κ°λ‘μ μΈλ‘λ output_size μ μ μ ν λΉμ¨λ‘ λ³νλ©λλ€.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
# h μ w λ μ΄λ―Έμ§μ λλλ§ν¬λ€ λλ¬Έμ μλ‘ λ°λλλ€.
# x μ y μΆλ€μ κ°κ° 1κ³Ό 0 κ°μ κ°μ§λλ€.
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'landmarks': landmarks}
class RandomCrop(object):
""" μνμ μλ μ΄λ―Έμ§λ₯Ό 무μμλ‘ μλ₯΄κΈ°.
Args:
output_size (tuple λλ int): μνλ κ²°κ³Όκ°μ ν¬κΈ°μ
λλ€.
intλ‘ μ€μ νμλ©΄ μ μ¬κ°ν ννλ‘ μλ₯΄κ² λ©λλ€.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h,
left: left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
""" μν μμ μλ nμ°¨μ λ°°μ΄μ Tensorλ‘ λ³ν₯νλλ€. """
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
# μκΉ μΆλ€μ λ°κΏμΉκΈ°ν΄μΌνλλ° κ·Έ μ΄μ λ numpyμ torchμ μ΄λ―Έμ§ ννλ°©μμ΄ λ€λ₯΄κΈ° λλ¬Έμ
λλ€.
# numpy μ΄λ―Έμ§: H x W x C
# torch μ΄λ―Έμ§: C X H X W
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image),
'landmarks': torch.from_numpy(landmarks)}
######################################################################
# 2.2 λ³νλ€μ ꡬμ±νκ³ μνμ μ μ©ν΄λ³΄κΈ°.
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# λ€μμλ μμ±ν΄μλ λ³νλ€μ ꡬμ±νκ³ μνμ μ μ©ν΄λ΄
μλ€.
#
#
# μ°λ¦¬κ° ν μ΄λ―Έμ§μ κ°λ‘λ μΈλ‘μ€μμ λμμ μͺ½μ 256μΌλ‘ ν¬κΈ°λ₯Ό λ°κΎΈκ³ μΆκ³
# λ°λ μ΄λ―Έμ§μμ 무μμνκ² κ°λ‘ μΈλ‘ μ λΆ 224λ‘ μλ₯΄κ³ μΆλ€κ³ μν©μ κ°μ ν΄λ΄
μλ€.
# μλ₯Όλ€λ©΄, μ°λ¦¬λ ``Rescale`` κ³Ό ``RandomCrop`` λ³νμ ꡬμ±ν΄μΌ ν©λλ€.
# ``torchvision.transforms.Compose`` λ κ°λ¨ν νΈμΆκ°λ₯ν ν΄λμ€λ‘ μ΄λ¬νκ²λ€μ μ°λ¦¬μκ² κ°λ₯νκ² ν΄μ€λλ€.
#
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
RandomCrop(224)])
# μμ μλ λ³νλ€μ κ°κ° μνμ μ μ© μν΅λλ€.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
transformed_sample = tsfrm(sample)
ax = plt.subplot(1, 3, i + 1)
plt.tight_layout()
ax.set_title(type(tsfrm).__name__)
show_landmarks(**transformed_sample)
plt.show()
######################################################################
# 2.3 λ°μ΄ν°μ
μ λ°λ³΅λ¬Έμ ν΅ν΄ μ¬μ©νκΈ°
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# λ€μμΌλ‘ μ°λ¦¬λ λ°μ΄ν°μ
μ λ°λ³΅λ¬Έμ ν΅ν΄ μ¬μ©ν΄λ³΄λλ‘ νκ² μ΅λλ€.
#
#
# μ΄μ μ΄ λͺ¨λ κ²μ λ€ κΊΌλ΄μ΄μ λ³νμ ꡬμ±νκ³ λ°μ΄ν°μ
μ λ§λ€μ΄λ΄
μλ€.
# μμ½νμλ©΄ νμ μ΄ λ°μ΄ν°μ
μ λ€μκ³Ό κ°μ΄ λΆλ¬μμ§λλ€.
#
# - μ΄λ―Έμ§λ μ½μΌλ €κ³ ν λλ§λ€ λΆλ¬μ΅λλ€.
# - λ³νλ€μ μ½μ μ΄λ―Έμ§μ μ μ©μ΄ λ©λλ€.
# - λ³νλ€μ€ νλλ 무μμλ₯Ό μ΄μ©νκΈ° λλ¬Έμ, λ°μ΄ν°λ μνλ§μ λ°λΌ μ¦κ°λ©λλ€.
#
# μ λ²μ ν΄λ³Έκ²μ²λΌ μμ±λ λ°μ΄ν°μ
μ ``for i in range`` μ΄λΌλ λ°λ³΅λ¬Έμ ν΅ν΄ μ¬μ©ν μ μμ΅λλ€.
#
transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
root_dir='faces/',
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor()
]))
for i in range(len(transformed_dataset)):
sample = transformed_dataset[i]
print(i, sample['image'].size(), sample['landmarks'].size())
if i == 3:
break
######################################################################
# μΈλ²μ§Έ: Dataloader
# ----------------------
#
######################################################################
# μ§μ μ μΌλ‘ λ°μ΄ν°μ
μ ``for`` λ°λ³΅λ¬ΈμΌλ‘ λ°μ΄ν°λ₯Ό μ΄μ©νλ건 λ§μ νΉμ±λ€μ λμΉ μ λ°μ μμ΅λλ€.
# νΉν, μ°λ¦¬λ λ€μκ³Ό κ°μ νΉμ±λ€μ λμΉλ€κ³ ν μ μμ΅λλ€.
#
# - λ°μ΄ν° λ°°μΉ
# - λ°μ΄ν° μκΈ°
# - ``multiprocessing`` λ₯Ό μ΄μ©νμ¬ λ³λ ¬μ μΌλ‘ λ°μ΄ν° λΆλ¬μ€κΈ°
#
# ``torch.utils.data.DataLoader`` λ λ°λ³΅μλ‘μ μμ λμμλ λͺ¨λ νΉμ±λ€μ μ 곡ν©λλ€.
# μλμ μ μλ μ¬μ©λλ 맀κ°λ³μλ€μ μ½κ² μ΄ν΄κ° λ κ²λλ€. ν₯λ―Έλ‘μ΄ λ°°κ°λ³μλ ``collate_fn`` μΈλ°
# μ΄κ²μ μ ννκ² ``collate_fn`` μ ν΅ν΄ λͺκ°μ μνλ€μ΄ λ°°μΉκ° λμ΄μΌνλμ§ μ§μ ν μ μμ΅λλ€.
# νμ§λ§ κ΅³μ΄ μμ νμ§ μμλ λλΆλΆμ κ²½μ°μλ μ μλν κ²λλ€.
#
dataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=4)
# λ°°μΉλ₯Ό 보μ¬μ£ΌκΈ°μν λμ ν¨μ
def show_landmarks_batch(sample_batched):
""" μνλ€μ λ°°μΉμμ μ΄λ―Έμ§μ ν¨κ» λλλ§ν¬λ₯Ό 보μ¬μ€λλ€. """
images_batch, landmarks_batch = \
sample_batched['image'], sample_batched['landmarks']
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
for i in range(batch_size):
plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
landmarks_batch[i, :, 1].numpy(),
s=10, marker='.', c='r')
plt.title('Batch from dataloader')
for i_batch, sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['image'].size(),
sample_batched['landmarks'].size())
# 4λ²μ§Έ λ°°μΉλ₯Ό 보μ¬μ£Όκ³ λ°λ³΅λ¬Έμ λ©μΆ₯λλ€.
if i_batch == 3:
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break
######################################################################
# μ΄μ PyTorchλ₯Ό μ΄μ©ν΄μ μ΄λ»κ² μ¬μ©μ μ μ dataloaderλ₯Ό λ§λλμ§ λ°°μ μ΅λλ€.
# μ ν¬λ μ’ λ κ΄λ ¨λ λ¬Έμλ€μ κΉκ² μ½μΌμ
μ λμ± λ§μΆ€νλ μμ
νλ¦Όμ κ°μ§κΈΈ μΆμ² λ립λλ€.
# λ λ°°μ보μλ €λ©΄ ``torch.utils.data`` λ¬Έμλ₯Ό `μ¬κΈ° <https://pytorch.org/docs/stable/data.html>`__ μμ μ½μ΄ λ³΄μ€ μ μμ΅λλ€.