Skip to content

Latest commit

 

History

History

experiments

jtt-m experiments

Prepare the data

cd data
sh 01_prepare.sh
cd ..

The script 01_prepare.sh downloads Schuster et al. (2021)'s preprocessing of the FEVER and MultiNLI datasets. After the files are downloaded and unzipped, you should see something like

+ wc -l fever/train.jsonl fever/dev.jsonl fever/test.jsonl
  178059 fever/train.jsonl
   11620 fever/dev.jsonl
   11710 fever/test.jsonl
  201389 total

+ python ../../group_stats.py fever/train.jsonl
S & 99303 (61.5) & 1267 (7.7) & 100570 (56.5)
R & 27575 (17.1) & 14275 (86.3) & 41850 (23.5)
N & 34633 (21.4) & 1006 (6.1) & 35639 (20.0)

for FEVER. For MultiNLI, you will see something like

+ wc -l mnli/train.jsonl mnli/dev.jsonl mnli/test.jsonl
   392702 mnli/train.jsonl
     9832 mnli/dev.jsonl
     9832 mnli/test.jsonl
   412366 total
+ python ../../group_stats.py mnli/train.jsonl
S & 118554 (36.7) & 12345 (17.7) & 130899 (33.3)
R & 88180 (27.3) & 42723 (61.2) & 130903 (33.3)
N & 116185 (36.0) & 14715 (21.1) & 130900 (33.3)

Reproduce JTT-m

Step 1: Do the first training to obtain the error set

Below we go through the procedure for the FEVER experiments. The process for the MNLI experiments is the same, except that we only evaluate on MultiNLI's test set, since its development set and test set are identical in this preprocessing version.

To train the model:

cd fever+sgd
sh 01_train.sh

Note that the script contains SLURM directives that specify our GPU resource requirements for the sbtach command, but it can also be run with the sh command.

Step 2: Prepare filtered error set for upweighting (JTT)

After training is finished, run

sh 02_predict.sh

This step

(a) makes predictions on the training set and saves the penultimate-layer embeddings of the training set for outlier removal later, and

(b) makes predictions on the development and test set (if the test set is available).

You can view the prediction results in bert-base-uncased-128-out/eval.{train,dev,test}.txt, and the group accuracies in bert-base-uncased-128-out/eval.groups.{train,dev,test}.txt .

The results should look something like this (showing the dev set results):


tail -n 11 bert-base-uncased-128-out/eval.dev.txt

S     R     N
S  3788   118    58
R   705  2574  1044
N   530   494  2309

          S     R     N
Prec:  75.4  80.8  67.7
Rec:   95.6  59.5  69.3
F1:    84.3  68.6  68.5

Acc: 74.6
tail bert-base-uncased-128-out/eval.groups.dev.txt

Total 11620, correct 8671, wrong 2949
Avg acc: 74.6 (8671/11620)
Worst group acc: 14.0
(S, no neg): 96.0 (3777/3934)
(S, neg): 36.7 (11/30)
(R, no neg): 43.7 (1339/3067)
(R, neg): 98.3 (1235/1256)
(N, no neg): 70.9 (2296/3240)
(N, neg): 14.0 (13/93)

Once all the predictions are finished, run

sh 03_calc_mahal.sh 

(This script does not require a GPU.)

This calculates the Mahalanobis distances of the penultimate layer embeddings and saves the distances calculated in bert-base-uncased-128-out/train.mahal.npy.

Step 3: Upweight examples

Run

sh 04_augment.sh

(This script does not require a GPU.)

This script upweights the training data in two different ways, and the upweighted training data are saved in corresponding subfolders (with the same name as their experiment folders) in ../data:

  1. JTT-m: Upweights incorrectly-predicted training set examples with outliers removed (by the Mahalanobis distance method) from the error set

    The folder is fever_sgd_df5_up3, meaning it uses sgd in the ERM training, filters the error set from the ERM training with degree of freedom df 5, and upweights (up) the filtered error set for 3 times. The second training uses adamw as its optimizer.

  2. JTT: Upweights incorrectly-predicted training set examples

    Its folder is fever_sgd_thres1.0_up3. thres1.0 sets the threshold to 1.0 for filtering out incorrect examples by their predicted probabilities. The default threshold is 1.0, so no examples will be filtered out for JTT. The examples are upweighted 3 times.

Step 4: Retrain the model on upweighted training set

We use fever_sgd_df5_up3+adamw (the JTT-m experiment) as an example. The procedure is the same for fever_sgd_thres1.0_up3+adamw (the JTT experiment).

cd ..
cd fever_sgd_df5_up3+adamw
sh 01_train.sh

Once the training is finished, run

sh 02_predict.sh

The prediction results will be saved in bert-base-uncased-128-out as eval.{dev,test}.txt and eval.groups.{dev,test}.txt . You should see something like this (using the dev set as example):

==> bert-base-uncased-128-out/eval.dev.txt <==
S  3763   109    92
R   250  3748   325
N   257   359  2717

          S     R     N
Prec:  88.1  88.9  86.7
Rec:   94.9  86.7  81.5
F1:    91.4  87.8  84.0

Acc: 88.0

==> bert-base-uncased-128-out/eval.groups.dev.txt <==
Total 11620, correct 10228, wrong 1392
Avg acc: 88.0 (10228/11620)
Worst group acc: 50.0
(S, no neg): 95.3 (3748/3934)
(S, neg): 50.0 (15/30)
(R, no neg): 83.3 (2554/3067)
(R, neg): 95.1 (1194/1256)
(N, no neg): 82.2 (2664/3240)
(N, neg): 57.0 (53/93)

Reproduce ERM results

The ERM models are trained with a different optimizer than the first training in Step 1. Instead of using SGD, AdamW is used. The training and prediction scripts are in fever+adamw and mnli+adamw.

To train, run

sh 01_train.sh

Run

sh 02_predict.sh

to obtain ERM model results for the test set (and dev set, if available).

Model checkpoints

We also release model checkpoints (as well as example outputs and preprocessed data) at: https://doi.org/10.5281/zenodo.7260028.