Skip to content

Commit

Permalink
fix: modify the shape for graph mode of informer #47 (#49)
Browse files Browse the repository at this point in the history
* Update dataembed from tokenembed

* Update dataembed from tokenembed

* fix: modify the shape for graph mode #47

* fix: solve conflict of informer
  • Loading branch information
LongxingTan committed Oct 11, 2023
1 parent 4b179e8 commit 403d8a3
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 24 deletions.
46 changes: 46 additions & 0 deletions tests/test_models/test_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
python -m unittest -v tests/test_models/test_informer.py
"""

from typing import Any, Dict
import unittest

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import LayerNormalization

Expand All @@ -12,6 +14,8 @@
from tfts.layers.attention_layer import FullAttention, ProbAttention
from tfts.models.informer import Decoder, DecoderLayer, DistilConv, Encoder, EncoderLayer, Informer

tf.config.run_functions_eagerly(True)


class InformerTest(unittest.TestCase):
def test_model(self):
Expand Down Expand Up @@ -101,3 +105,45 @@ def test_decoder(self):
y = decoder(x, memory=memory)

self.assertEqual(y.shape, (2, 50, attention_hidden_sizes))

def test_train(self):
params: Dict[str, Any] = {
"n_encoder_layers": 1,
"n_decoder_layers": 1,
"attention_hidden_sizes": 32 * 1,
"num_heads": 1,
"attention_dropout": 0.0,
"ffn_hidden_sizes": 32 * 1,
"ffn_filter_sizes": 32 * 1,
"ffn_dropout": 0.0,
"skip_connect_circle": False,
"skip_connect_mean": False,
"prob_attention": False,
"distil_conv": False,
}

custom_params = params.copy()
custom_params["prob_attention"] = True

train_length = 49
predict_length = 10
n_encoder_feature = 2
n_decoder_feature = 3

x_train = (
np.random.rand(1, train_length, 1),
np.random.rand(1, train_length, n_encoder_feature),
np.random.rand(1, predict_length, n_decoder_feature),
)
y_train = np.random.rand(1, predict_length, 1) # target: (batch, predict_length, 1)

x_valid = (
np.random.rand(1, train_length, 1),
np.random.rand(1, train_length, n_encoder_feature),
np.random.rand(1, predict_length, n_decoder_feature),
)
y_valid = np.random.rand(1, predict_length, 1)

model = AutoModel("Informer", predict_length=predict_length, custom_model_params=custom_params)
trainer = KerasTrainer(model)
trainer.train((x_train, y_train), (x_valid, y_valid), n_epochs=1)
46 changes: 25 additions & 21 deletions tfts/layers/attention_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,36 +141,37 @@ def build(self, input_shape: Tuple[Optional[int], ...]) -> None:
super().build(input_shape)

def _prob_qk(self, q, k, sample_k, top_n):
B, H, L, E = k.shape
_, H, L, E = k.shape
_, _, S, _ = q.shape
B = tf.shape(k)[0]

k_expand = tf.broadcast_to(tf.expand_dims(k, -3), (B, H, L, S, E))
k_random_index = tf.random.uniform((S, sample_k), maxval=L, dtype=tf.int32)
k_random_index = tf.tile(k_random_index[tf.newaxis, tf.newaxis, :], [B, H, 1, 1])

batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis, tf.newaxis], (1, H, L, k_random_index.shape[-1]))
head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis, tf.newaxis], (B, 1, L, k_random_index.shape[-1]))
k_indexes = tf.tile(tf.range(L)[tf.newaxis, tf.newaxis, :, tf.newaxis], (B, H, 1, k_random_index.shape[-1]))
indx_q_seq = tf.random.uniform((S,), maxval=L, dtype=tf.int32)
indx_k_seq = tf.random.uniform((sample_k,), maxval=L, dtype=tf.int32)

k_random_index = tf.stack([batch_indexes, head_indexes, k_indexes, k_random_index], axis=-1)
k_sample = tf.gather_nd(k_expand, k_random_index)
K_sample = tf.gather(k_expand, tf.range(S), axis=2)

qk_sample = tf.squeeze(tf.matmul(tf.expand_dims(q, -2), tf.transpose(k_sample, [0, 1, 2, 4, 3])))
m = tf.math.reduce_max(qk_sample, axis=-1) - tf.divide(tf.reduce_sum(qk_sample, axis=-1), L)
m_top = tf.math.top_k(m, top_n, sorted=False)[1]
m_top = m_top[tf.newaxis] if B == 1 else m_top
m_top = tf.tile(m_top, (1, 1, 1))
K_sample = tf.gather(K_sample, indx_q_seq, axis=2)
K_sample = tf.gather(K_sample, indx_k_seq, axis=3)

Q_K_sample = tf.squeeze(tf.matmul(tf.expand_dims(q, -2), tf.einsum("...ij->...ji", K_sample)))
M = tf.math.reduce_max(Q_K_sample, axis=-1) - tf.raw_ops.Div(x=tf.reduce_sum(Q_K_sample, axis=-1), y=L)
m_top = tf.math.top_k(M, top_n, sorted=False)[1]
m_top = m_top[tf.newaxis, tf.newaxis] if B == 1 else m_top

batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis], (1, H, top_n))
head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis], (B, 1, top_n))

idx = tf.stack([batch_indexes, head_indexes, m_top], axis=-1)

q_reduce = tf.gather_nd(q, idx)
qk = tf.matmul(q_reduce, tf.transpose(k, (0, 1, 3, 2)))
return qk, m_top

def _get_initial_context(self, v, L_Q):
B, H, L_V, D = v.shape
_, H, L_V, D = v.shape
B = tf.shape(v)[0]
if not self.mask_flag:
v_sum = tf.math.reduce_sum(v, axis=-2)
context = tf.identity(tf.boradcast_to(tf.expand_dims(v_sum, -2), [B, H, L_Q, v_sum.shape[-1]]))
Expand All @@ -180,9 +181,10 @@ def _get_initial_context(self, v, L_Q):
return context

def _update_context(self, context_in, v, scores, index, L_Q):
B, H, L_V, D = v.shape
batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis], (1, H, index.shape[-1]))
head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis], (B, 1, index.shape[-1]))
_, H, L_V, D = v.shape
B = tf.shape(v)[0]
batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis], (1, H, tf.shape(index)[-1]))
head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis], (B, 1, tf.shape(index)[-1]))
index = tf.stack([batch_indexes, head_indexes, index], axis=-1)

if self.mask_flag:
Expand All @@ -193,18 +195,20 @@ def _update_context(self, context_in, v, scores, index, L_Q):
context_in = tf.tensor_scatter_nd_update(context_in, index, tf.matmul(attn, v))
return tf.convert_to_tensor(context_in)

# @tf.function
def call(self, q, k, v, mask=None):
"""Prob attention"""
q = self.dense_q(q) # project the query/key/value to num_heads * units
k = self.dense_k(k)
v = self.dense_v(v)

B, L, D = q.shape
_, L, D = q.shape
B = tf.shape(q)[0]
_, S, _ = k.shape

q_ = tf.reshape(q, (B, self.num_heads, L, -1))
k_ = tf.reshape(k, (B, self.num_heads, S, -1))
v_ = tf.reshape(v, (B, self.num_heads, S, -1))
q_ = tf.reshape(q, (-1, self.num_heads, L, self.hidden_size // self.num_heads))
k_ = tf.reshape(k, (-1, self.num_heads, S, self.hidden_size // self.num_heads))
v_ = tf.reshape(v, (-1, self.num_heads, S, self.hidden_size // self.num_heads))

u_q = self.factor * np.ceil(np.log(L)).astype("int").item()
u_k = self.factor * np.ceil(np.log(S)).astype("int").item()
Expand Down
2 changes: 1 addition & 1 deletion tfts/layers/mask_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, B, H, L, index, scores):
mask_expanded = tf.broadcast_to(mask, [B, H, L, scores.shape[-1]])
# mask specific q based on reduced Q
mask_Q = tf.gather_nd(mask_expanded, index)
self._mask = tf.cast(tf.reshape(mask_Q, scores.shape), tf.bool)
self._mask = tf.cast(tf.reshape(mask_Q, tf.shape(scores)), tf.bool)

@property
def mask(self):
Expand Down
2 changes: 0 additions & 2 deletions tfts/models/informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ def call(self, x, mask=None):
if self.conv_layers is not None:
for attn_layer, conv_layer in zip(self.layers, self.conv_layers):
x = attn_layer(x, mask)
# print(x.shape)
# x = conv_layer(x)
# print(x.shape)
x = self.layers[-1](x, mask)

else:
Expand Down

0 comments on commit 403d8a3

Please sign in to comment.