Skip to content

Commit

Permalink
make model importable
Browse files Browse the repository at this point in the history
  • Loading branch information
toby-p committed Jul 23, 2022
1 parent 3e9e934 commit 4ec1939
Showing 1 changed file with 31 additions and 32 deletions.
63 changes: 31 additions & 32 deletions train_bilal_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,37 @@
MAX_LENGTH = 320


def bilal_full_yelp_finetune():
"""Create a BERT model using the model and parameters specified in the Bilal paper:
https://link.springer.com/article/10.1007/s10660-022-09560-w/tables/2
- model: TFBertForSequenceClassification
- learning rate: 2e-5
- epsilon: 1e-8
"""
# Using the TFBertForSequenceClassification as specified in the paper:
bert_model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# Don't freeze any layers:
untrainable = []
trainable = [w.name for w in bert_model.weights]

for w in bert_model.weights:
if w.name in untrainable:
w._trainable = False
elif w.name in trainable:
w._trainable = True

# Compile the model:
bert_model.compile(
optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5,epsilon=1e-08),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = [tf.keras.metrics.SparseCategoricalAccuracy("accuracy")]
)

return bert_model


if __name__ == "__main__":

# Code for helping save models locally after training:
Expand Down Expand Up @@ -109,38 +140,6 @@ def get_google_drive_download_url(raw_url: str):

print("Training ...")


def bilal_full_yelp_finetune():
"""Create a BERT model using the model and parameters specified in the Bilal paper:
https://link.springer.com/article/10.1007/s10660-022-09560-w/tables/2
- model: TFBertForSequenceClassification
- learning rate: 2e-5
- epsilon: 1e-8
"""
# Using the TFBertForSequenceClassification as specified in the paper:
bert_model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# Don't freeze any layers:
untrainable = []
trainable = [w.name for w in bert_model.weights]

for w in bert_model.weights:
if w.name in untrainable:
w._trainable = False
elif w.name in trainable:
w._trainable = True

# Compile the model:
bert_model.compile(
optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5,epsilon=1e-08),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = [tf.keras.metrics.SparseCategoricalAccuracy("accuracy")]
)

return bert_model


model = bilal_full_yelp_finetune()
print(model.summary())

Expand Down

0 comments on commit 4ec1939

Please sign in to comment.