-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrnn_word_model.py
105 lines (84 loc) · 2.85 KB
/
rnn_word_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import functools
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
def lazy_property(func):
attribute = '_cache_' + func.__name__
@property
@functools.wraps(func)
def decorator(self):
if not hasattr(self, attribute):
with tf.name_scope(func.__name__):
setattr(self, attribute, func(self))
return getattr(self, attribute)
return decorator
class RNNWordModel:
"""RNN-based TF model for inferring a class from character-level one-hot representation of a word."""
def __init__(self, inputs, targets, seq_length, dropout,
cell_type, num_layers, num_hidden, optimizer):
self.inputs = inputs
self.targets = targets
self.seq_length = seq_length
self.dropout = dropout
self.cell_type = cell_type
self.num_layers = num_layers
self.num_hidden = num_hidden
self.optimizer = optimizer if optimizer \
else tf.train.AdamOptimizer()
self.batch_size = tf.shape(self.inputs)[0]
self.num_classes = int(self.targets.get_shape()[1])
self.cell
self.last_output
self.logits
self.prediction
self.error
self.loss
self.training
@lazy_property
def cell(self):
layers = []
for i in range(self.num_layers):
if self.cell_type == rnn.LSTMCell:
layer = self.cell_type(self.num_hidden[i], use_peepholes=True)
else:
layer = self.cell_type(self.num_hidden[i])
layer = rnn.DropoutWrapper(layer, output_keep_prob=1.0 - self.dropout)
layers.append(layer)
if self.num_layers > 1:
return rnn.MultiRNNCell(layers)
else:
return layers[0]
@lazy_property
def last_output(self):
rnn_outputs, rnn_state = tf.nn.dynamic_rnn(
self.cell, self.inputs,
sequence_length=self.seq_length,
dtype=self.inputs.dtype
)
indices = tf.stack([
tf.range(0, self.batch_size),
self.seq_length - 1
], axis=1)
return tf.gather_nd(rnn_outputs, indices)
@lazy_property
def logits(self):
return tf.contrib.layers.fully_connected(
self.last_output, self.num_classes,
activation_fn=None
)
@lazy_property
def prediction(self):
return tf.nn.softmax(self.logits)
@lazy_property
def error(self):
mistakes = tf.not_equal(
tf.argmax(self.targets, 1),
tf.argmax(self.logits, 1)
)
return tf.reduce_mean(tf.cast(mistakes, tf.float32))
@lazy_property
def loss(self):
return tf.losses.softmax_cross_entropy(
self.targets, self.logits)
@lazy_property
def training(self):
return self.optimizer.minimize(self.loss)