This repository has been archived by the owner on May 22, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 62
/
dataset.py
144 lines (113 loc) · 6.49 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
from pathlib import Path
import cv2
import numpy as np
from torch.utils.data import DataLoader, Dataset
class VideoDataset(Dataset):
r"""A Dataset for a folder of videos. Expects the directory structure to be
directory->[train/val/test]->[class labels]->[videos]. Initializes with a list
of all file names, along with an array of labels, with label being automatically
inferred from the respective folder names.
Args:
directory (str): The path to the directory containing the train/val/test datasets
mode (str, optional): Determines which folder of the directory the dataset will read from. Defaults to 'train'.
clip_len (int, optional): Determines how many frames are there in each clip. Defaults to 8.
"""
def __init__(self, directory, mode='train', clip_len=8):
folder = Path(directory)/mode # get the directory of the specified split
self.clip_len = clip_len
# the following three parameters are chosen as described in the paper section 4.1
self.resize_height = 128
self.resize_width = 171
self.crop_size = 112
# obtain all the filenames of files inside all the class folders
# going through each class folder one at a time
self.fnames, labels = [], []
for label in sorted(os.listdir(folder)):
for fname in os.listdir(os.path.join(folder, label)):
self.fnames.append(os.path.join(folder, label, fname))
labels.append(label)
# prepare a mapping between the label names (strings) and indices (ints)
self.label2index = {label:index for index, label in enumerate(sorted(set(labels)))}
# convert the list of label names into an array of label indices
self.label_array = np.array([self.label2index[label] for label in labels], dtype=int)
def __getitem__(self, index):
# loading and preprocessing. TODO move them to transform classes
buffer = self.loadvideo(self.fnames[index])
buffer = self.crop(buffer, self.clip_len, self.crop_size)
buffer = self.normalize(buffer)
return buffer, self.label_array[index]
def loadvideo(self, fname):
# initialize a VideoCapture object to read video data into a numpy array
capture = cv2.VideoCapture(fname)
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# create a buffer. Must have dtype float, so it gets converted to a FloatTensor by Pytorch later
buffer = np.empty((frame_count, self.resize_height, self.resize_width, 3), np.dtype('float32'))
count = 0
retaining = True
# read in each frame, one at a time into the numpy buffer array
while (count < frame_count and retaining):
retaining, frame = capture.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# will resize frames if not already final size
# NOTE: strongly recommended to resize them during the download process. This script
# will process videos of any size, but will take longer the larger the video file.
if (frame_height != self.resize_height) or (frame_width != self.resize_width):
frame = cv2.resize(frame, (self.resize_width, self.resize_height))
buffer[count] = frame
count += 1
# release the VideoCapture once it is no longer needed
capture.release()
# convert from [D, H, W, C] format to [C, D, H, W] (what PyTorch uses)
# D = Depth (in this case, time), H = Height, W = Width, C = Channels
buffer = buffer.transpose((3, 0, 1, 2))
return buffer
def crop(self, buffer, clip_len, crop_size):
# randomly select time index for temporal jittering
time_index = np.random.randint(buffer.shape[1] - clip_len)
# randomly select start indices in order to crop the video
height_index = np.random.randint(buffer.shape[2] - crop_size)
width_index = np.random.randint(buffer.shape[3] - crop_size)
# crop and jitter the video using indexing. The spatial crop is performed on
# the entire array, so each frame is cropped in the same location. The temporal
# jitter takes place via the selection of consecutive frames
buffer = buffer[:, time_index:time_index + clip_len,
height_index:height_index + crop_size,
width_index:width_index + crop_size]
return buffer
def normalize(self, buffer):
# Normalize the buffer
# NOTE: Default values of RGB images normalization are used, as precomputed
# mean and std_dev values (akin to ImageNet) were unavailable for Kinetics. Feel
# free to push to and edit this section to replace them if found.
buffer = (buffer - 128)/128
return buffer
def __len__(self):
return len(self.fnames)
class VideoDataset1M(VideoDataset):
r"""Dataset that implements VideoDataset, and produces exactly 1M augmented
training samples every epoch.
Args:
directory (str): The path to the directory containing the train/val/test datasets
mode (str, optional): Determines which folder of the directory the dataset will read from. Defaults to 'train'.
clip_len (int, optional): Determines how many frames are there in each clip. Defaults to 8.
"""
def __init__(self, directory, mode='train', clip_len=8):
# Initialize instance of original dataset class
super(VideoDataset1M, self).__init__(directory, mode, clip_len)
def __getitem__(self, index):
# if we are to have 1M samples on every pass, we need to shuffle
# the index to a number in the original range, or else we'll get an
# index error. This is a legitimate operation, as even with the same
# index being used multiple times, it'll be randomly cropped, and
# be temporally jitterred differently on each pass, properly
# augmenting the data.
index = np.random.randint(len(self.fnames))
buffer = self.loadvideo(self.fnames[index])
buffer = self.crop(buffer, self.clip_len, self.crop_size)
buffer = self.normalize(buffer)
return buffer, self.label_array[index]
def __len__(self):
return 1000000 # manually set the length to 1 million