-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_test.py
65 lines (51 loc) · 1.55 KB
/
train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import math
import torch
from torchinfo import summary
from model import create_net
from dataset import create_dataset
from config import parser
from apis import run
from utils import set_seed, save_csv
def main():
args = parser.parse_args()
set_seed(args.seed)
args.recordPath = args.modelPath
args.im_width, args.im_height = (128 // args.ds, 128 // args.ds)
if args.dataset == "action" or args.dataset == "recogition":
args.im_width, args.im_height = (math.ceil(346 / args.ds), 260 // args.ds)
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.device_ids = range(torch.cuda.device_count())
args.modelPath = args.modelPath + args.attention
args.name = (
args.dataset
+ "_dt="
+ str(args.dt)
+ "ms"
+ "_T="
+ str(args.T)
+ "_attn="
+ args.attention
+ "_reduc="
+ str(args.reduction)
+ "_lam="
+ str(args.lam)
+ "_seed="
+ str(args.seed)
+ "_arch="
+ str(args.arch)
)
args.modelNames = args.name + ".pth"
args.recordNames = args.name + ".csv"
print(args.name)
create_dataset(args=args)
create_net(args=args)
summary(
args.model,
(2, args.T, args.in_channels, args.im_height, args.im_width),
depth=3,
)
run(args=args)
print("best acc:", args.best_acc, "best_epoch:", args.best_epoch)
save_csv(args=args)
if __name__ == "__main__":
main()