-
Notifications
You must be signed in to change notification settings - Fork 259
/
Copy pathtext_sentiment_ngrams_tutorial.py
387 lines (316 loc) Β· 16.2 KB
/
text_sentiment_ngrams_tutorial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
"""
torchtext λΌμ΄λΈλ¬λ¦¬λ‘ ν
μ€νΈ λΆλ₯νκΈ°
===============================================
**λ²μ**: `κΉκ°λ―Ό <https://github.com/gangsss>`_ , `κΉμ§ν <https://github.com/lewha0>`_
μ΄ νν 리μΌμμλ torchtext λΌμ΄λΈλ¬λ¦¬λ₯Ό μ¬μ©νμ¬ μ΄λ»κ² ν
μ€νΈ λΆλ₯ λΆμμ μν λ°μ΄ν°μ
μ λ§λλμ§λ₯Ό μ΄ν΄λ³΄κ² μ΅λλ€.
λ€μκ³Ό κ°μ λ΄μ©λ€μ μκ² λ©λλ€:
- λ°λ³΅μ(iterator)λ‘ κ°κ³΅λμ§ μμ λ°μ΄ν°(raw data)μ μ κ·ΌνκΈ°
- κ°κ³΅λμ§ μμ ν
μ€νΈ λ¬Έμ₯λ€μ λͺ¨λΈ νμ΅μ μ¬μ©ν μ μλ ``torch.Tensor`` λ‘ λ³ννλ λ°μ΄ν° μ²λ¦¬ νμ΄νλΌμΈ λ§λ€κΈ°
- `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`__ λ₯Ό μ¬μ©νμ¬ λ°μ΄ν°λ₯Ό μκ³ λ°λ³΅νκΈ°(shuffle and iterate)
μ¬μ μꡬ μ¬ν
~~~~~~~~~~~~~~~~
μ΄ νν 리μΌμ μ€ννκΈ° μν΄μλ λ¨Όμ 2.x λ²μ μ μ΅μ ``portalocker`` ν¨ν€μ§κ° μ€μΉλμ΄ μμ΄μΌ ν©λλ€.
μλ₯Ό λ€μ΄, Colab νκ²½μμλ λ€μκ³Ό κ°μ΄ μ€ν¬λ¦½νΈ 맨 μμ λ€μ μ€μ μΆκ°νμ¬ μ€μΉν μ μμ΅λλ€:
.. code-block:: bash
!pip install -U portalocker>=2.0.0`
"""
######################################################################
# κΈ°μ΄ λ°μ΄ν°μ
λ°λ³΅μ(raw data iterator)μ μ κ·ΌνκΈ°
# -------------------------------------------------------------
#
# torchtext λΌμ΄λΈλ¬λ¦¬λ κ°κ³΅λμ§ μμ ν
μ€νΈ λ¬Έμ₯λ€μ λ§λλ(yield) λͺ κ°μ§ κΈ°μ΄ λ°μ΄ν°μ
λ°λ³΅μ(raw dataset iterator)λ₯Ό μ 곡ν©λλ€.
# μλ₯Ό λ€μ΄, ``AG_NEWS`` λ°μ΄ν°μ
λ°λ³΅μλ λ μ΄λΈ(label)κ³Ό λ¬Έμ₯μ νν(tuple) ννλ‘ κ°κ³΅λμ§ μμ λ°μ΄ν°λ₯Ό λ§λλλ€.
#
# torchtext λ°μ΄ν°μ
μ μ κ·ΌνκΈ° μ μ, https://github.com/pytorch/data μ μ°Έκ³ νμ¬ torchdataλ₯Ό
# μ€μΉνμκΈ° λ°λλλ€.
#
import torch
from torchtext.datasets import AG_NEWS
train_iter = iter(AG_NEWS(split="train"))
######################################################################
# .. code-block:: sh
#
# next(train_iter)
# >>> (3, "Fears for T N pension after talks Unions representing workers at Turner
# Newall say they are 'disappointed' after talks with stricken parent firm Federal
# Mogul.")
#
# next(train_iter)
# >>> (4, "The Race is On: Second Private Team Sets Launch Date for Human
# Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\\team of
# rocketeers competing for the #36;10 million Ansari X Prize, a contest
# for\\privately funded suborbital space flight, has officially announced
# the first\\launch date for its manned rocket.")
#
# next(train_iter)
# >>> (4, 'Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded
# by a chemistry researcher at the University of Louisville won a grant to develop
# a method of producing better peptides, which are short chains of amino acids, the
# building blocks of proteins.')
#
######################################################################
# λ°μ΄ν° μ²λ¦¬ νμ΄νλΌμΈ μ€λΉνκΈ°
# ---------------------------------
#
# μ΄νμ§(vocab), λ¨μ΄ 벑ν°(word vector), ν ν¬λμ΄μ (tokenizer)λ₯Ό ν¬ν¨νμ¬ torchtext λΌμ΄λΈλ¬λ¦¬μ κ°μ₯ κΈ°λ³Έμ μΈ κ΅¬μ±μμλ₯Ό μ¬κ²ν νμ΅λλ€.
# μ΄λ€μ κ°κ³΅λμ§ μμ ν
μ€νΈ λ¬Έμμ΄μ λν κΈ°λ³Έμ μΈ λ°μ΄ν° μ²λ¦¬ λΉλ© λΈλ‘(data processing building block)μ
λλ€.
#
# λ€μμ ν ν¬λμ΄μ λ° μ΄νμ§μ μ¬μ©ν μΌλ°μ μΈ NLP λ°μ΄ν° μ²λ¦¬μ μμ
λλ€.
# 첫λ²μ§Έ λ¨κ³λ κ°κ³΅λμ§ μμ νμ΅ λ°μ΄ν°μ
μΌλ‘ μ΄νμ§μ λ§λλ κ²μ
λλ€.
# μ¬κΈ°μμλ ν ν°μ λͺ©λ‘ λλ λ°λ³΅μλ₯Ό λ°λ λ΄μ₯(built-in) ν©ν 리 ν¨μ(factory function) `build_vocab_from_iterator` λ₯Ό μ¬μ©ν©λλ€.
# μ¬μ©μλ μ΄νμ§μ μΆκ°ν νΉμ κΈ°νΈ(special symbol) κ°μ κ²λ€μ μ λ¬ν μλ μμ΅λλ€.
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
tokenizer = get_tokenizer("basic_english")
train_iter = AG_NEWS(split="train")
def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
######################################################################
# μ΄νμ§ λΈλ‘(vocabulary block)μ ν ν° λͺ©λ‘μ μ μλ‘ λ³νν©λλ€.
#
# .. code-block:: sh
#
# vocab(['here', 'is', 'an', 'example'])
# >>> [475, 21, 30, 5297]
#
# ν ν¬λμ΄μ μ μ΄νμ§μ κ°μΆ ν
μ€νΈ μ²λ¦¬ νμ΄νλΌμΈμ μ€λΉν©λλ€.
# ν
μ€νΈ νμ΄νλΌμΈκ³Ό λ μ΄λΈ(label) νμ΄νλΌμΈμ λ°μ΄ν°μ
λ°λ³΅μλ‘λΆν° μ»μ΄μ¨ κ°κ³΅λμ§ μμ λ¬Έμ₯ λ°μ΄ν°λ₯Ό μ²λ¦¬νκΈ° μν΄ μ¬μ©λ©λλ€.
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
######################################################################
# ν
μ€νΈ νμ΄νλΌμΈμ μ΄νμ§μ μ μλ 룩μ
ν
μ΄λΈ(μλν; lookup table)μ κΈ°λ°νμ¬ ν
μ€νΈ λ¬Έμ₯μ μ μ λͺ©λ‘μΌλ‘ λ³νν©λλ€.
# λ μ΄λΈ(label) νμ΄νλΌμΈμ λ μ΄λΈμ μ μλ‘ λ³νν©λλ€. μλ₯Ό λ€μ΄,
#
# .. code-block:: sh
#
# text_pipeline('here is the an example')
# >>> [475, 21, 2, 30, 5297]
# label_pipeline('10')
# >>> 9
#
######################################################################
# λ°μ΄ν° λ°°μΉ(batch)μ λ°λ³΅μ μμ±νκΈ°
# ----------------------------------------
#
# `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`__ λ₯Ό
# κΆμ₯ν©λλ€. (νν 리μΌμ `μ¬κΈ° <https://tutorials.pytorch.kr/beginner/data_loading_tutorial.html>`__ μμ΅λλ€.)
# μ΄λ ``getitem()`` κ³Ό ``len()`` νλ‘ν μ½μ ꡬνν 맡 νν(map-style)μ λ°μ΄ν°μ
μΌλ‘ λμνλ©°, 맡(map)μ²λΌ μΈλ±μ€/ν€λ‘ λ°μ΄ν° μνμ μ»μ΄μ΅λλ€.
# λν, μ
ν(shuffle) μΈμλ₯Ό ``False`` λ‘ μ€μ νλ©΄ μν κ°λ₯ν(iterable) λ°μ΄ν°μ
μ²λΌ λμν©λλ€.
#
# λͺ¨λΈλ‘ 보λ΄κΈ° μ , ``collate_fn`` ν¨μλ ``DataLoader`` λ‘λΆν° μμ±λ μν λ°°μΉλ‘ λμν©λλ€.
# ``collate_fn`` μ μ
λ ₯μ ``DataLoader`` μ λ°°μΉ ν¬κΈ°(batch size)κ° μλ λ°°μΉ(batch) λ°μ΄ν°μ΄λ©°,
# ``collate_fn`` μ μ΄λ₯Ό 미리 μ μΈλ λ°μ΄ν° μ²λ¦¬ νμ΄νλΌμΈμ λ°λΌ μ²λ¦¬ν©λλ€.
# ``collate_fn`` μ΄ μ΅μμ μμ€μΌλ‘ μ μ(top level def)λμλμ§ νμΈν©λλ€. μ΄λ κ² νλ©΄ λͺ¨λ μ컀μμ μ΄ ν¨μλ₯Ό μ¬μ©ν μ μμ΅λλ€.
#
# μλ μμ μμ, μ£Όμ΄μ§(original) λ°μ΄ν° λ°°μΉμ ν
μ€νΈ νλͺ©λ€μ 리μ€νΈ(list)μ λ΄κΈ΄(pack) λ€ ``nn.EmbeddingBag`` μ μ
λ ₯μ μν νλμ tensorλ‘ ν©μ³(concatenate)μ§λλ€.
# μ€νμ
(offset)μ ν
μ€νΈ tensorμμ κ°λ³ μνμ€ μμ μΈλ±μ€λ₯Ό νννκΈ° μν ꡬλΆμ(delimiter) tensorμ
λλ€.
# λ μ΄λΈ(label)μ κ°λ³ ν
μ€νΈ νλͺ©μ λ μ΄λΈμ μ μ₯νλ tensorμ
λλ€.
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for _label, _text in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(processed_text)
offsets.append(processed_text.size(0))
label_list = torch.tensor(label_list, dtype=torch.int64)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text_list = torch.cat(text_list)
return label_list.to(device), text_list.to(device), offsets.to(device)
train_iter = AG_NEWS(split="train")
dataloader = DataLoader(
train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch
)
######################################################################
# λͺ¨λΈ μ μνκΈ°
# ---------------
#
# λͺ¨λΈμ
# `nn.EmbeddingBag <https://pytorch.org/docs/stable/nn.html?highlight=embeddingbag#torch.nn.EmbeddingBag>`__
# λ μ΄μ΄μ λΆλ₯(classification) λͺ©μ μ μν μ ν λ μ΄μ΄λ‘ ꡬμ±λ©λλ€.
# κΈ°λ³Έ λͺ¨λκ° "νκ· (mean)"μΈ ``nn.EmbeddingBag`` μ μλ² λ©λ€μ "κ°λ°©(bag)"μ νκ· κ°μ κ³μ°ν©λλ€.
# μ΄λ ν
μ€νΈ(text) νλͺ©λ€μ κ°κΈ° κ·Έ κΈΈμ΄κ° λ€λ₯Ό μ μμ§λ§, ``nn.EmbeddingBag`` λͺ¨λμ ν
μ€νΈμ κΈΈμ΄λ₯Ό
# μ€νμ
(offset)μΌλ‘ μ μ₯νκ³ μμΌλ―λ‘ ν¨λ©(padding)μ΄ νμνμ§λ μμ΅λλ€.
#
# λ§λΆμ¬μ, ``nn.EmbeddingBag`` μ μλ² λ©μ νκ· μ μ¦μ κ³μ°νκΈ° λλ¬Έμ,
# tensorλ€μ μνμ€λ₯Ό μ²λ¦¬ν λ μ±λ₯ λ° λ©λͺ¨λ¦¬ ν¨μ¨μ± μΈ‘λ©΄μμμ μ₯μ λ
# κ°κ³ μμ΅λλ€.
#
# .. image:: ../_static/img/text_sentiment_ngrams_model.png
#
from torch import nn
class TextClassificationModel(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
super(TextClassificationModel, self).__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
self.fc = nn.Linear(embed_dim, num_class)
self.init_weights()
def init_weights(self):
initrange = 0.5
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
######################################################################
# μΈμ€ν΄μ€ μμ±νκΈ°
# -------------------
#
# ``AG_NEWS`` λ°μ΄ν°μ
μλ 4μ’
λ₯μ λ μ΄λΈμ΄ μ‘΄μ¬νλ―λ‘ ν΄λμ€μ κ°μλ 4κ°μ
λλ€.
#
# .. code-block:: sh
#
# 1 : World (μΈκ³)
# 2 : Sports (μ€ν¬μΈ )
# 3 : Business (κ²½μ )
# 4 : Sci/Tec (κ³Όν/κΈ°μ )
#
# μλ² λ© μ°¨μμ΄ 64μΈ λͺ¨λΈμ λ§λλλ€.
# μ΄νμ§μ ν¬κΈ°(Vocab size)λ μ΄νμ§(vocab)μ κΈΈμ΄μ κ°μ΅λλ€.
# ν΄λμ€μ κ°μλ λ μ΄λΈμ κ°μμ κ°μ΅λλ€.
#
train_iter = AG_NEWS(split="train")
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
######################################################################
# λͺ¨λΈμ νμ΅νκ³ κ²°κ³Όλ₯Ό νκ°νλ ν¨μ μ μνκΈ°
# -----------------------------------------------
#
import time
def train(dataloader):
model.train()
total_acc, total_count = 0, 0
log_interval = 500
start_time = time.time()
for idx, (label, text, offsets) in enumerate(dataloader):
optimizer.zero_grad()
predicted_label = model(text, offsets)
loss = criterion(predicted_label, label)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
optimizer.step()
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
if idx % log_interval == 0 and idx > 0:
elapsed = time.time() - start_time
print(
"| epoch {:3d} | {:5d}/{:5d} batches "
"| accuracy {:8.3f}".format(
epoch, idx, len(dataloader), total_acc / total_count
)
)
total_acc, total_count = 0, 0
start_time = time.time()
def evaluate(dataloader):
model.eval()
total_acc, total_count = 0, 0
with torch.no_grad():
for idx, (label, text, offsets) in enumerate(dataloader):
predicted_label = model(text, offsets)
loss = criterion(predicted_label, label)
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
return total_acc / total_count
######################################################################
# λ°μ΄ν°μ
μ λΆν νκ³ λͺ¨λΈ μννκΈ°
# -----------------------------------
#
# μλ³Έ ``AG_NEWS`` μλ κ²μ¦μ© λ°μ΄ν°κ° ν¬ν¨λμ΄ μμ§ μκΈ° λλ¬Έμ, μ°λ¦¬λ νμ΅
# λ°μ΄ν°λ₯Ό νμ΅ λ° κ²μ¦ λ°μ΄ν°λ‘ λΆν νλ € ν©λλ€. μ΄λ λ°μ΄ν°λ₯Ό λΆν νλ
# λΉμ¨μ 0.95(νμ΅)μ 0.05(κ²μ¦) μ
λλ€. μ°λ¦¬λ μ¬κΈ°μ PyTorchμ
# ν΅μ¬ λΌμ΄λΈλ¬λ¦¬ μ€ νλμΈ
# `torch.utils.data.dataset.random_split <https://pytorch.org/docs/stable/data.html?highlight=random_split#torch.utils.data.random_split>`__
# ν¨μλ₯Ό μ¬μ©ν©λλ€.
#
# `CrossEntropyLoss <https://pytorch.org/docs/stable/nn.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss>`__
# κΈ°μ€(criterion)μ κ° ν΄λμ€μ λν΄ ``nn.LogSoftmax()`` μ ``nn.NLLLoss()`` λ₯Ό
# ν©μ³λμ λ°©μμ
λλ€.
# `SGD <https://pytorch.org/docs/stable/_modules/torch/optim/sgd.html>`__
# optimizerλ νλ₯ μ κ²½μ¬ νκ°λ²μ ꡬνν΄λμ κ²μ
λλ€. μ²μμ νμ΅λ₯ μ
# 5.0μΌλ‘ λμμ΅λλ€. 맀 μνμ μ§ννλ©΄μ νμ΅λ₯ μ μ‘°μ ν λλ
# `StepLR <https://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html#StepLR>`__
# μ μ¬μ©ν©λλ€.
#
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# Hyperparameters
EPOCHS = 10 # epoch
LR = 5 # learning rate
BATCH_SIZE = 64 # batch size for training
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(
train_dataset, [num_train, len(train_dataset) - num_train]
)
train_dataloader = DataLoader(
split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
valid_dataloader = DataLoader(
split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
test_dataloader = DataLoader(
test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
for epoch in range(1, EPOCHS + 1):
epoch_start_time = time.time()
train(train_dataloader)
accu_val = evaluate(valid_dataloader)
if total_accu is not None and total_accu > accu_val:
scheduler.step()
else:
total_accu = accu_val
print("-" * 59)
print(
"| end of epoch {:3d} | time: {:5.2f}s | "
"valid accuracy {:8.3f} ".format(
epoch, time.time() - epoch_start_time, accu_val
)
)
print("-" * 59)
######################################################################
# νκ° λ°μ΄ν°λ‘ λͺ¨λΈ νκ°νκΈ°
# -------------------------------
#
######################################################################
# νκ° λ°μ΄ν°μ
μ ν΅ν κ²°κ³Όλ₯Ό νμΈν©λλ€...
print("Checking the results of test dataset.")
accu_test = evaluate(test_dataloader)
print("test accuracy {:8.3f}".format(accu_test))
######################################################################
# μμμ λ΄μ€λ‘ νκ°νκΈ°
# ------------------------
#
# νμ¬κΉμ§ μ΅κ³ μ λͺ¨λΈλ‘ 골ν λ΄μ€λ₯Ό ν
μ€νΈν΄λ³΄κ² μ΅λλ€.
#
ag_news_label = {1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tec"}
def predict(text, text_pipeline):
with torch.no_grad():
text = torch.tensor(text_pipeline(text))
output = model(text, torch.tensor([0]))
return output.argmax(1).item() + 1
ex_text_str = "MEMPHIS, Tenn. β Four days ago, Jon Rahm was \
enduring the seasonβs worst weather conditions on Sunday at The \
Open on his way to a closing 75 at Royal Portrush, which \
considering the wind and the rain was a respectable showing. \
Thursdayβs first round at the WGC-FedEx St. Jude Invitational \
was another story. With temperatures in the mid-80s and hardly any \
wind, the Spaniard was 13 strokes better in a flawless round. \
Thanks to his best putting performance on the PGA Tour, Rahm \
finished with an 8-under 62 for a three-stroke lead, which \
was even more impressive considering heβd never played the \
front nine at TPC Southwind."
model = model.to("cpu")
print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])