-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Reward Backpropogation Support #1585
Changes from 13 commits
209fb01
0db8104
8030ac7
4b9dc1a
af84272
f3ff177
c3fe757
4f8501e
34af985
c0a6ce3
fbeafbc
d804207
405db53
8296a07
4607f96
fce0de2
32ec653
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation | ||
|
||
## The why | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think as a reader I understand if the following table justifies the name of this section. Would you mind elaborating? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I added a better why statement. |
||
|
||
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO. | ||
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation. | ||
|
||
<div style="text-align: center"><img src="https://align-prop.github.io/reward_tuning.png"/></div> | ||
|
||
|
||
## Getting started with `examples/scripts/alignprop.py` | ||
|
||
The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`). | ||
|
||
**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1. | ||
|
||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running | ||
|
||
```batch | ||
python alignprop.py --hf_user_access_token <token> | ||
``` | ||
|
||
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help` | ||
|
||
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script) | ||
|
||
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps) | ||
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False | ||
|
||
## Setting up the image logging hook function | ||
|
||
Expect the function to be given a dictionary with keys | ||
```python | ||
['image', 'prompt', 'prompt_metadata', 'rewards'] | ||
|
||
``` | ||
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched. | ||
You are free to log however you want the use of `wandb` or `tensorboard` is recommended. | ||
|
||
### Key terms | ||
|
||
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process | ||
- `prompt` : The prompt is the text that is used to generate the image | ||
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45) | ||
- `image` : The image generated by the Stable Diffusion model | ||
|
||
Example code for logging sampled images with `wandb` is given below. | ||
|
||
```python | ||
# for logging these images to wandb | ||
|
||
def image_outputs_hook(image_data, global_step, accelerate_logger): | ||
# For the sake of this example, we only care about the last batch | ||
# hence we extract the last element of the list | ||
result = {} | ||
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']] | ||
for i, image in enumerate(images): | ||
pil = Image.fromarray( | ||
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8) | ||
) | ||
pil = pil.resize((256, 256)) | ||
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil] | ||
accelerate_logger.log_images( | ||
result, | ||
step=global_step, | ||
) | ||
|
||
``` | ||
|
||
### Using the finetuned model | ||
|
||
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows | ||
|
||
```python | ||
from diffusers import StableDiffusionPipeline | ||
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") | ||
pipeline.to("cuda") | ||
|
||
pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics') | ||
|
||
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"] | ||
results = pipeline(prompts) | ||
|
||
for prompt, image in zip(prompts,results.images): | ||
image.save(f"dump/{prompt}.png") | ||
``` | ||
|
||
## Credits | ||
|
||
This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation | ||
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://arxiv.org/abs/2310.03739). |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
python examples/scripts/alignprop.py \ | ||
--num_epochs=20 \ | ||
--train_gradient_accumulation_steps=4 \ | ||
--sample_num_steps=50 \ | ||
--train_batch_size=8 \ | ||
--tracker_project_name="stable_diffusion_training" \ | ||
--log_with="wandb" | ||
""" | ||
import os | ||
from dataclasses import dataclass, field | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from transformers import HfArgumentParser | ||
from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline | ||
from trl.models.auxiliary_modules import aesthetic_scorer | ||
|
||
|
||
@dataclass | ||
class ScriptArguments: | ||
pretrained_model: str = field( | ||
default="runwayml/stable-diffusion-v1-5", metadata={"help": "the pretrained model to use"} | ||
) | ||
pretrained_revision: str = field(default="main", metadata={"help": "the pretrained model revision to use"}) | ||
hf_hub_model_id: str = field( | ||
default="alignprop-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"} | ||
) | ||
hf_hub_aesthetic_model_id: str = field( | ||
default="trl-lib/ddpo-aesthetic-predictor", | ||
metadata={"help": "HuggingFace model ID for aesthetic scorer model weights"}, | ||
) | ||
hf_hub_aesthetic_model_filename: str = field( | ||
default="aesthetic-model.pth", | ||
metadata={"help": "HuggingFace model filename for aesthetic scorer model weights"}, | ||
) | ||
use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."}) | ||
|
||
|
||
|
||
# list of example prompts to feed stable diffusion | ||
animals = [ | ||
"cat", | ||
"dog", | ||
"horse", | ||
"monkey", | ||
"rabbit", | ||
"zebra", | ||
"spider", | ||
"bird", | ||
"sheep", | ||
"deer", | ||
"cow", | ||
"goat", | ||
"lion", | ||
"frog", | ||
"chicken", | ||
"duck", | ||
"goose", | ||
"bee", | ||
"pig", | ||
"turkey", | ||
"fly", | ||
"llama", | ||
"camel", | ||
"bat", | ||
"gorilla", | ||
"hedgehog", | ||
"kangaroo", | ||
] | ||
|
||
|
||
def prompt_fn(): | ||
return np.random.choice(animals), {} | ||
|
||
|
||
|
||
def image_outputs_logger(image_pair_data, global_step, accelerate_logger): | ||
# For the sake of this example, we will only log the last batch of images | ||
# and associated data | ||
result = {} | ||
images, prompts, rewards = [image_pair_data['images'],image_pair_data['prompts'],image_pair_data['rewards']] | ||
for i, image in enumerate(images[:4]): | ||
prompt = prompts[i] | ||
reward = rewards[i].item() | ||
result[f"{prompt}"] = image.unsqueeze(0).float() | ||
accelerate_logger.log_images( | ||
result, | ||
step=global_step, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = HfArgumentParser((ScriptArguments, AlignPropConfig)) | ||
args, alignprop_config = parser.parse_args_into_dataclasses() | ||
alignprop_config.project_kwargs = { | ||
"logging_dir": "./logs", | ||
"automatic_checkpoint_naming": True, | ||
"total_limit": 5, | ||
"project_dir": "./save", | ||
} | ||
|
||
pipeline = DefaultDDPOStableDiffusionPipeline( | ||
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora | ||
) | ||
trainer = AlignPropTrainer( | ||
alignprop_config, | ||
aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename), | ||
prompt_fn, | ||
pipeline, | ||
image_samples_hook=image_outputs_logger, | ||
) | ||
|
||
trainer.train() | ||
|
||
trainer.push_to_hub(args.hf_hub_model_id) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import gc | ||
import unittest | ||
import torch | ||
|
||
from trl import is_diffusers_available, is_peft_available | ||
|
||
from .testing_utils import require_diffusers | ||
|
||
|
||
if is_diffusers_available() and is_peft_available(): | ||
from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline | ||
|
||
|
||
def scorer_function(images, prompts, metadata): | ||
return torch.randn(1) * 3.0, {} | ||
|
||
|
||
def prompt_function(): | ||
return ("cabbages", {}) | ||
|
||
|
||
@require_diffusers | ||
class AlignPropTrainerTester(unittest.TestCase): | ||
""" | ||
Test the AlignPropTrainer class. | ||
""" | ||
|
||
def setUp(self): | ||
self.alignprop_config = AlignPropConfig( | ||
num_epochs=2, | ||
train_gradient_accumulation_steps=1, | ||
train_batch_size=2, | ||
truncated_backprop_rand=False, | ||
mixed_precision=None, | ||
save_freq=1000000, | ||
) | ||
pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch" | ||
pretrained_revision = "main" | ||
|
||
pipeline = DefaultDDPOStableDiffusionPipeline( | ||
pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=False | ||
) | ||
|
||
self.trainer = AlignPropTrainer(self.alignprop_config, scorer_function, prompt_function, pipeline) | ||
|
||
return super().setUp() | ||
|
||
def tearDown(self) -> None: | ||
gc.collect() | ||
|
||
def test_generate_samples(self): | ||
output_pairs = self.trainer._generate_samples(2, with_grad=True) | ||
assert len(output_pairs.keys()) == 3 | ||
assert len(output_pairs['images']) == 2 | ||
|
||
def test_calculate_loss(self): | ||
sample = self.trainer._generate_samples(2) | ||
|
||
images = sample["images"] | ||
prompts = sample["prompts"] | ||
|
||
assert images.shape == (2, 3, 128, 128) | ||
assert len(prompts) == 2 | ||
|
||
rewards = self.trainer.compute_rewards( | ||
sample | ||
) | ||
loss = self.trainer.calculate_loss(rewards) | ||
|
||
assert torch.isfinite(loss.cpu()) | ||
|
||
|
||
@require_diffusers | ||
class AlignPropTrainerWithLoRATester(AlignPropTrainerTester): | ||
""" | ||
Test the AlignPropTrainer class. | ||
""" | ||
|
||
def setUp(self): | ||
self.alignprop_config = AlignPropConfig( | ||
num_epochs=2, | ||
train_gradient_accumulation_steps=1, | ||
mixed_precision=None, | ||
truncated_backprop_rand=False, | ||
save_freq=1000000, | ||
) | ||
|
||
pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch" | ||
pretrained_revision = "main" | ||
|
||
pipeline = DefaultDDPOStableDiffusionPipeline( | ||
pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=True | ||
) | ||
|
||
self.trainer = AlignPropTrainer(self.alignprop_config, scorer_function, prompt_function, pipeline) | ||
|
||
return super().setUp() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice documentation page !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Yes the import wasn't needed, i committed it.