Skip to content

Commit

Permalink
Update transformer model (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan authored Aug 19, 2023
1 parent 310fcc3 commit b6381e4
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 25 deletions.
4 changes: 2 additions & 2 deletions tfts/models/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class AutoConfig(object):
"""AutoConfig for model"""

def __init__(self, use_model: str):
def __init__(self, use_model: str) -> None:
if use_model.lower() == "seq2seq":
self.params = seq2seq_params
elif use_model.lower() == "rnn":
Expand All @@ -44,7 +44,7 @@ def __init__(self, use_model: str):
# elif use_model.lower() == "gan":
# self.params = gan_params
else:
raise ValueError("unsupported model of {} yet".format(use_model))
raise ValueError("Unsupported model of {} yet".format(use_model))

def get_config(self):
return self.params
Expand Down
3 changes: 3 additions & 0 deletions tfts/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def __call__(
# assert len(x[0].shape) == 3, "The expected inputs dimension is 3, while get {}".format(len(x[0].shape))
return self.model(x)

def from_pretrained(self, name: str):
return

def build_model(self, inputs):
outputs = self.model(inputs)
return tf.keras.Model([inputs], [outputs]) # to handles the Keras symbolic tensors for tf2.3.1
Expand Down
27 changes: 12 additions & 15 deletions tfts/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
"n_encoder_layers": 1,
"n_decoder_layers": 1,
"use_token_embedding": False,
"attention_hidden_sizes": 32 * 1,
"attention_hidden_sizes": 128 * 1,
"num_heads": 1,
"attention_dropout": 0.0,
"ffn_hidden_sizes": 32 * 1,
"ffn_filter_sizes": 32 * 1,
"ffn_hidden_sizes": 128 * 1,
"ffn_filter_sizes": 128 * 1,
"ffn_dropout": 0.0,
"layer_postprocess_dropout": 0.0,
"scheduler_sampling": 1, # 0 means teacher forcing, 1 means use last prediction
"skip_connect_circle": False,
"skip_connect_mean": False,
Expand Down Expand Up @@ -157,10 +156,10 @@ def __init__(
def build(self, input_shape):
for _ in range(self.n_encoder_layers):
attention_layer = SelfAttention(self.attention_hidden_sizes, self.num_heads, self.attention_dropout)
feed_forward_layer = FeedForwardNetwork(self.ffn_hidden_sizes, self.ffn_filter_sizes, self.ffn_dropout)
ffn_layer = FeedForwardNetwork(self.ffn_hidden_sizes, self.ffn_filter_sizes, self.ffn_dropout)
ln_layer1 = LayerNormalization(epsilon=1e-6, dtype="float32")
ln_layer2 = LayerNormalization(epsilon=1e-6, dtype="float32")
self.layers.append([attention_layer, ln_layer1, feed_forward_layer, ln_layer2])
self.layers.append([attention_layer, ln_layer1, ffn_layer, ln_layer2])
super(Encoder, self).build(input_shape)

def call(self, inputs, mask=None):
Expand Down Expand Up @@ -280,7 +279,7 @@ def get_causal_attention_mask(self, inputs):
i = tf.range(sequence_length)[:, tf.newaxis]
j = tf.range(sequence_length)
mask = tf.cast(i >= j, dtype="int32")
mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
mask = tf.reshape(mask, (1, sequence_length, sequence_length))
mult = tf.concat(
[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
axis=0,
Expand Down Expand Up @@ -314,14 +313,12 @@ def __init__(
def build(self, input_shape):
for _ in range(self.n_decoder_layers):
self_attention_layer = SelfAttention(self.attention_hidden_sizes, self.num_heads, self.attention_dropout)
enc_dec_attention_layer = FullAttention(self.attention_hidden_sizes, self.num_heads, self.attention_dropout)
feed_forward_layer = FeedForwardNetwork(self.ffn_hidden_sizes, self.ffn_filter_sizes, self.ffn_dropout)
attention_layer = FullAttention(self.attention_hidden_sizes, self.num_heads, self.attention_dropout)
ffn_layer = FeedForwardNetwork(self.ffn_hidden_sizes, self.ffn_filter_sizes, self.ffn_dropout)
ln_layer1 = LayerNormalization(epsilon=self.eps, dtype="float32")
ln_layer2 = LayerNormalization(epsilon=self.eps, dtype="float32")
ln_layer3 = LayerNormalization(epsilon=self.eps, dtype="float32")
self.layers.append(
[self_attention_layer, enc_dec_attention_layer, feed_forward_layer, ln_layer1, ln_layer2, ln_layer3]
)
self.layers.append([self_attention_layer, attention_layer, ffn_layer, ln_layer1, ln_layer2, ln_layer3])
super(DecoderLayer, self).build(input_shape)

def call(self, decoder_inputs, encoder_memory, tgt_mask=None, cross_mask=None):
Expand All @@ -346,11 +343,11 @@ def call(self, decoder_inputs, encoder_memory, tgt_mask=None, cross_mask=None):
x = decoder_inputs

for _, layer in enumerate(self.layers):
self_attention_layer, enc_dec_attention_layer, ffn_layer, ln_layer1, ln_layer2, ln_layer3 = layer
self_attention_layer, attention_layer, ffn_layer, ln_layer1, ln_layer2, ln_layer3 = layer
dec = x
dec = self_attention_layer(dec, mask=tgt_mask)
dec1 = ln_layer1(x + dec)
dec1 = enc_dec_attention_layer(dec1, encoder_memory, encoder_memory, mask=cross_mask)
dec = ln_layer1(x + dec)
dec1 = attention_layer(dec, encoder_memory, encoder_memory, mask=cross_mask)
dec1 = ln_layer2(dec + dec1)
dec2 = ffn_layer(dec1)
x = ln_layer3(dec1 + dec2)
Expand Down
19 changes: 12 additions & 7 deletions tfts/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@ def __init__(
lr_scheduler=None,
strategy=None,
**kwargs
):
) -> None:
self.model = model

self.loss_fn = loss_fn
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.strategy = strategy

for key, value in kwargs.items():
setattr(self, key, value)

def train(
self,
train_loader,
Expand Down Expand Up @@ -216,7 +218,7 @@ def __init__(
lr_scheduler=None,
strategy=None,
**kwargs
):
) -> None:
"""
model: tf.keras.Model instance
loss: loss function
Expand All @@ -228,14 +230,17 @@ def __init__(
self.lr_scheduler = lr_scheduler
self.strategy = strategy

for key, value in kwargs.items():
setattr(self, key, value)

def train(
self,
train_dataset,
valid_dataset=None,
n_epochs=20,
batch_size=64,
steps_per_epoch=None,
callback_eval_metrics=None,
callback_metrics=None,
early_stopping=None,
checkpoint=None,
verbose=2,
Expand Down Expand Up @@ -289,7 +294,7 @@ def train(
self.model = self.model.build_model(inputs=inputs)

# print(self.model.summary())
self.model.compile(loss=self.loss_fn, optimizer=self.optimizer, metrics=callback_eval_metrics, run_eagerly=True)
self.model.compile(loss=self.loss_fn, optimizer=self.optimizer, metrics=callback_metrics, run_eagerly=False)
if isinstance(train_dataset, (list, tuple)):
x_train, y_train = train_dataset

Expand All @@ -315,14 +320,14 @@ def train(
)
return self.history

def predict(self, x_test, batch_size=1):
def predict(self, x_test, batch_size: int = 1):
y_test_pred = self.model.predict(x_test, batch_size=batch_size)
return y_test_pred

def get_model(self):
return self.model

def save_model(self, model_dir, only_pb=True, checkpoint_dir=None):
def save_model(self, model_dir, only_pb=True, checkpoint_dir: str = None):
# save the model, checkpoint_dir if you use Checkpoint callback to save your best weights
if checkpoint_dir is not None:
logging.info("checkpoint Loaded", checkpoint_dir)
Expand Down
2 changes: 1 addition & 1 deletion tfts/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class AutoTuner(object):
"""Auto tune parameters by optuna"""

def __init__(self, use_model: str):
def __init__(self, use_model: str) -> None:
self.use_model = use_model

def generate_parameter(self):
Expand Down

0 comments on commit b6381e4

Please sign in to comment.