Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Fix electra (#1291)
Browse files Browse the repository at this point in the history
* update Dockerfile

* fix num_out_files

* fix run_electra

* Revert "update Dockerfile"

This reverts commit 80593a2.
  • Loading branch information
zheyuye committed Aug 8, 2020
1 parent c33e62e commit 9e268c0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 1 addition & 3 deletions scripts/pretraining/data_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def main(args):
random.shuffle(fnames)
num_files = len(fnames)
num_out_files = min(args.num_out_files, num_files)
file_volume = math.ceil(num_files / num_out_files)
splited_files = np.array_split(fnames, file_volume)
num_out_files = len(splited_files)
splited_files = np.array_split(fnames, num_out_files)
output_files = [os.path.join(
args.output, "owt-pretrain-record-{}.npz".format(str(i).zfill(4))) for i in range(num_out_files)]
print("All preprocessed features will be saved in {} npz files".format(num_out_files))
Expand Down
8 changes: 5 additions & 3 deletions scripts/pretraining/run_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,11 @@ def train(args):
train_end_time - train_start_time))
if writer is not None:
writer.close()
model_name = args.model_name.replace('google', 'gluon')
save_dir = os.path.join(args.output_dir, model_name)
final_save(model, save_dir, tokenizer)

if local_rank == 0:
model_name = args.model_name.replace('google', 'gluon')
save_dir = os.path.join(args.output_dir, model_name)
final_save(model, save_dir, tokenizer)

# TODO(zheyuye), Directly implement a metric for weighted accuracy

Expand Down

0 comments on commit 9e268c0

Please sign in to comment.