-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathfeeder_ntu.py
122 lines (107 loc) · 4.82 KB
/
feeder_ntu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import pickle
from torch.utils.data import Dataset
from feeders import tools
class Feeder(Dataset):
def __init__(self, data_path, label_path=None, p_interval=1, split='train', repeat=1, random_choose=False, random_shift=False,
random_move=False, random_rot=False, window_size=64, normalization=False, debug=False, use_mmap=False,
vel=False, sort=False, A=None):
"""
:param data_path:
:param label_path:
:param split: training set or test set
:param random_choose: If true, randomly choose a portion of the input sequence
:param random_shift: If true, randomly pad zeros at the begining or end of sequence
:param random_move:
:param random_rot: rotate skeleton around xyz axis
:param window_size: The length of the output sequence
:param normalization: If true, normalize input sequence
:param debug: If true, only use the first 100 samples
:param use_mmap: If true, use mmap mode to load data, which can save the running memory
:param vel: use motion modality or not
:param only_label: only load label for ensemble score compute
"""
self.debug = debug
self.data_path = data_path
self.label_path = label_path
self.split = split
self.random_choose = random_choose
self.random_shift = random_shift
self.random_move = random_move
self.window_size = window_size
self.normalization = normalization
self.use_mmap = use_mmap
self.p_interval = p_interval
self.random_rot = random_rot
self.vel = vel
self.A = A
self.load_data()
if sort:
self.get_n_per_class()
self.sort()
if normalization:
self.get_mean_map()
def load_data(self):
# data: N C V T M
npz_data = np.load(self.data_path)
if self.split == 'train':
self.data = npz_data['x_train']
self.label = np.argmax(npz_data['y_train'], axis=-1)
elif self.split == 'test':
self.data = npz_data['x_test']
self.label = np.argmax(npz_data['y_test'], axis=-1)
else:
raise NotImplementedError('data split only supports train/test')
nan_out = np.isnan(self.data.mean(-1).mean(-1))==False
self.data = self.data[nan_out]
self.label = self.label[nan_out]
self.sample_name = [self.split + '_' + str(i) for i in range(len(self.data))]
N, T, _ = self.data.shape
if self.A is not None:
self.data = self.data.reshape((N*T*2, 25, 3))
self.data = np.array(self.A) @ self.data # x = N C T V M
self.data = self.data.reshape(N, T, 2, 25, 3).transpose(0, 4, 1, 3, 2)
# self.data -= self.data[:,:,:,1:2]
def get_n_per_class(self):
self.n_per_cls = np.zeros(len(self.label), dtype=int)
for label in self.label:
self.n_per_cls[label] += 1
self.csum_n_per_cls = np.insert(np.cumsum(self.n_per_cls), 0, 0)
def sort(self):
sorted_idx = self.label.argsort()
self.data = self.data[sorted_idx]
self.label = self.label[sorted_idx]
def get_mean_map(self):
data = self.data
N, C, T, V, M = data.shape
self.mean_map = data.mean(axis=2, keepdims=True).mean(axis=4, keepdims=True).mean(axis=0)
self.std_map = data.transpose((0, 2, 4, 1, 3)).reshape((N * T * M, C * V)).std(axis=0).reshape((C, 1, V, 1))
def __len__(self):
return len(self.label)
def __iter__(self):
return self
def __getitem__(self, index):
data_numpy = self.data[index]
label = self.label[index]
data_numpy = np.array(data_numpy)
valid_frame = data_numpy.sum(0, keepdims=True).sum(2, keepdims=True)
valid_frame_num = np.sum(np.squeeze(valid_frame).sum(-1) != 0)
# reshape Tx(MVC) to CTVM
data_numpy = tools.valid_crop_resize(data_numpy, valid_frame_num, self.p_interval, self.window_size)
mask = (abs(data_numpy.sum(0, keepdims=True).sum(2, keepdims=True)) > 0)
if self.random_rot:
data_numpy = tools.random_rot(data_numpy)
if self.vel:
data_numpy[:, :-1] = data_numpy[:, 1:] - data_numpy[:, :-1]
data_numpy[:, -1] = 0
return data_numpy, label, mask, index
def top_k(self, score, top_k):
rank = score.argsort()
hit_top_k = [l in rank[i, -top_k:] for i, l in enumerate(self.label)]
return sum(hit_top_k) * 1.0 / len(hit_top_k)
def import_class(name):
components = name.split('.')
mod = __import__(components[0])
for comp in components[1:]:
mod = getattr(mod, comp)
return mod