-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_ConvLSTM.py
105 lines (82 loc) · 3.77 KB
/
main_ConvLSTM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python3
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from Models.ConvLSTM.ConvLSTM import ConvLSTM
from Models.ConvLSTM.EegDataset import EegDataset
writer = SummaryWriter('ConvLSTM_tensorboard')
batch_size = 128
lr = 1e-4
epochs = 50
train_dir = './data/train'
val_dir = './data/val'
train_data_file = 'train_data_ConvLSTM.pt'
train_label_file = 'train_label_ConvLSTM.pt'
val_data_file = 'val_data_ConvLSTM.pt'
val_label_file = 'val_label_ConvLSTM.pt'
save_path = './Models/ConvLSTM/weights/ConvLSTM.pth'
if not os.path.exists('./Models/ConvLSTM/weights/'):
os.system('mkdir -p ./Models/ConvLSTM/weights/')
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using {} device".format(device))
num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
# num_workers = 0
print("Using {} dataloader workers every process".format(num_workers))
train_dataset = EegDataset(os.path.join(train_dir, train_data_file),
os.path.join(train_dir, train_label_file))
val_dataset = EegDataset(os.path.join(val_dir, val_data_file),
os.path.join(val_dir, val_label_file))
train_num = len(train_dataset)
val_num = len(val_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
print("Load dataset: \n"
"Train Dataset: %d samples\n"
"val Dataset: %d samples" % (len(train_dataset), len(val_dataset)))
example_data, example_label = train_dataset[0]
n_channel, h, w = example_data.shape
model = ConvLSTM(input_channels=640, hidden_channels=[256, 64, 64, 32, 32], kernel_size=3, step=5,
effective_step=[4], h=9, w=9).cuda()
model.to(device=device)
# define loss function
loss_function = torch.nn.CrossEntropyLoss()
# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=lr)
best_acc = 0.0
train_steps = len(train_dataloader)
for epoch in range(epochs):
# train
model.train()
running_loss = 0.0
train_bar = tqdm(train_dataloader)
for step, data in enumerate(train_bar):
signals, labels = data
optimizer.zero_grad()
logits = model(signals.to(device))
loss = loss_function(logits, labels.to(device).to(torch.long))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = "train epoch [{}/{}] loss: {:.3f}".format(epoch+1, epochs, loss)
model.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(val_dataloader)
for val_data in val_bar:
val_signals, val_labels = val_data
outputs = model(val_signals.to(device))
predict = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict, val_labels.to(device)).sum().item()
val_bar.desc = "val epoch[{}/{}]".format(epoch+1, epochs)
val_accuracy = acc/val_num
print("[epoch %d] train_loss: %.3f val_accuracy: %.3f" % (epoch+1, running_loss/train_steps, val_accuracy))
writer.add_scalar(tag='val_acc:', scalar_value=val_accuracy, global_step=epoch)
writer.add_scalar(tag='loss', scalar_value=running_loss / train_steps, global_step=epoch)
if val_accuracy > best_acc:
best_acc = val_accuracy
torch.save(model.state_dict(), save_path)