-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Reward Backpropogation Support #1585
Changes from 4 commits
209fb01
0db8104
8030ac7
4b9dc1a
af84272
f3ff177
c3fe757
4f8501e
34af985
c0a6ce3
fbeafbc
d804207
405db53
8296a07
4607f96
fce0de2
32ec653
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation | ||
|
||
## The why | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think as a reader I understand if the following table justifies the name of this section. Would you mind elaborating? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I added a better why statement. |
||
|
||
| Before | After finetuning | | ||
| --- | --- | | ||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"/></div> | | ||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"/></div> | | ||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> | | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These images are the ones generated from DDPO no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, i wanted to update them, although I wasn't sure how to do it, as they linked to a huggingface internal webpage https://huggingface.co/datasets/trl-internal-testing/ If you can guide me on how to do it, i can update them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can open a PR to https://huggingface.co/datasets/trl-internal-testing/ repository adding the resultant images you want. |
||
|
||
|
||
## Getting started with Stable Diffusion finetuning with reinforcement learning | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is needed. We should strive to keep the API documentation lean and precise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes i removed it. |
||
|
||
The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers` | ||
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers. | ||
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made. | ||
|
||
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.** | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here it references DDPO trainer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing this. I have fixed this. |
||
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide. | ||
|
||
The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO). | ||
|
||
For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py) | ||
|
||
Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training. | ||
|
||
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images. | ||
|
||
## Getting started with `examples/scripts/alignprop.py` | ||
|
||
The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`). | ||
|
||
**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1. | ||
|
||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running | ||
|
||
```batch | ||
python alignprop.py --hf_user_access_token <token> | ||
``` | ||
|
||
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help` | ||
|
||
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script) | ||
|
||
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps) | ||
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False | ||
|
||
## Setting up the image logging hook function | ||
|
||
Expect the function to be given a dictionary with keys | ||
```python | ||
['image', 'prompt', 'prompt_metadata', 'rewards'] | ||
|
||
``` | ||
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched. | ||
You are free to log however you want the use of `wandb` or `tensorboard` is recommended. | ||
|
||
### Key terms | ||
|
||
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process | ||
- `prompt` : The prompt is the text that is used to generate the image | ||
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45) | ||
- `image` : The image generated by the Stable Diffusion model | ||
|
||
Example code for logging sampled images with `wandb` is given below. | ||
|
||
```python | ||
# for logging these images to wandb | ||
|
||
def image_outputs_hook(image_data, global_step, accelerate_logger): | ||
# For the sake of this example, we only care about the last batch | ||
# hence we extract the last element of the list | ||
result = {} | ||
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']] | ||
for i, image in enumerate(images): | ||
pil = Image.fromarray( | ||
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8) | ||
) | ||
pil = pil.resize((256, 256)) | ||
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil] | ||
accelerate_logger.log_images( | ||
result, | ||
step=global_step, | ||
) | ||
|
||
``` | ||
|
||
### Using the finetuned model | ||
|
||
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows | ||
|
||
```python | ||
|
||
import torch | ||
from trl import DefaultDDPOStableDiffusionPipeline | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have to use a non-diffusers pipeline here? Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes indeed, i changed it to StableDiffusionPipeline from diffusers |
||
|
||
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/alignprop-finetuned-sd-model") | ||
|
||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | ||
|
||
# memory optimization | ||
pipeline.vae.to(device, torch.float16) | ||
pipeline.text_encoder.to(device, torch.float16) | ||
pipeline.unet.to(device, torch.float16) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These LoCs could be reduce if we do: pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/alignprop-finetuned-sd-model", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Additionally, https://huggingface.co/metric-space/alignprop-finetuned-sd-model is not available. Let's make sure we're using the right checkpoint ids here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes i reduced it and fixed the checkpoint ids. |
||
|
||
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"] | ||
results = pipeline(prompts) | ||
|
||
for prompt, image in zip(prompts,results.images): | ||
image.save(f"{prompt}.png") | ||
|
||
``` | ||
|
||
## Credits | ||
|
||
This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation | ||
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://arxiv.org/abs/2310.03739). |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
python examples/scripts/ddpo.py \ | ||
--num_epochs=200 \ | ||
--train_gradient_accumulation_steps=1 \ | ||
--sample_num_steps=50 \ | ||
--sample_batch_size=6 \ | ||
--train_batch_size=3 \ | ||
--sample_num_batches_per_epoch=4 \ | ||
--per_prompt_stat_tracking=True \ | ||
--per_prompt_stat_tracking_buffer_size=32 \ | ||
--tracker_project_name="stable_diffusion_training" \ | ||
--log_with="wandb" | ||
""" | ||
import os | ||
import torchvision | ||
from dataclasses import dataclass, field | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from huggingface_hub import hf_hub_download | ||
from huggingface_hub.utils import EntryNotFoundError | ||
from transformers import CLIPModel, CLIPProcessor, HfArgumentParser | ||
|
||
from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline | ||
from trl.import_utils import is_npu_available, is_xpu_available | ||
|
||
|
||
@dataclass | ||
class ScriptArguments: | ||
pretrained_model: str = field( | ||
default="runwayml/stable-diffusion-v1-5", metadata={"help": "the pretrained model to use"} | ||
) | ||
pretrained_revision: str = field(default="main", metadata={"help": "the pretrained model revision to use"}) | ||
hf_hub_model_id: str = field( | ||
default="ddpo-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"} | ||
) | ||
hf_hub_aesthetic_model_id: str = field( | ||
default="trl-lib/ddpo-aesthetic-predictor", | ||
metadata={"help": "HuggingFace model ID for aesthetic scorer model weights"}, | ||
) | ||
hf_hub_aesthetic_model_filename: str = field( | ||
default="aesthetic-model.pth", | ||
metadata={"help": "HuggingFace model filename for aesthetic scorer model weights"}, | ||
) | ||
use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."}) | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.layers = nn.Sequential( | ||
nn.Linear(768, 1024), | ||
nn.Dropout(0.2), | ||
nn.Linear(1024, 128), | ||
nn.Dropout(0.2), | ||
nn.Linear(128, 64), | ||
nn.Dropout(0.1), | ||
nn.Linear(64, 16), | ||
nn.Linear(16, 1), | ||
) | ||
|
||
def forward(self, embed): | ||
return self.layers(embed) | ||
|
||
|
||
class AestheticScorer(torch.nn.Module): | ||
""" | ||
This model attempts to predict the aesthetic score of an image. The aesthetic score | ||
is a numerical approximation of how much a specific image is liked by humans on average. | ||
This is from https://github.com/christophschuhmann/improved-aesthetic-predictor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we copy-pasting these modules from the DDPO script? @younesbelkada would it make sense to have a separate module for these ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are not exactly copy pasted, as DDPO had clamp and no_grad operations within them, which were preventing gradients from backpropagating. Anyhow I still transfered the above reward function code from alignprop.py to trl/models/auxiliary_modules.py, as u suggested. |
||
""" | ||
|
||
def __init__(self, *, dtype, model_id, model_filename): | ||
super().__init__() | ||
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||
self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | ||
std=[0.26862954, 0.26130258, 0.27577711]) | ||
self.target_size = 224 | ||
self.mlp = MLP() | ||
try: | ||
cached_path = hf_hub_download(model_id, model_filename) | ||
except EntryNotFoundError: | ||
cached_path = os.path.join(model_id, model_filename) | ||
state_dict = torch.load(cached_path, map_location=torch.device("cpu")) | ||
self.mlp.load_state_dict(state_dict) | ||
self.dtype = dtype | ||
self.eval() | ||
|
||
def __call__(self, images): | ||
device = next(self.parameters()).device | ||
images = torchvision.transforms.Resize(self.target_size)(images) | ||
images = self.normalize(images).to(self.dtype).to(device) | ||
embed = self.clip.get_image_features(pixel_values=images) | ||
# normalize embedding | ||
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) | ||
reward = self.mlp(embed).squeeze(1) | ||
return reward | ||
|
||
|
||
def aesthetic_scorer(hub_model_id, model_filename): | ||
scorer = AestheticScorer( | ||
model_id=hub_model_id, | ||
model_filename=model_filename, | ||
dtype=torch.float32, | ||
) | ||
if is_npu_available(): | ||
scorer = scorer.npu() | ||
elif is_xpu_available(): | ||
scorer = scorer.xpu() | ||
else: | ||
scorer = scorer.cuda() | ||
|
||
def _fn(images, prompts, metadata): | ||
images = (images).clamp(0, 1) | ||
scores = scorer(images) | ||
return scores, {} | ||
|
||
return _fn | ||
|
||
|
||
# list of example prompts to feed stable diffusion | ||
animals = [ | ||
"cat", | ||
"dog", | ||
"horse", | ||
"monkey", | ||
"rabbit", | ||
"zebra", | ||
"spider", | ||
"bird", | ||
"sheep", | ||
"deer", | ||
"cow", | ||
"goat", | ||
"lion", | ||
"frog", | ||
"chicken", | ||
"duck", | ||
"goose", | ||
"bee", | ||
"pig", | ||
"turkey", | ||
"fly", | ||
"llama", | ||
"camel", | ||
"bat", | ||
"gorilla", | ||
"hedgehog", | ||
"kangaroo", | ||
] | ||
|
||
|
||
def prompt_fn(): | ||
return np.random.choice(animals), {} | ||
|
||
|
||
|
||
def image_outputs_logger(image_pair_data, global_step, accelerate_logger): | ||
# For the sake of this example, we will only log the last batch of images | ||
# and associated data | ||
result = {} | ||
images, prompts, rewards = [image_pair_data['images'],image_pair_data['prompts'],image_pair_data['rewards']] | ||
for i, image in enumerate(images[:4]): | ||
prompt = prompts[i] | ||
reward = rewards[i].item() | ||
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float() | ||
|
||
accelerate_logger.log_images( | ||
result, | ||
step=global_step, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = HfArgumentParser((ScriptArguments, AlignPropConfig)) | ||
args, alignprop_config = parser.parse_args_into_dataclasses() | ||
alignprop_config.project_kwargs = { | ||
"logging_dir": "./logs", | ||
"automatic_checkpoint_naming": True, | ||
"total_limit": 5, | ||
"project_dir": "./save", | ||
} | ||
|
||
pipeline = DefaultDDPOStableDiffusionPipeline( | ||
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora | ||
) | ||
|
||
trainer = AlignPropTrainer( | ||
alignprop_config, | ||
aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename), | ||
prompt_fn, | ||
pipeline, | ||
image_samples_hook=image_outputs_logger, | ||
) | ||
|
||
trainer.train() | ||
|
||
trainer.push_to_hub(args.hf_hub_model_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice documentation page !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Yes the import wasn't needed, i committed it.