Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Reward Backpropogation Support #1585

Merged
merged 17 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions docs/source/alignprop_trainer.mdx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice documentation page !

Copy link
Contributor Author

@mihirp1998 mihirp1998 Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Yes the import wasn't needed, i committed it.

Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation

## The why
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think as a reader I understand if the following table justifies the name of this section. Would you mind elaborating?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I added a better why statement.


| Before | After finetuning |
| --- | --- |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"/></div> |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"/></div> |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These images are the ones generated from DDPO no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, i wanted to update them, although I wasn't sure how to do it, as they linked to a huggingface internal webpage https://huggingface.co/datasets/trl-internal-testing/

If you can guide me on how to do it, i can update them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can open a PR to https://huggingface.co/datasets/trl-internal-testing/ repository adding the resultant images you want.



## Getting started with Stable Diffusion finetuning with reinforcement learning
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is needed. We should strive to keep the API documentation lean and precise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i removed it.


The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made.

There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `AlignPropTrainer`, which is one of the methods for fine-tuning Stable Diffusion with reward backpropagation. **Note: Only the StableDiffusion architecture is supported at this point.**
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.

The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is one of the way to pass gradients from the scheduler step befitting of the algorithm at hand (AlignProp).

For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)

Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.

Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.

## Getting started with `examples/scripts/alignprop.py`

The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).

**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.

Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running

```batch
python alignprop.py --hf_user_access_token <token>
```

To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`

The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)

- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False

## Setting up the image logging hook function

Expect the function to be given a dictionary with keys
```python
['image', 'prompt', 'prompt_metadata', 'rewards']

```
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
You are free to log however you want the use of `wandb` or `tensorboard` is recommended.

### Key terms

- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
- `prompt` : The prompt is the text that is used to generate the image
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
- `image` : The image generated by the Stable Diffusion model

Example code for logging sampled images with `wandb` is given below.

```python
# for logging these images to wandb

