Skip to content

Commit

Permalink
Merge pull request #91 from StevenTang1998/main
Browse files Browse the repository at this point in the history
Update
  • Loading branch information
turboLJY authored Jan 11, 2021
2 parents 0dcbaa1 + 0892b58 commit d9567de
Show file tree
Hide file tree
Showing 13 changed files with 32,338 additions and 22,041 deletions.
14 changes: 14 additions & 0 deletions PYPI.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ We provide the support for 6 benchmark text generation datasets. A user can appl
- **Unified and modularized framework.** TextBox is built upon PyTorch and designed to be highly modularized, by decoupling diverse models into a set of highly reusable modules.
- **Comprehensive models, benchmark datasets and standardized evaluations.** TextBox also contains a wide range of text generation models, covering the categories of VAE, GAN, RNN or Transformer based models, and pre-trained language models (PLM).
- **Extensible and flexible framework.** TextBox provides convenient interfaces of various common functions or modules in text generation models, RNN encoder-decoder, Transformer encoder-decoder and pre-trained language model.
- **Easy and convenient to get started.** TextBox provides flexible configuration files, which allows green hands to run experiments without modifying source code, and allows researchers to conduct qualitative analysis by modifying few configurations.

## Installation

Expand Down Expand Up @@ -88,6 +89,19 @@ This will perform the training and test of the RNN model on the COCO dataset.

If you want to run different models, parameters or datasets, the operations are same with **Start from source**.

### **Using Pretrained Language Model**

TextBox supports to apply part of pretrained language models (PLM) to conduct text generation. Take the GPT-2 for example, we will show you how to use PLMs to fine-tune.

1. Download the GPT-2 model provided from Hugging Face (https://huggingface.co/gpt2/tree/main), including `config.json`, `merges.txt`, `pytorch_model.bin`, `tokenizer.json`and `vocab.json`. Then put them in a folder at the same level as `textbox`, such as `pretrained_model/gpt2`.

2. After downloading, you just need to run the command:

```bash
python run_textbox.py --model=GPT2 --dataset=COCO --task_type=unconditional \
--pretrained_model_path=pretrained_model/gpt2
```

## The Team

TextBox is developed and maintained by [AI Box](http://aibox.ruc.edu.cn/).
Expand Down
240 changes: 191 additions & 49 deletions README.md

Large diffs are not rendered by default.

250 changes: 196 additions & 54 deletions README_CN.md

Large diffs are not rendered by default.

10,000 changes: 10,000 additions & 0 deletions generated_examples/GPT2-COCO.txt

Large diffs are not rendered by default.

10,000 changes: 10,000 additions & 0 deletions generated_examples/HybridVAE-EMNLP_news.txt

Large diffs are not rendered by default.

10,000 changes: 0 additions & 10,000 deletions generated_examples/LeakGAN_EMNLP.txt

This file was deleted.

1,936 changes: 1,936 additions & 0 deletions generated_examples/RNNEncDec-GigaWord.txt

Large diffs are not rendered by default.

1,936 changes: 0 additions & 1,936 deletions generated_examples/RNNEncDec_GigaWord.txt

This file was deleted.

10,000 changes: 0 additions & 10,000 deletions generated_examples/RankGAN_IMDB.txt

This file was deleted.

10,000 changes: 10,000 additions & 0 deletions generated_examples/SeqGAN-IMDB.txt

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion textbox/module/Generator/MaskGANGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def adversarial_loss(self, inputs, lengths, targets, targets_present, discrimina
outputs, log_probs, logits = self.forward(inputs, lengths, targets, targets_present)
fake_predictions, _ = discriminator(inputs, lengths, outputs, targets_present, self.embedder)
fake_predictions = fake_predictions.detach()
est_state_values = discriminator.critic(inputs, outputs, self.embedder)
est_state_values = discriminator.critic(outputs, self.embedder)
rl_loss, critic_loss = self.calculate_reinforce_objective(log_probs, fake_predictions, targets_present,
est_state_values)
return (rl_loss, critic_loss)
Expand Down
1 change: 0 additions & 1 deletion textbox/properties/dataset/IMDB.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ max_seq_length: 100
split_strategy: "by_ratio"
split_ratio: [0.8,0.1,0.1]
source_language: "English"
train_batch_size: 100

0 comments on commit d9567de

Please sign in to comment.