Official implementation of GeDi: Generative Discriminator Guided Sequence Generation
Blogpost here
Colab Notebook on controlling topic using GeDi here
Sept 29, 2020: Adding support for GeDi-guided GPT-3 generation (API key needed)
GeDi is a method of using class-conditional language models (which we refer to as generative discriminators (GeDis)) to guide generation from other (potentially much larger) language models. This has several advantages over finetuning large language models directly including:
- significantly less training computation.
- maintaining the diversity of the original language model (If we finetune a large pretrained language model to a specific attribute dataset, we will likely reduce the broad generation capabilities of the model).
- teaching the language model what not to generate. This is especially useful for applications like detoxification.
GeDi is a form of discriminator guided generation. A discriminator that can classify an attribute could be used to guide language model generation towards that attribute by classifying the sequences that result from candidate next tokens. However, using a normal discriminator (such as BERT) to do this would be very computationally expensive during generation, since it would require feeding in every candidate next token one-by-one to the discriminator to be classified. However, using generative discriminators, we can very efficiently classify candidate next tokens during generation using Bayes rule (see Section 3.1 of the paper). As an added bonus, generative discriminators can be used as zero shot classifiers, and can therefore be used to guide generation towards unseen topics.
-
Python 3.7, PyTorch 1.4 (We recommend creating a container using the pytorch/pytorch:1.4-cuda10.1-cudnn7-devel official pytorch docker image.)
-
Run
scripts/setup.sh
:cd scripts bash setup.sh
This will install the following:
- First download the models:
cd scripts
bash get_models.sh
This downloads and saves the topic, sentiment, and detoxifier models in the folder ../pretrained_models
- To generate, use
bash run_generation.sh
, which calls../generate_GeDi.py
with the appropriate arguments (set for topic generation by default).
Important arguments include:
--mode
can be set totopic
,sentiment
, ordetoxify
--gen_type
can be set togedi
for GeDi guided generation,cclm
for class conditional generation, orgpt2
to generate from raw GPT-2--gen_length
max length of generation--gedi_model_name_or_path
path to GeDi model. If unused, will assume you ranbash get_models.sh
and infer model directory from--mode
argument--filter_p
equal to 1 - \rho in Equation 7 of the paper--target_p
equal to \tau from the paper--disc_weight
exponent for posterior weighting (\omega in Equation 6 of the paper)--fp16
converts GPT2-XL weights to fp16 for faster generation and less GPU memory usage
Running will allow you to enter control codes and prompts for generation in a continuous loop until you exit.
- Set
--mode topic
inscripts/run_generation.sh
- You will be prompted to give a topic code. The model was trained on
world
,sports
,business
, andscience
, but can often generate other topics zero-shot, for instancespace
,fire
,climate
,education
- If the topic code you give is more than one BPE token, the model often struggles because the 4 training topics were all 1 BPE token. You will be warned that this might not work, but can proceed by hitting enter again (or can type a new topic code).
- After the topic code, you will be asked to give a prompt to the model to condition on for generation.
- Set
--mode sentiment
inscripts/run_generation.sh
- The model can controllably generate positive or negative text. When generalizing to other domains such as stories, this often translates to positive/negative mood or tone of the story (since sentiment implies an opinion).
- The model is set to positive sentiment by default. You will be prompted for the opportunity to change to negative sentiment by typing
n
. Note that the negative model can be very negative, and this sometimes results in toxic or offensive samples. - You will then be asked to give a prompt to the model to condition on for generation.
- Set
--mode detoxify
inscripts/run_generation.sh
- This mode can be used to avoid generating toxic or offensive text.
- You will then be asked to give a prompt to the model to condition on for generation.
- GeDi can often find a way to navigate especially aggressive prompts, but does rarely but occasionally still generate toxic text if given certain prompts. We observed this can be a problem for longer generations.
- Two of the baselines we consider are generating from GPT-2 (will give same result regardless of control codes), and generating from the GeDi model directly as a class-conditional language model (instead of using it to guide generation from GPT-2).
- Set
--gen_type gpt2
to generate from GPT-2, and--gen_type cclm
to generate directly from the GeDi as a class-conditional language model.--gen_type cclm
corresponds to all experiments in Section 5 of the paper, and the CC-LM baselines in Section 6.1.
- If you have your own GPT-3 API secret key, you can use GeDi to guide decoding from GPT-3.
- This is somewhat limited, since the GPT-3 API only allow access to the top 100 next token log probabilities.
- Reuses settings for controlling GPT-2 (which uses all next token log probs), retuning for GPT-3 could give better results.
- It is also slow (up to 1 second per token) because modifying GPT-3 decoding requires calling the API one token at a time.
To control sentiment from GPT-3 using your API key (should have prefix "sk"):
pip install openai
python ../generate_GeDi.py --penalize_cond --gen_length 100 --mode sentiment --gpt3_api_key sk-xxxxxxxx
You can also try changing the --mode
or other arguments. To generate directly from GPT-3 without GeDi using our same greedy decoding scheme:
python ../generate_GeDi.py --penalize_cond --gen_length 100 --mode sentiment --gen_type gpt2 --gpt3_api_key sk-xxxxxxx
- This repository includes code to train a topic GeDi using GeDi training.
- There are some differences in this training script and the one used to train the pretrained model. The pretrained model only used half of AG news, and there were some slight differences in preprocessing.
- This runs in about 5 hours on a 16GB V100 GPU on GCP.
- First, download and process the topic data:
cd scripts
bash get_data.sh
- Then run training using:
bash run_training.sh
which calls ../train_GeDi.py
with the appropriate arguments
- The directory for model to be saved is specified by
output_dir
argument. - When generating from your trained GeDi, you will need to call
../generate_GeDi.py
(called frombash run_generation.sh
) with--gedi_model_name_or_path
set to the directory of your trained model.
@article{KrauseGeDi2020,
title={{GeDi: Generative Discriminator Guided Sequence Generation}},
author={Krause, Ben and Gotmare, Akhilesh Deepak and McCann, Bryan and Keskar, Nitish Shirish and Joty, Shafiq and Socher, Richard and Rajani, Nazneen Fatema},
journal={arXiv preprint arXiv:2009.06367},
year={2020}
}
The code is released under the BSD-3 License (see LICENSE.txt
for details), but we also ask that users respect the following:
This software should not be used to promote or profit from violence, hate, and division, environmental destruction, abuse of human rights, or the destruction of people's physical and mental health.