-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
66 lines (51 loc) · 1.97 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import warnings
import torch.nn as nn
from test import test
from train import train
import torch.optim as optim
from model import TextClassifier
from torchtext import data, datasets
from torchtext.vocab import Vectors
EPOCHS = 10
PATIENCE = 2
DROPOUT = 0.2
BATCH_SIZE = 64
HIDDEN_SIZE = 256
HIDDEN_LAYERS = 2
PRETRAINED = True
DATA_DIR = '.data/sst/trees'
IN_FILES = ['train.txt', 'dev.txt', 'test.txt']
CSV_FILES = ['train.csv', 'dev.csv', 'test.csv']
warnings.filterwarnings('ignore', category=UserWarning) # torchtext deprecation warnings
def process_SST():
TEXT = data.Field(lower=True, include_lengths=True, batch_first=True)
LABEL = data.Field(sequential=False, dtype=torch.long)
train_set, val_set, test_set = datasets.SST.splits(
TEXT, LABEL, fine_grained=True, train_subtrees=False
)
print('train_set size:\t', len(train_set))
print('val_set size:\t', len(val_set))
print('test_set size:\t', len(test_set))
TEXT.build_vocab(train_set, vectors=Vectors(name='vector.txt', cache='./word-embeddings'))
LABEL.build_vocab(train_set)
print('vocab size:\t', len(TEXT.vocab))
print('embedding dim:\t', TEXT.vocab.vectors.size()[1])
train_iter, val_iter, test_iter = data.BucketIterator.splits(
(train_set, val_set, test_set), batch_size=BATCH_SIZE
)
return TEXT, LABEL, train_iter, val_iter, test_iter
if __name__ == '__main__':
print('Processing SST Dataset...')
TEXT, LABEL, train_iter, val_iter, test_iter = process_SST()
print('Creating TextClassifier Model...')
model = TextClassifier(HIDDEN_LAYERS, HIDDEN_SIZE, DROPOUT, TEXT, LABEL, PRETRAINED)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
print('Training...')
loss_hist, accu_hist = train(
model, train_iter, val_iter, criterion, optimizer, EPOCHS, BATCH_SIZE, PATIENCE
)
print('Testing...')
accu = test(model, test_iter, BATCH_SIZE)
print(f'test_acc={accu:.4f}')