Skip to content

Commit

Permalink
Update README
Browse files Browse the repository at this point in the history
  • Loading branch information
luozhouyang committed Jan 24, 2021
1 parent d1b0c03 commit 272154f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def build_bert_classify_model(pretrained_model_dir, trainable=True, **kwargs):
bert = Bert.from_pretrained(pretrained_model_dir, **kwargs)
bert.trainable = trainable

_, pooled_output, _, _ = bert(inputs=(input_ids, segment_ids))
sequence_outputs, pooled_output = bert(inputs=(input_ids, segment_ids))
outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
model.compile(loss='binary_cross_entropy', optimizer='adam')
Expand Down Expand Up @@ -87,7 +87,7 @@ def build_albert_classify_model(pretrained_model_dir, trainable=True, **kwargs):
albert = Albert.from_pretrained(pretrained_model_dir, **kwargs)
albert.trainable = trainable

_, pooled_output, _, _ = albert(inputs=(input_ids, segment_ids))
sequence_outputs, pooled_output = albert(inputs=(input_ids, segment_ids))
outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
model.compile(loss='binary_cross_entropy', optimizer='adam')
Expand Down

0 comments on commit 272154f

Please sign in to comment.