Skip to content

Commit

Permalink
Update to support seq2seq lstm (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan committed Jun 19, 2023
1 parent b6a6faa commit 39966dc
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 24 deletions.
66 changes: 66 additions & 0 deletions tests/test_models/test_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,71 @@ def test_model(self):
y = model(x)
self.assertEqual(y.shape, (2, predict_sequence_length, 1), "incorrect output shape")

def test_model_gru_attn(self):
predict_sequence_length = 8
custom_model_params = {
"rnn_type": "gru",
"bi_direction": False,
"rnn_size": 64,
"dense_size": 64,
"num_stacked_layers": 1,
"scheduler_sampling": 0, # teacher forcing
"use_attention": True,
"attention_sizes": 64,
"attention_heads": 2,
"attention_dropout": 0,
"skip_connect_circle": False,
"skip_connect_mean": False,
}
model = Seq2seq(predict_sequence_length=predict_sequence_length, custom_model_params=custom_model_params)

x = tf.random.normal([2, 16, 3])
y = model(x)
self.assertEqual(y.shape, (2, predict_sequence_length, 1), "incorrect output shape")

def test_model_lstm(self):
predict_sequence_length = 8
custom_model_params = {
"rnn_type": "lstm",
"bi_direction": False,
"rnn_size": 64,
"dense_size": 64,
"num_stacked_layers": 1,
"scheduler_sampling": 0, # teacher forcing
"use_attention": False,
"attention_sizes": 64,
"attention_heads": 2,
"attention_dropout": 0,
"skip_connect_circle": False,
"skip_connect_mean": False,
}
model = Seq2seq(predict_sequence_length=predict_sequence_length, custom_model_params=custom_model_params)

x = tf.random.normal([2, 16, 3])
y = model(x)
self.assertEqual(y.shape, (2, predict_sequence_length, 1), "incorrect output shape")

def test_model_lstm_gru(self):
predict_sequence_length = 8
custom_model_params = {
"rnn_type": "lstm",
"bi_direction": False,
"rnn_size": 64,
"dense_size": 64,
"num_stacked_layers": 1,
"scheduler_sampling": 0, # teacher forcing
"use_attention": True,
"attention_sizes": 64,
"attention_heads": 2,
"attention_dropout": 0,
"skip_connect_circle": False,
"skip_connect_mean": False,
}
model = Seq2seq(predict_sequence_length=predict_sequence_length, custom_model_params=custom_model_params)

x = tf.random.normal([2, 16, 3])
y = model(x)
self.assertEqual(y.shape, (2, predict_sequence_length, 1), "incorrect output shape")

def test_train(self):
pass
2 changes: 1 addition & 1 deletion tfts/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def __init__(self, predict_sequence_length=3, custom_model_params=None) -> None:
self.dense2 = Dense(1)

def __call__(self, inputs, teacher=None):
"""_summary_
"""RNN model2
Parameters
----------
Expand Down
61 changes: 39 additions & 22 deletions tfts/models/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def __init__(
)

def __call__(self, inputs, teacher=None):
"""A RNN seq2seq structure for time series
"""An RNN seq2seq structure for time series
:param inputs: _description_
:type inputs: _type_
:param teacher: teacher forcing decoding, defaults to None
:type teacher: _type_, optional
:return: _description_
:rtype: _type_
:type: _type_
"""
if isinstance(inputs, (list, tuple)):
x, encoder_feature, decoder_feature = inputs
Expand Down Expand Up @@ -105,36 +105,42 @@ def __call__(self, inputs, teacher=None):
class Encoder(tf.keras.layers.Layer):
def __init__(self, rnn_type, rnn_size, rnn_dropout=0, dense_size=32, **kwargs):
super(Encoder, self).__init__(**kwargs)
self.rnn_type = rnn_type
if rnn_type.lower() == "gru":
self.rnn = GRU(
units=rnn_size, activation="tanh", return_state=True, return_sequences=True, dropout=rnn_dropout
)
elif self.rnn_type.lower() == "lstm":
elif rnn_type.lower() == "lstm":
self.rnn = LSTM(
units=self.rnn_size,
units=rnn_size,
activation="tanh",
return_state=True,
return_sequences=True,
dropout=self.rnn_dropout,
dropout=rnn_dropout,
)
self.dense = Dense(units=dense_size, activation="tanh")

def call(self, inputs):
"""_summary_
"""Seq2seq encoder
Parameters
----------
inputs : _type_
inputs : tf.Tensor
_description_
Returns
-------
_type_
_description_
tf.Tensor
batch_size * input_seq_length * rnn_size, state: batch_size * rnn_size
"""
# outputs: batch_size * input_seq_length * rnn_size, state: batch_size * rnn_size
outputs, state = self.rnn(inputs)
state = self.dense(state)
if self.rnn_type.lower() == "gru":
outputs, state = self.rnn(inputs)
state = self.dense(state)
elif self.rnn_type.lower() == "lstm":
outputs, state1, state2 = self.rnn(inputs)
state = (state1, state2)
else:
raise ValueError("No supported rnn type of {}".format(self.rnn_type))
# encoder_hidden_state = tuple(self.dense(hidden_state) for _ in range(params['num_stacked_layers']))
# outputs = self.dense(outputs) # => batch_size * input_seq_length * dense_size
return outputs, state
Expand Down Expand Up @@ -164,14 +170,15 @@ def build(self, input_shape):
if self.rnn_type.lower() == "gru":
self.rnn_cell = GRUCell(self.rnn_size)
elif self.rnn_type.lower() == "lstm":
self.rnn = LSTMCell(units=self.rnn_size)
self.rnn_cell = LSTMCell(units=self.rnn_size)
self.dense = Dense(units=1, activation=None)
if self.use_attention:
self.attention = FullAttention(
hidden_size=self.attention_sizes,
num_heads=self.attention_heads,
attention_dropout=self.attention_dropout,
)
super().build(input_shape)

def call(
self,
Expand All @@ -183,7 +190,7 @@ def call(
training=None,
**kwargs
):
"""_summary_
"""Seq2seq decoder1: step by step
:param decoder_features: _description_
:type decoder_features: _type_
Expand Down Expand Up @@ -221,10 +228,21 @@ def call(
this_input = tf.concat([this_input, decoder_features[:, i]], axis=-1)

if self.use_attention:
att = self.attention(
tf.expand_dims(prev_state, 1), k=kwargs["encoder_output"], v=kwargs["encoder_output"]
)
att = tf.squeeze(att, 1)
if self.rnn_type.lower() == "gru":
# q: (batch, 1, feature), att_output: (batch, 1, feature)
att = self.attention(
tf.expand_dims(prev_state, 1), k=kwargs["encoder_output"], v=kwargs["encoder_output"]
)
att = tf.squeeze(att, 1) # (batch, feature)
elif self.rnn_type.lower() == "lstm":
# q: (batch, 1, feature * 2), att_output: (batch, 1, feature)
att = self.attention(
tf.expand_dims(tf.concat(prev_state, 1), 1),
k=kwargs["encoder_output"],
v=kwargs["encoder_output"],
)
att = tf.squeeze(att, 1) # (batch, feature)

this_input = tf.concat([this_input, att], axis=-1)

this_output, this_state = self.rnn_cell(this_input, prev_state)
Expand Down Expand Up @@ -268,6 +286,7 @@ def build(self, input_shape):
num_heads=self.attention_heads,
attention_dropout=self.attention_dropout,
)
super().build(input_shape)

def forward(
self,
Expand Down Expand Up @@ -327,7 +346,7 @@ def call(
training=None,
**kwargs
):
"""_summary_
"""Decoder model2
Parameters
----------
Expand All @@ -337,8 +356,6 @@ def call(
_description_
decoder_init_input : _type_
_description_
encoder_output : _type_
_description_
teacher : _type_, optional
_description_, by default None
Expand Down Expand Up @@ -384,7 +401,7 @@ def call(
training=None,
**kwargs
):
"""_summary_
"""Decoder3: just simple
Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion tfts/models/wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __call__(
training=None,
**kwargs
):
"""_summary_
"""wavenet decoder1
Parameters
----------
Expand Down

0 comments on commit 39966dc

Please sign in to comment.