-
Notifications
You must be signed in to change notification settings - Fork 86
/
text_model.py
127 lines (103 loc) · 3.92 KB
/
text_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class for text data."""
import string
import numpy as np
import torch
class SimpleVocab(object):
def __init__(self):
super(SimpleVocab, self).__init__()
self.word2id = {}
self.wordcount = {}
self.word2id['<UNK>'] = 0
self.wordcount['<UNK>'] = 9e9
def tokenize_text(self, text):
text = text.encode('ascii', 'ignore').decode('ascii')
tokens = str(text).lower()
tokens = tokens.translate(str.maketrans('','',string.punctuation))
tokens = tokens.strip().split()
return tokens
def build(self, texts):
for text in texts:
tokens = self.tokenize_text(text)
for token in tokens:
if token not in self.wordcount:
self.wordcount[token] = 0
self.wordcount[token] += 1
for token in sorted(list(self.wordcount.keys())):
if token not in self.word2id:
self.word2id[token] = len(self.word2id)
def threshold_rare_words(self, wordcount_threshold=5):
for w in self.word2id:
if self.wordcount[w] < wordcount_threshold:
self.word2id[w] = 0
def encode_text(self, text):
tokens = self.tokenize_text(text)
x = [self.word2id.get(t, 0) for t in tokens]
return x
def get_size(self):
return len(self.word2id)
class TextLSTMModel(torch.nn.Module):
def __init__(self,
texts_to_build_vocab,
word_embed_dim=512,
lstm_hidden_dim=512):
super(TextLSTMModel, self).__init__()
self.vocab = SimpleVocab()
self.vocab.build(texts_to_build_vocab)
vocab_size = self.vocab.get_size()
self.word_embed_dim = word_embed_dim
self.lstm_hidden_dim = lstm_hidden_dim
self.embedding_layer = torch.nn.Embedding(vocab_size, word_embed_dim)
self.lstm = torch.nn.LSTM(word_embed_dim, lstm_hidden_dim)
self.fc_output = torch.nn.Sequential(
torch.nn.Dropout(p=0.1),
torch.nn.Linear(lstm_hidden_dim, lstm_hidden_dim),
)
def forward(self, x):
""" input x: list of strings"""
if type(x) is list:
if type(x[0]) is str or type(x[0]) is unicode:
x = [self.vocab.encode_text(text) for text in x]
assert type(x) is list
assert type(x[0]) is list
assert type(x[0][0]) is int
return self.forward_encoded_texts(x)
def forward_encoded_texts(self, texts):
# to tensor
lengths = [len(t) for t in texts]
itexts = torch.zeros((np.max(lengths), len(texts))).long()
for i in range(len(texts)):
itexts[:lengths[i], i] = torch.tensor(texts[i])
# embed words
itexts = torch.autograd.Variable(itexts).cuda()
etexts = self.embedding_layer(itexts)
# lstm
lstm_output, _ = self.forward_lstm_(etexts)
# get last output (using length)
text_features = []
for i in range(len(texts)):
text_features.append(lstm_output[lengths[i] - 1, i, :])
# output
text_features = torch.stack(text_features)
text_features = self.fc_output(text_features)
return text_features
def forward_lstm_(self, etexts):
batch_size = etexts.shape[1]
first_hidden = (torch.zeros(1, batch_size, self.lstm_hidden_dim),
torch.zeros(1, batch_size, self.lstm_hidden_dim))
first_hidden = (first_hidden[0].cuda(), first_hidden[1].cuda())
lstm_output, last_hidden = self.lstm(etexts, first_hidden)
return lstm_output, last_hidden