-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
109 lines (96 loc) · 4.02 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
from torch_geometric.data import Dataset
from models.multi_view import MultiViewActionRecognizer
from models.single_view import SingleViewActionRecognizer
from data_mgmt.dataloaders.multi_dataloader import DataLoader as MultiDataLoader
from data_mgmt.dataloaders.single_dataloader import DataLoader as SingleDataLoader
from typing import Tuple, Dict
def get_multi_view(
config: Dict,
args: argparse.Namespace,
dataset: Tuple[Dataset, Dataset, Dataset],
) -> Tuple[
MultiViewActionRecognizer, Tuple[MultiDataLoader, MultiDataLoader, MultiDataLoader]
]:
"""
Returns the model and the dataloaders for the multi view case
Parameters
----------
args : argparse.Namespace
Arguments passed to the program
train_dataset : Dataset
Training dataset
val_dataset : Dataset
Validation dataset
test_dataset : Dataset
Testing dataset
Returns
-------
Tuple[MultiViewActionRecognizer, Tuple[MultiDataLoader, MultiDataLoader, MultiDataLoader]]
Tuple containing the model and the dataloaders for the training, validation and testing datasets
"""
train_dataset, val_dataset, test_dataset = dataset
train_loader = MultiDataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True
)
val_loader = MultiDataLoader(val_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = MultiDataLoader(
test_dataset, batch_size=args.batch_size, shuffle=True
)
return MultiViewActionRecognizer(
gcn_num_features=config["gcn_num_features"],
gcn_hidden_dim1=config["gcn_hidden_dim1"],
gcn_hidden_dim2=config["gcn_hidden_dim2"],
gcn_output_dim=config["gcn_output_dim"],
transformer_d_model=config["transformer_d_model"],
transformer_nhead=config["transformer_nhead"],
transformer_num_layers=config["transformer_num_layers"],
transformer_num_features=config["transformer_num_features"],
transformer_dropout=config["transformer_dropout"],
transformer_dim_feedforward=config["transformer_dim_feedforward"],
transformer_num_classes=config["transformer_num_classes"],
aggregator=args.aggregator,
), (train_loader, val_loader, test_loader)
def get_single_view(
config: Dict,
args: argparse.Namespace,
dataset: Tuple[Dataset, Dataset, Dataset],
) -> Tuple[
SingleViewActionRecognizer, Tuple[MultiDataLoader, MultiDataLoader, MultiDataLoader]
]:
"""
Returns the model and the dataloaders for the single view case
Parameters
----------
args : argparse.Namespace
Arguments passed to the program
train_dataset : Dataset
Training dataset
val_dataset : Dataset
Validation dataset
Returns
-------
Tuple[SingleViewActionRecognizer, Tuple[MultiDataLoader, MultiDataLoader, MultiDataLoader]]
Tuple containing the model and the dataloaders for the training, validation and testing datasets
"""
train_dataset, val_dataset, test_dataset = dataset
train_loader = SingleDataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True
)
val_loader = SingleDataLoader(val_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = SingleDataLoader(
test_dataset, batch_size=args.batch_size, shuffle=True
)
return SingleViewActionRecognizer(
gcn_num_features=config["gcn_num_features"],
gcn_hidden_dim1=config["gcn_hidden_dim1"],
gcn_hidden_dim2=config["gcn_hidden_dim2"],
gcn_output_dim=config["gcn_output_dim"],
transformer_d_model=config["transformer_d_model"],
transformer_nhead=config["transformer_nhead"],
transformer_num_layers=config["transformer_num_layers"],
transformer_num_features=config["transformer_num_features"],
transformer_dropout=config["transformer_dropout"],
transformer_dim_feedforward=config["transformer_dim_feedforward"],
transformer_num_classes=config["transformer_num_classes"],
), (train_loader, val_loader, test_loader)