diff --git a/train_bilal_baseline.py b/train_bilal_baseline.py index 12b933c..fa9e337 100644 --- a/train_bilal_baseline.py +++ b/train_bilal_baseline.py @@ -16,6 +16,37 @@ MAX_LENGTH = 320 +def bilal_full_yelp_finetune(): + """Create a BERT model using the model and parameters specified in the Bilal paper: + https://link.springer.com/article/10.1007/s10660-022-09560-w/tables/2 + + - model: TFBertForSequenceClassification + - learning rate: 2e-5 + - epsilon: 1e-8 + """ + # Using the TFBertForSequenceClassification as specified in the paper: + bert_model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) + + # Don't freeze any layers: + untrainable = [] + trainable = [w.name for w in bert_model.weights] + + for w in bert_model.weights: + if w.name in untrainable: + w._trainable = False + elif w.name in trainable: + w._trainable = True + + # Compile the model: + bert_model.compile( + optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5,epsilon=1e-08), + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics = [tf.keras.metrics.SparseCategoricalAccuracy("accuracy")] + ) + + return bert_model + + if __name__ == "__main__": # Code for helping save models locally after training: @@ -109,38 +140,6 @@ def get_google_drive_download_url(raw_url: str): print("Training ...") - - def bilal_full_yelp_finetune(): - """Create a BERT model using the model and parameters specified in the Bilal paper: - https://link.springer.com/article/10.1007/s10660-022-09560-w/tables/2 - - - model: TFBertForSequenceClassification - - learning rate: 2e-5 - - epsilon: 1e-8 - """ - # Using the TFBertForSequenceClassification as specified in the paper: - bert_model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) - - # Don't freeze any layers: - untrainable = [] - trainable = [w.name for w in bert_model.weights] - - for w in bert_model.weights: - if w.name in untrainable: - w._trainable = False - elif w.name in trainable: - w._trainable = True - - # Compile the model: - bert_model.compile( - optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5,epsilon=1e-08), - loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), - metrics = [tf.keras.metrics.SparseCategoricalAccuracy("accuracy")] - ) - - return bert_model - - model = bilal_full_yelp_finetune() print(model.summary())