def image_outputs_hook(image_data, global_step, accelerate_logger):
# For the sake of this example, we only care about the last batch
# hence we extract the last element of the list
result = {}
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
for i, image in enumerate(images):
pil = Image.fromarray(
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((256, 256))
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
accelerate_logger.log_images(
result,
step=global_step,
)

```

### Using the finetuned model

Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows

```python

import torch
from trl import DefaultDDPOStableDiffusionPipeline
Copy link
Member

@sayakpaul sayakpaul May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to use a non-diffusers pipeline here? Does DiffusionPipeline from diffusers not work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, i changed it to StableDiffusionPipeline from diffusers


pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/alignprop-finetuned-sd-model")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# memory optimization
pipeline.vae.to(device, torch.float16)
pipeline.text_encoder.to(device, torch.float16)
pipeline.unet.to(device, torch.float16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These LoCs could be reduce if we do:

pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/alignprop-finetuned-sd-model", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, https://huggingface.co/metric-space/alignprop-finetuned-sd-model is not available. Let's make sure we're using the right checkpoint ids here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i reduced it and fixed the checkpoint ids.


prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)

for prompt, image in zip(prompts,results.images):
image.save(f"{prompt}.png")

```

## Credits

This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://arxiv.org/abs/2310.03739).
212 changes: 212 additions & 0 deletions examples/scripts/alignprop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python examples/scripts/ddpo.py \
--num_epochs=200 \
--train_gradient_accumulation_steps=1 \
--sample_num_steps=50 \
--sample_batch_size=6 \
--train_batch_size=3 \
--sample_num_batches_per_epoch=4 \
--per_prompt_stat_tracking=True \
--per_prompt_stat_tracking_buffer_size=32 \
--tracker_project_name="stable_diffusion_training" \
--log_with="wandb"
"""
import os
import torchvision
from dataclasses import dataclass, field

import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from transformers import CLIPModel, CLIPProcessor, HfArgumentParser

from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline
from trl.import_utils import is_npu_available, is_xpu_available


@dataclass
class ScriptArguments:
pretrained_model: str = field(
default="runwayml/stable-diffusion-v1-5", metadata={"help": "the pretrained model to use"}
)
pretrained_revision: str = field(default="main", metadata={"help": "the pretrained model revision to use"})
hf_hub_model_id: str = field(
default="ddpo-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"}
)
hf_hub_aesthetic_model_id: str = field(
default="trl-lib/ddpo-aesthetic-predictor",
metadata={"help": "HuggingFace model ID for aesthetic scorer model weights"},
)
hf_hub_aesthetic_model_filename: str = field(
default="aesthetic-model.pth",
metadata={"help": "HuggingFace model filename for aesthetic scorer model weights"},
)
use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."})


class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(768, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)

def forward(self, embed):
return self.layers(embed)


class AestheticScorer(torch.nn.Module):
"""
This model attempts to predict the aesthetic score of an image. The aesthetic score
is a numerical approximation of how much a specific image is liked by humans on average.
This is from https://github.com/christophschuhmann/improved-aesthetic-predictor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we copy-pasting these modules from the DDPO script?

@younesbelkada would it make sense to have a separate module for these (auxiliary_modules, perhaps)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not exactly copy pasted, as DDPO had clamp and no_grad operations within them, which were preventing gradients from backpropagating.

Anyhow I still transfered the above reward function code from alignprop.py to trl/models/auxiliary_modules.py, as u suggested.

"""

def __init__(self, *, dtype, model_id, model_filename):
super().__init__()
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
self.target_size = 224
self.mlp = MLP()
try:
cached_path = hf_hub_download(model_id, model_filename)
except EntryNotFoundError:
cached_path = os.path.join(model_id, model_filename)
state_dict = torch.load(cached_path, map_location=torch.device("cpu"))
self.mlp.load_state_dict(state_dict)
self.dtype = dtype
self.eval()

def __call__(self, images):
device = next(self.parameters()).device
images = torchvision.transforms.Resize(self.target_size)(images)
images = self.normalize(images).to(self.dtype).to(device)
embed = self.clip.get_image_features(pixel_values=images)
# normalize embedding
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
reward = self.mlp(embed).squeeze(1)
return reward


def aesthetic_scorer(hub_model_id, model_filename):
scorer = AestheticScorer(
model_id=hub_model_id,
model_filename=model_filename,
dtype=torch.float32,
)
if is_npu_available():
scorer = scorer.npu()
elif is_xpu_available():
scorer = scorer.xpu()
else:
scorer = scorer.cuda()

def _fn(images, prompts, metadata):
images = (images).clamp(0, 1)
scores = scorer(images)
return scores, {}

return _fn


# list of example prompts to feed stable diffusion
animals = [
"cat",
"dog",
"horse",
"monkey",
"rabbit",
"zebra",
"spider",
"bird",
"sheep",
"deer",
"cow",
"goat",
"lion",
"frog",
"chicken",
"duck",
"goose",
"bee",
"pig",
"turkey",
"fly",
"llama",
"camel",
"bat",
"gorilla",
"hedgehog",
"kangaroo",
]


def prompt_fn():
return np.random.choice(animals), {}



def image_outputs_logger(image_pair_data, global_step, accelerate_logger):
# For the sake of this example, we will only log the last batch of images
# and associated data
result = {}
images, prompts, rewards = [image_pair_data['images'],image_pair_data['prompts'],image_pair_data['rewards']]
for i, image in enumerate(images[:4]):
prompt = prompts[i]
reward = rewards[i].item()
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float()

accelerate_logger.log_images(
result,
step=global_step,
)


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, AlignPropConfig))
args, alignprop_config = parser.parse_args_into_dataclasses()
alignprop_config.project_kwargs = {
"logging_dir": "./logs",
"automatic_checkpoint_naming": True,
"total_limit": 5,
"project_dir": "./save",
}

pipeline = DefaultDDPOStableDiffusionPipeline(
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora
)

trainer = AlignPropTrainer(
alignprop_config,
aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename),
prompt_fn,
pipeline,
image_samples_hook=image_outputs_logger,
)

trainer.train()

trainer.push_to_hub(args.hf_hub_model_id)
Loading