Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

accelerate integration #58

Merged
merged 48 commits into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
9c977d0
working v1
younesbelkada Dec 27, 2022
1971cea
add `accelerate` on requirements
younesbelkada Dec 27, 2022
45cad09
add `accelerate` on `setup.py`
younesbelkada Dec 27, 2022
a0ebdaa
add `datasets` on `setup.py`
younesbelkada Dec 27, 2022
dec21f3
small updates
younesbelkada Dec 27, 2022
4254292
rm unneeded file
younesbelkada Dec 27, 2022
19f4d92
replace with `generate`
younesbelkada Dec 27, 2022
35330a9
Update trl/trainer/accelerate_ppo.py
younesbelkada Dec 27, 2022
34773de
correct return
younesbelkada Dec 27, 2022
b810d8a
add dataloader support
younesbelkada Dec 27, 2022
e4c57b2
add `wandb` to `setup.py`
younesbelkada Dec 27, 2022
7516b37
refactor
younesbelkada Dec 27, 2022
40f81e0
test
younesbelkada Dec 27, 2022
b1638e5
fix test
younesbelkada Dec 27, 2022
e2e7a90
rename file
younesbelkada Dec 27, 2022
96b4115
refactor
younesbelkada Dec 27, 2022
5eb46ad
remove unneeded device assignment
younesbelkada Dec 27, 2022
609f718
fix correct device assignment
younesbelkada Dec 27, 2022
4d57b47
standardize docstrings
younesbelkada Dec 27, 2022
fac85b5
add `wandb` on `dev`
younesbelkada Dec 27, 2022
c1b166b
fix slow convergence
younesbelkada Dec 28, 2022
9495f2a
oops
younesbelkada Dec 28, 2022
c813857
revert fix
younesbelkada Dec 28, 2022
157eca6
revert patch
younesbelkada Dec 28, 2022
2efb961
Merge remote-tracking branch 'origin/master' into accelerate-ppo
younesbelkada Dec 28, 2022
0a1c9a2
remove unneeded reshape
younesbelkada Dec 28, 2022
b6004f0
add input safety checker
younesbelkada Dec 28, 2022
f47b907
refactor
younesbelkada Dec 28, 2022
2918a8e
Apply suggestions from code review
younesbelkada Dec 29, 2022
747d5f0
refactor
younesbelkada Dec 29, 2022
7615994
some refactor
younesbelkada Dec 29, 2022
65be5bd
remove unneeded hack
younesbelkada Dec 29, 2022
edd5ea3
adapt dataset
younesbelkada Dec 29, 2022
76c2afd
fix test
younesbelkada Dec 29, 2022
5d41170
remove rollout
younesbelkada Dec 29, 2022
7843a34
remove timing
younesbelkada Dec 29, 2022
6cd89d5
remove `shuffle=True`
younesbelkada Dec 29, 2022
4e802e8
remove `LengthSampler` from trainer
younesbelkada Dec 29, 2022
6012a9b
refactor
younesbelkada Dec 29, 2022
d2c363f
remove text length sampler args from config
younesbelkada Dec 29, 2022
d048bbe
change collate_fn
younesbelkada Dec 29, 2022
66f23b1
fix silent bug
younesbelkada Dec 29, 2022
e318307
rename
younesbelkada Dec 29, 2022
31d12d6
move file
younesbelkada Dec 29, 2022
48c1070
refactor base trainer
younesbelkada Dec 29, 2022
e9cec71
fix collate
younesbelkada Dec 29, 2022
9a987d4
Merge remote-tracking branch 'origin/master' into accelerate-ppo
younesbelkada Dec 29, 2022
244f001
final bug
younesbelkada Dec 29, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 0 additions & 139 deletions examples/scripts/04-ppo-sentiment.py

This file was deleted.

150 changes: 150 additions & 0 deletions examples/scripts/ppo-sentiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from tqdm import tqdm
tqdm.pandas()

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

########################################################################
# This is a fully working simple example to use trl with accelerate.
#
# This example fine-tunes a GPT2 model on the IMDB dataset using PPO
# (proximal policy optimization).
# in any of the following settings (with the same script):
# - single CPU or single GPU
# - multi GPUS (using PyTorch distributed mode)
# - multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2)
# - fp16 (mixed-precision) or fp32 (normal precision)
#
# To run it in each of these various modes, first initialize the accelerate
# configuration with `accelerate config`
#
########################################################################

# We first define the configuration of the experiment, defining the model, the dataset,
# the training parameters, and the PPO parameters.
# Check the default arguments in the `PPOConfig` class for more details.
config = PPOConfig(
model_name="lvwerra/gpt2-imdb",
learning_rate=1.41e-5,
)

# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
sent_kwargs = {
"return_all_scores": True,
"function_to_apply": "none",
"batch_size": config.forward_batch_size
}

