-
Notifications
You must be signed in to change notification settings - Fork 4
/
data.py
51 lines (45 loc) · 2.18 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from __future__ import print_function
import zipfile
import os
import torchvision.transforms as transforms
# once the images are loaded, how do we pre-process them before being passed into the network
# by default, we resize the images to 32 x 32 in size
# and normalize them to mean = 0 and standard-deviation = 1 based on statistics collected from
# the training set
data_transforms = transforms.Compose([
transforms.Scale((48, 48)),
transforms.ToTensor(),
transforms.Normalize((0.3337, 0.3064, 0.3171), ( 0.2672, 0.2564, 0.2629))
])
def initialize_data(folder):
train_zip = folder + '/train_images.zip'
test_zip = folder + '/test_images.zip'
if not os.path.exists(train_zip) or not os.path.exists(test_zip):
raise(RuntimeError("Could not find " + train_zip + " and " + test_zip
+ ', please download them from https://www.kaggle.com/c/nyu-cv-fall-2017/data '))
# extract train_data.zip to train_data
train_folder = folder + '/train_images'
if not os.path.isdir(train_folder):
print(train_folder + ' not found, extracting ' + train_zip)
zip_ref = zipfile.ZipFile(train_zip, 'r')
zip_ref.extractall(folder)
zip_ref.close()
# extract test_data.zip to test_data
test_folder = folder + '/test_images'
if not os.path.isdir(test_folder):
print(test_folder + ' not found, extracting ' + test_zip)
zip_ref = zipfile.ZipFile(test_zip, 'r')
zip_ref.extractall(folder)
zip_ref.close()
# make validation_data by using images 00000*, 00001* and 00002* in each class
val_folder = folder + '/val_images'
if not os.path.isdir(val_folder):
print(val_folder + ' not found, making a validation set')
os.mkdir(val_folder)
for dirs in os.listdir(train_folder):
if dirs.startswith('000'):
os.mkdir(val_folder + '/' + dirs)
for f in os.listdir(train_folder + '/' + dirs):
if f.startswith('00000') or f.startswith('00001') or f.startswith('00002'):
# move file to validation folder
os.rename(train_folder + '/' + dirs + '/' + f, val_folder + '/' + dirs + '/' + f)