Skip to content

Commit

Permalink
Fix casual mask for transformer decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan committed Jun 14, 2023
1 parent 3f9fb3b commit 71c1472
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 72 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ model = AutoModel('rnn', predict_length=7, custom_model_params=custom_model_para
- seq2seq
- wavenet
- transformer
- informer

</details>

Expand Down
1 change: 1 addition & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ model = AutoModel('rnn', predict_length=7, custom_model_params=custom_model_para
- seq2seq
- wavenet
- transformer
- informer

</details>

Expand Down
11 changes: 6 additions & 5 deletions examples/run_prediction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Demo of time series prediction by tfts"""
# python run_prediction.py --use_model rnn
"""Demo of time series prediction by tfts
python run_prediction.py --use_model rnn
"""

import argparse
import os
Expand All @@ -23,9 +24,9 @@ def parse_args():
parser.add_argument("--use_data", type=str, default="sine", help="dataset: sine or airpassengers")
parser.add_argument("--train_length", type=int, default=24, help="sequence length for train")
parser.add_argument("--predict_length", type=int, default=12, help="sequence length for predict")
parser.add_argument("--n_epochs", type=int, default=50, help="Number of training epochs")
parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training")
parser.add_argument("--learning_rate", type=float, default=3e-4, help="learning rate for training")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="learning rate for training")

return parser.parse_args()

Expand All @@ -50,7 +51,7 @@ def run_train(args):
model = AutoModel(args.use_model, predict_length=args.predict_length)

trainer = KerasTrainer(model, optimizer=optimizer, loss_fn=loss_fn)
trainer.train(train, valid, n_epochs=args.n_epochs, early_stopping=EarlyStopping("val_loss", patience=5))
trainer.train(train, valid, n_epochs=args.epochs, early_stopping=EarlyStopping("val_loss", patience=5))

pred = trainer.predict(valid[0])
trainer.plot(history=valid[0], true=valid[1], pred=pred)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_examples/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class args(object):
use_model = "rnn"
train_length = 10
predict_length = 5
n_epochs = 2
epochs = 2
batch_size = 32
learning_rate = 0.003

Expand Down
5 changes: 5 additions & 0 deletions tests/test_layers/test_attention_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tensorflow as tf

from tfts.layers.attention_layer import FullAttention, ProbAttention, SelfAttention
from tfts.layers.mask_layer import CausalMask


class AttentionLayerTest(unittest.TestCase):
Expand All @@ -20,6 +21,10 @@ def test_full_attention_layer(self):
config = layer.get_config()
self.assertEqual(config["hidden_size"], hidden_size)

mask = CausalMask(2 * num_heads, 128).mask
y2 = layer(q, k, v, mask=mask)
self.assertEqual(y2.shape, (2, 128, hidden_size))

def test_self_attention_layer(self):
hidden_size = 64
num_heads = 4
Expand Down
9 changes: 6 additions & 3 deletions tests/test_layers/test_mask_layer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import unittest

from tfts.layers.mask_layer import MaskLayer
from tfts.layers.mask_layer import CausalMask, ProbMask


class MaskLayerTest(unittest.TestCase):
def test_conv_layer(self):
pass
def test_casual_mask_layer(self):
B = 2 * 8
L = 99
mask = CausalMask(B, L).mask
self.assertEqual(mask.shape, (B, L, L))
2 changes: 1 addition & 1 deletion tfts/layers/attention_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def call(self, q, k, v, mask=None):
k = self.dense_k(k)
v = self.dense_v(v)

q_ = tf.concat(tf.split(q, self.num_heads, axis=2), axis=0) # multi-heads transfer to
q_ = tf.concat(tf.split(q, self.num_heads, axis=2), axis=0) # multi-heads transfer to multi-sample
k_ = tf.concat(tf.split(k, self.num_heads, axis=2), axis=0)
v_ = tf.concat(tf.split(v, self.num_heads, axis=2), axis=0)

Expand Down
12 changes: 5 additions & 7 deletions tfts/layers/mask_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,15 @@
from tensorflow.keras import activations, constraints, initializers, regularizers


class MaskLayer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
class CausalMask:
"""Casual Mask is used for transformer decoder, used in first self-attention for decoder feature"""


class TriangularCausalMask:
def __init__(self, B, L):
mask_shape = [B, 1, L, L]
mask_shape = [B, L, L] # for multi-heads split [B, 1, L, L]

mask_a = tf.linalg.band_part(tf.ones(mask_shape), 0, -1) # Upper triangular matrix of 0s and 1s
mask_b = tf.linalg.band_part(tf.ones(mask_shape), 0, 0) # Diagonal matrix of 0s and 1s
mask = tf.cast(mask_a - mask_b, dtype=tf.float32)

self._mask = mask
tf.stop_gradient(self._mask)

Expand All @@ -28,6 +24,8 @@ def mask(self):


class ProbMask:
"""ProbMask for informer"""

def __init__(self, B, H, L, index, scores):
# B: batch_size, H: num_heads, L: seq_length
mask = tf.ones([L, scores.shape[-1]], tf.float32)
Expand Down
5 changes: 4 additions & 1 deletion tfts/models/informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from tfts.layers.attention_layer import FullAttention, ProbAttention
from tfts.layers.embed_layer import DataEmbedding, TokenEmbedding
from tfts.layers.mask_layer import CausalMask

params = {
"n_encoder_layers": 1,
Expand Down Expand Up @@ -118,8 +119,10 @@ def __call__(self, inputs, teacher=None):
encoder_feature = self.encoder_embedding(encoder_feature) # batch * seq * embedding_size
memory = self.encoder(encoder_feature, mask=None)

B, L, _ = tf.shape(decoder_feature)
casual_mask = CausalMask(B * self.params["num_heads"], L).mask
decoder_feature = self.decoder_embedding(decoder_feature)
outputs = self.decoder(decoder_feature, memory=memory)
outputs = self.decoder(decoder_feature, memory=memory, x_mask=casual_mask)
outputs = self.projection(outputs)

if self.params["skip_connect_circle"]:
Expand Down
89 changes: 35 additions & 54 deletions tfts/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tfts.layers.attention_layer import FullAttention, SelfAttention
from tfts.layers.dense_layer import FeedForwardNetwork
from tfts.layers.embed_layer import DataEmbedding, TokenEmbedding
from tfts.layers.mask_layer import CausalMask

params = {
"n_encoder_layers": 1,
Expand Down Expand Up @@ -61,17 +62,8 @@ def __init__(
params["ffn_filter_sizes"],
params["ffn_dropout"],
)
# self.decoder = Decoder(
# predict_sequence_length,
# params['n_decoder_layers'],
# params['attention_hidden_sizes'],
# params['num_heads'],
# params['attention_dropout'],
# params['ffn_hidden_sizes'],
# params['ffn_filter_sizes'],
# params['ffn_dropout'])

self.decoder = Decoder2(
predict_sequence_length,
embed_layer=TokenEmbedding(params["attention_hidden_sizes"]),
att_layers=[
DecoderLayer2(
Expand All @@ -90,19 +82,19 @@ def __init__(
self.project = Dense(1, activation=None)

def __call__(self, inputs, teacher=None):
"""_summary_
"""Time series transformer
Parameters
----------
inputs : _type_
_description_
teacher : _type_, optional
inputs : tf.Tensor
3D tensor for batch * seq_len * features
teacher : tf.Tensor, optional
_description_, by default None
Returns
-------
_type_
_description_
tf.Tensor
3D tensor for output, batch * output_seq * 1
"""
if isinstance(inputs, (list, tuple)):
x, encoder_feature, decoder_feature = inputs
Expand All @@ -126,7 +118,10 @@ def __call__(self, inputs, teacher=None):
memory = self.encoder(encoder_feature, src_mask=None)

# decoder_outputs = self.decoder(decoder_features, init_input=x[:, -1:], encoder_memory=memory, teacher=teacher)
decoder_outputs = self.decoder(decoder_feature, memory)

B, L, _ = tf.shape(decoder_feature)
casual_mask = CausalMask(B * self.params["num_heads"], L).mask
decoder_outputs = self.decoder(decoder_feature, memory, x_mask=casual_mask)
decoder_outputs = self.project(decoder_outputs)

if self.params["skip_connect_circle"]:
Expand Down Expand Up @@ -169,7 +164,7 @@ def build(self, input_shape):
super(Encoder, self).build(input_shape)

def call(self, encoder_inputs, src_mask=None):
"""_summary_
"""Transformer encoder
Parameters
----------
Expand Down Expand Up @@ -330,7 +325,7 @@ def build(self, input_shape):
super(DecoderLayer, self).build(input_shape)

def call(self, decoder_inputs, encoder_memory, tgt_mask=None, cross_mask=None):
"""_summary_
"""Decoder layer
Parameters
----------
Expand Down Expand Up @@ -376,9 +371,8 @@ def get_config(self):


class Decoder2(tf.keras.layers.Layer):
def __init__(self, predict_sequence_length, embed_layer, att_layers, norm_layer=None) -> None:
def __init__(self, embed_layer, att_layers, norm_layer=None) -> None:
super().__init__()
self.predict_sequence_length = predict_sequence_length
self.att_layers = att_layers
self.norm = norm_layer
self.decoder_embedding = embed_layer
Expand All @@ -389,48 +383,35 @@ def __init__(self, predict_sequence_length, embed_layer, att_layers, norm_layer=
self.drop2 = TimeDistributed(Dropout(0.1))
self.proj = TimeDistributed(Dense(1))

def decode(self, x, cross, x_mask, cross_mask):
x = self.decoder_embedding(x)
for layer in self.att_layers:
x = layer(x, cross, x_mask, cross_mask)
if self.norm is not None:
x = self.norm(x)
return x

def call(self, x, cross, x_mask=None, cross_mask=None, training=True):
"""_summary_
def call(self, x, memory, x_mask=None, memory_mask=None):
"""Transformer decoder2
Parameters
----------
x : _type_
_description_
cross : _type_
memory : _type_
_description_
x_mask : _type_, optional
_description_, by default None
cross_mask : _type_, optional
memory_mask : _type_, optional
_description_, by default None
training : bool, optional
_description_, by default True
Returns
-------
_type_
tf.Tensor
_description_
"""
x = self.decode(x, cross, x_mask, cross_mask)
x = self.decoder_embedding(x)
for layer in self.att_layers:
x = layer(x, memory, x_mask, memory_mask)
if self.norm is not None:
x = self.norm(x)

x = self.drop(x)
x = self.dense2(x)
x = self.drop2(x)
x = self.proj(x)

# if training:
# x = self.decode(x, cross, x_mask, cross_mask)
# else:
# for _ in range(self.predict_sequence_length):
# x1 = self.decode(x, cross, x_mask, cross_mask)
# print(x1.shape, x.shape)
# x = tf.concat([x, x1], axis=0)
return x


Expand Down Expand Up @@ -470,8 +451,8 @@ def build(self, input_shape):
)
super(DecoderLayer2, self).build(input_shape)

def call(self, decoder_inputs, encoder_memory, tgt_mask=None, cross_mask=None):
"""_summary_
def call(self, decoder_inputs, encoder_memory, decoder_mask=None, memory_mask=None):
"""Decoder layer2
Parameters
----------
Expand All @@ -481,7 +462,7 @@ def call(self, decoder_inputs, encoder_memory, tgt_mask=None, cross_mask=None):
_description_
tgt_mask : _type_, optional
_description_, by default None
cross_mask : _type_, optional
memory_mask : _type_, optional
_description_, by default None
Returns
Expand All @@ -493,14 +474,14 @@ def call(self, decoder_inputs, encoder_memory, tgt_mask=None, cross_mask=None):

for _, layer in enumerate(self.layers):
self_attention_layer, enc_dec_attention_layer, ffn_layer, ln_layer1, ln_layer2, ln_layer3 = layer
dec1 = x
# dec = self_attention_layer(dec, mask=tgt_mask)
# dec1 = ln_layer1(x + dec)
dec1 = enc_dec_attention_layer(dec1, encoder_memory, encoder_memory, mask=cross_mask)
dec = x
dec = self_attention_layer(dec, mask=decoder_mask)
dec1 = ln_layer1(x + dec)
dec1 = enc_dec_attention_layer(dec1, encoder_memory, encoder_memory, mask=memory_mask)
dec2 = ln_layer2(x + dec1)
dec2 = ffn_layer(dec2)
# x = ln_layer3(dec1 + dec2)
x = dec1 + dec2
x = ln_layer3(dec1 + dec2) # note that don't repeat ln
# x = dec1 + dec2
return x

def get_config(self):
Expand Down

0 comments on commit 71c1472

Please sign in to comment.