# Below is an example function to build the dataset. In our case, we use the IMDB dataset
# from the `datasets` library. One should customize this function to train the model on
# its own dataset.
def build_dataset(config, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
"""
Build dataset for training. This builds the dataset from `load_dataset`, one should
customize this function to train the model on its own dataset.

Args:
dataset_name (`str`):
The name of the dataset to be loaded.

Returns:
dataloader (`torch.utils.data.DataLoader`):
The dataloader for the dataset.
"""
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token
# load imdb with datasets
ds = load_dataset(dataset_name, split='train')
ds = ds.rename_columns({'text': 'review'})
ds = ds.filter(lambda x: len(x["review"])>200, batched=False)

input_size = LengthSampler(input_min_text_length, input_max_text_length)

def tokenize(sample):
sample["input_ids"] = tokenizer.encode(sample["review"])[:input_size()]
sample["query"] = tokenizer.decode(sample["input_ids"])
return sample

ds = ds.map(tokenize, batched=False)
ds.set_format(type='torch')
return ds

# We retrieve the dataloader by calling the `build_dataset` function.
dataset = build_dataset(config)

def collater(data):
return dict((key, [d[key] for d in data]) for key in data[0])

# Now let's build the model, the reference model, and the tokenizer.
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token.
# only for this model.
tokenizer.pad_token = tokenizer.eos_token

# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collater)

# We then build the sentiment analysis pipeline, passing the model name and the
# sentiment analysis pipeline arguments. Let's also make sure to set the device
# to the same device as the PPOTrainer.
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
sentiment_pipe = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=device)

# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.
generation_kwargs = {
"min_length":-1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id
}
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
query_tensors = batch['input_ids']

#### Get response from gpt2
response_tensors = []
for query in query_tensors:
gen_len = output_length_sampler()
generation_kwargs["max_new_tokens"] = gen_len
response = ppo_trainer.generate(query, **generation_kwargs)
response_tensors.append(response.squeeze()[-gen_len:])
batch['response'] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

#### Compute sentiment score
texts = [q + r for q,r in zip(batch['query'], batch['response'])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[1]["score"]).to(device) for output in pipe_outputs]

#### Run PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ datasets==1.17.0
torch>=1.4.0
tqdm
transformers
accelerate
wandb==0.10.20
matplotlib==3.5.1
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

requirements = cfg.get('requirements','').split()
extras = {
"test" : ["pytest","pytest-xdist",],
"dev" : ["pytest","pytest-xdist", "black", "isort", "flake8>=3.8.3"],
"test" : ["pytest","pytest-xdist","accelerate", "datasets", "wandb"],
"dev" : ["pytest","pytest-xdist", "black", "isort", "flake8>=3.8.3", "accelerate", "datasets", "wandb"],
}
lic = licenses[cfg['license']]
min_python = cfg['min_python']
Expand Down
38 changes: 30 additions & 8 deletions tests/test_gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,20 @@
from transformers import GPT2Tokenizer

from trl import AutoModelForCausalLMWithValueHead
from trl.gpt2 import respond_to_batch
from trl.core import respond_to_batch

from trl.ppo import PPOTrainer
from trl import PPOTrainer, PPOConfig

class DummyDataset(torch.utils.data.Dataset):
def __init__(self, query_data, response_data):
self.query_data = query_data
self.response_data = response_data

def __len__(self):
return len(self.query_data)

def __getitem__(self, idx):
return self.query_data[idx], self.response_data[idx]


def test_gpt2_model():
Expand All @@ -16,8 +27,8 @@ def test_gpt2_model():
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# initialize trainer
ppo_config = {"batch_size": 1, "forward_batch_size": 1}
ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **ppo_config)
ppo_config = {"batch_size": 2, "forward_batch_size": 1, "log_with_wandb": False}
ppo_config = PPOConfig(**ppo_config)

# encode a query
query_txt = "This morning I went to the "
Expand All @@ -28,12 +39,23 @@ def test_gpt2_model():
assert response_tensor.shape == (1, 20)
response_txt = gpt2_tokenizer.decode(response_tensor[0, :])

# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]
# create a dummy dataset
min_length = min(len(query_tensor[0]), len(response_tensor[0]))
dummy_dataset = DummyDataset([query_tensor[:, :min_length].squeeze(0) for _ in range(2)], [response_tensor[:, :min_length].squeeze(0) for _ in range(2)])
dummy_dataloader = torch.utils.data.DataLoader(
dummy_dataset, batch_size=2, shuffle=True
)

ppo_trainer = PPOTrainer(config=ppo_config, model=gpt2_model, ref_model=gpt2_model_ref, tokenizer=gpt2_tokenizer, dataset=dummy_dataset)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
for query_tensor, response_tensor in dummy_dataloader:
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
break

EXPECTED_STATS = [
"objective/kl",
Expand Down
3 changes: 2 additions & 1 deletion trl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__version__ = "0.1.1"

from .models import AutoModelForCausalLMWithValueHead
from .models import AutoModelForCausalLMWithValueHead
from .trainer import PPOTrainer, PPOConfig
Loading