This is a simple PoC (proof-of-concept) model built to restore punctuation & capitalization from a given text. In other words, given a text with no punctuations and no capitalization, this model is able to restore the needed punctuations and capitalization to make the text human-readable.
BertPuncCap is PyTorch model built on top of a pre-trained Google's BERT model by creating two linear layers that perform the two tasks simultaneously. One layer is responsible for the re-punctuation task and the other is responsible for the re-punctuation task as shown in the following figure:
How this model works can be summarized in the following steps:
- BertPuncCap takes an input sentence that consists of
segment_size=32
(by default) tokens. If the input is shorter thansegment_size
, then we are going to pad it with the edges (both ends of the input sentence). Thesegment_size
is a hyper-parameter that you can tune. - Then, the pre-trained BERT language model will return the representations for
the input tokens. The shape of the output should be
segment_size x model_dim
. If you are using BERT-base, then themodel_dim=768
. - These representations will be sent to the two linear layers for classification. One layer should classify the punctuation after each token while the other should classify the case.
- The loss function will be the weighted sum of the punctuation classification
loss
punc-loss
and the capitalization classification losscap-loss
according to the following formula where$\alpha$ is a hyper-parameter that you can set in yourconfig.yaml
file:
Note:
BertPuncCap was inspired by BertPunc with the following differences:
- BertPunc only handles punctuation restoration task, while this model handles both punctuation restoration & re-capitalization.
- BertPunc only handles COMMA, PERIOD and QUESTION_MARK, while this model handles three more punctuations EXCLAMATION, COLON, SEMICOLON. And you can add yours if you want. It's totally configurable.
- BertPunc is not compatible with HuggingFace
transformers
package, while this model does.- BertPunc doesn't provide any pre-trained model, while this model provides many.
You can check this notebook for the different ways for which you can use this model; also for how to get the confusion matrix of different classes.
To install the dependencies, run the following command:
pip install -r requirements.txt
You can download the pre-trained models from the following table:
Name | Pre-trained BertPuncCap | Training Data | Pre-trained BERT | Supported Languages |
---|---|---|---|---|
mbert_base_cased_fr | ( Model, Configuration ) | mTEDx | bert-base-multilingual-cased | French (fr) |
mbert_base_cased_8langs | ( Model, Configuration ) | mTEDx | bert-base-multilingual-cased |
|
Now, it's very easy to use these pre-trained models; here is an example:
>>> from transformers import BertTokenizer, BertModel
>>> from model import BertPuncCap
>>>
>>> # load pre-trained mBERT from HuggingFace's transformers package
>>> BERT_name = "bert-base-multilingual-cased"
>>> bert_tokenizer = BertTokenizer.from_pretrained(BERT_name)
>>> bert_model = BertModel.from_pretrained(BERT_name)
>>>
>>> # load trained checkpoint
>>> checkpoint_path = os.path.join("models", "mbert_base_cased")
>>> bert_punc_cap = BertPuncCap(bert_model, bert_tokenizer, checkpoint_path)
Now that we have loaded the model, let's use it:
>>> x = ["bonsoir",
... "notre planète est recouverte à 70 % d'océan et pourtant étrangement on a choisi de l'appeler « la Terre »"
... ]
>>> # start predicting
>>> bert_punc_cap.predict(x)
[
'Bonsoir ,',
"Notre planète est recouverte à 70 % d ' océan . et pourtant étrangement , on a choisi de l ' appeler « La Terre »"
]
To train the model, you need to use the train.py
script. Here is how you can
do so:
python train.py --seed 1234 \
--pretrained_bert bert-base-multilingual-cased \
--optimizer Adam \
--criterion cross_entropy \
--alpha 0.5 \
--dataset mTEDx \
--langs fr \
--save_path ./models/mbert_base_cased \
--batch_size 1024 \
--segment_size 32 \
--dropout 0.3 \
--lr 0.00001 \
--max_epochs 50 \
--num_validations 1 \
--patience 1 \
--stop_metric overall_f1
The following is a full list of all training parameters that can be used with this model:
Parameter | Description | Possible Values | Default |
---|---|---|---|
seed | Random seed | Any positive integer value | 1234 |
pretrained_bert | The name of the pre-trained BERT model from huggingface |
|
bert-base-multilingual-cased |
optimizer | The optimizer name to train this model on | Adam | - |
lr | The learning rate used by the optimizer. | Any positive number. | 0.00001 |
criterion | The criterion used to train the model | cross_entropy | - |
alpha | The tuning parameter of punc_loss & cap_loss | any value that belongs to [0,1] | 0.5 |
dataset | The dataset used for training | mTEDx | mTEDx |
langs | List of languages from the dataset that you need to train your model on. | Depends on the dataset | fr |
save_path | The relative/absolute path to save the model. | A working path | 1234- |
batch_size | The batch size for training, validating, and testing. | Any positive integer value | 256 |
segment_size | The segment size of the model. | Any positive integer value | 32 |
dropout | The dropout rate of the linear layers buit on top of BERT. | Any value between 0 and 1. | 0.3 |
max_epochs | The maximum number of epochs to train the model. | Any positive integer value | 50 |
num_validations | The number of validations to perform per epoch. | Any positive integer value | 1 |
patience | The number of validations to wait for performance improvement before early stopping. | Any positive integer value | 10 |
stop_metric | The name of the metric to watch for monitor to measure peformance for early stopping |
|
overall_f1 |
The list of punctuations & cases handled by this model can be seen down below:
-
Punctuations:
- COMMA
- PERIOD
- QUESTION
- EXCLAMATION
- COLON
- SEMICOLON
- O
-
Cases:
- F (First_Cap): When the first letter is capital.
- A (All_Cap): When the whole token is capitalized.
- O: Other
The training progress will be written in a file called progress.tsv
which can
be used to monitor the model's performance during training. In this file, you
can find important metrics about the training process.
For example, the following is the training/validation loss:
And the following is the F1-scores of all punctuation classes punc_overall_f1
,
all capitalization classes case_overall_f1
, and all of the classes
overall_f1
:
You can use this model to re-punctuate & re-capitalize ASR transcription.
You can use you repunc_recap.py
python script to do so, given the path of
- Pre-trained BertPuncCap.
- ASR output transcription file.
- Output file.
The following is a working example:
python repunc_recap.py \
--ckpt /gfs/project/stag/users/manwar/BertPuncCap/models/mbert_base_cased_8langs \
--in /gfs/project/stag/users/manwar/results/mTEDx_4/CASCADE/XLSR/test_fr.hyp \
--out /gfs/project/stag/users/manwar/results/mTEDx_4/CASCADE/XLSR/test_punc_cap_fr.hyp
After running this code, a new file named test_punc_cap_fr.hyp
will be
created where it should have words that are punctuated and capitalized.
For benchmarking this model and evaluating how it performs, you can use the
benchmark.py
python script to do so. It works similar to the previous script
where you need the absolute/relative path of:
- Pre-trained BertPuncCap.
- ASR reference file.
NOTE:
This reference file should have words that re punctuated (have punctuations) & capitalized.
The following is a working example:
python benchmark.py \
--ckpt /gfs/project/stag/users/manwar/BertPuncCap/models/mbert_base_cased_8langs \
--in /gfs/project/stag/users/manwar/results/mTEDx_4/CASCADE/XLSR/test.ref
And the following is the output which shows the Precision, Recall, and F1 scores of the different punctuations and cases:
, . ? ! : ;
Precision 0.974879 0.562257 0.536195 0.617647 0.0 0.687500 0.0
Recall 0.960403 0.604181 0.681283 0.253012 0.0 0.488889 0.0
F1 0.967587 0.582466 0.600094 0.358974 0.0 0.571429 0.0
O F A
Precision 0.962127 0.655914 0.495283
Recall 0.970909 0.559238 0.664557
F1 0.966498 0.603730 0.567568