Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Reward Backpropogation Support #1585

Merged
merged 17 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
title: CPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: alignprop_trainer
title: AlignProp Trainer
- local: orpo_trainer
title: ORPO Trainer
- local: iterative_sft_trainer
Expand Down
91 changes: 91 additions & 0 deletions docs/source/alignprop_trainer.mdx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice documentation page !

Copy link
Contributor Author

@mihirp1998 mihirp1998 Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Yes the import wasn't needed, i committed it.

Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation

## The why
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think as a reader I understand if the following table justifies the name of this section. Would you mind elaborating?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I added a better why statement.


If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.

<div style="text-align: center"><img src="https://align-prop.github.io/reward_tuning.png"/></div>


## Getting started with `examples/scripts/alignprop.py`

The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).

**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.

Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running

```batch
python alignprop.py --hf_user_access_token <token>
```

To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`

The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)

- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False

## Setting up the image logging hook function

Expect the function to be given a dictionary with keys
```python
['image', 'prompt', 'prompt_metadata', 'rewards']

```
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
You are free to log however you want the use of `wandb` or `tensorboard` is recommended.

### Key terms

- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
- `prompt` : The prompt is the text that is used to generate the image
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
- `image` : The image generated by the Stable Diffusion model

Example code for logging sampled images with `wandb` is given below.

```python
# for logging these images to wandb

def image_outputs_hook(image_data, global_step, accelerate_logger):
# For the sake of this example, we only care about the last batch
# hence we extract the last element of the list
result = {}
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
for i, image in enumerate(images):
pil = Image.fromarray(
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((256, 256))
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
accelerate_logger.log_images(
result,
step=global_step,
)

```

### Using the finetuned model

Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows

```python
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipeline.to("cuda")

pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')

prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)

for prompt, image in zip(prompts,results.images):
image.save(f"dump/{prompt}.png")
```

## Credits

This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://arxiv.org/abs/2310.03739).
129 changes: 129 additions & 0 deletions examples/scripts/alignprop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Total Batch size = 128 = 4 (num_gpus) * 8 (per_device_batch) * 4 (accumulation steps)
Feel free to reduce batch size or increasing truncated_rand_backprop_min to a higher value to reduce memory usage.

CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/scripts/alignprop.py \
--num_epochs=20 \
--train_gradient_accumulation_steps=4 \
--sample_num_steps=50 \
--train_batch_size=8 \
--tracker_project_name="stable_diffusion_training" \
--log_with="wandb"

"""
from dataclasses import dataclass, field

import numpy as np
from transformers import HfArgumentParser

from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline
from trl.models.auxiliary_modules import aesthetic_scorer


@dataclass
class ScriptArguments:
pretrained_model: str = field(
default="runwayml/stable-diffusion-v1-5", metadata={"help": "the pretrained model to use"}
)
pretrained_revision: str = field(default="main", metadata={"help": "the pretrained model revision to use"})
hf_hub_model_id: str = field(
default="alignprop-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"}
)
hf_hub_aesthetic_model_id: str = field(
default="trl-lib/ddpo-aesthetic-predictor",
metadata={"help": "HuggingFace model ID for aesthetic scorer model weights"},
)
hf_hub_aesthetic_model_filename: str = field(
default="aesthetic-model.pth",
metadata={"help": "HuggingFace model filename for aesthetic scorer model weights"},
)
use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."})


# list of example prompts to feed stable diffusion
animals = [
"cat",
"dog",
"horse",
"monkey",
"rabbit",
"zebra",
"spider",
"bird",
"sheep",
"deer",
"cow",
"goat",
"lion",
"frog",
"chicken",
"duck",
"goose",
"bee",
"pig",
"turkey",
"fly",
"llama",
"camel",
"bat",
"gorilla",
"hedgehog",
"kangaroo",
]


def prompt_fn():
return np.random.choice(animals), {}


def image_outputs_logger(image_pair_data, global_step, accelerate_logger):
# For the sake of this example, we will only log the last batch of images
# and associated data
result = {}
images, prompts, _ = [image_pair_data["images"], image_pair_data["prompts"], image_pair_data["rewards"]]
for i, image in enumerate(images[:4]):
prompt = prompts[i]
result[f"{prompt}"] = image.unsqueeze(0).float()
accelerate_logger.log_images(
result,
step=global_step,
)


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, AlignPropConfig))
args, alignprop_config = parser.parse_args_into_dataclasses()
alignprop_config.project_kwargs = {
"logging_dir": "./logs",
"automatic_checkpoint_naming": True,
"total_limit": 5,
"project_dir": "./save",
}

pipeline = DefaultDDPOStableDiffusionPipeline(
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora
)
trainer = AlignPropTrainer(
alignprop_config,
aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename),
prompt_fn,
pipeline,
image_samples_hook=image_outputs_logger,
)

trainer.train()

trainer.push_to_hub(args.hf_hub_model_id)
109 changes: 109 additions & 0 deletions tests/test_alignprop_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest

import torch

from trl import is_diffusers_available, is_peft_available

from .testing_utils import require_diffusers


if is_diffusers_available() and is_peft_available():
from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline


def scorer_function(images, prompts, metadata):
return torch.randn(1) * 3.0, {}


def prompt_function():
return ("cabbages", {})


@require_diffusers
class AlignPropTrainerTester(unittest.TestCase):
"""
Test the AlignPropTrainer class.
"""

def setUp(self):
self.alignprop_config = AlignPropConfig(
num_epochs=2,
train_gradient_accumulation_steps=1,
train_batch_size=2,
truncated_backprop_rand=False,
mixed_precision=None,
save_freq=1000000,
)
pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch"
pretrained_revision = "main"

pipeline = DefaultDDPOStableDiffusionPipeline(
pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=False
)

self.trainer = AlignPropTrainer(self.alignprop_config, scorer_function, prompt_function, pipeline)

return super().setUp()

def tearDown(self) -> None:
gc.collect()

def test_generate_samples(self):
output_pairs = self.trainer._generate_samples(2, with_grad=True)
assert len(output_pairs.keys()) == 3
assert len(output_pairs["images"]) == 2

def test_calculate_loss(self):
sample = self.trainer._generate_samples(2)

images = sample["images"]
prompts = sample["prompts"]

assert images.shape == (2, 3, 128, 128)
assert len(prompts) == 2

rewards = self.trainer.compute_rewards(sample)
loss = self.trainer.calculate_loss(rewards)

assert torch.isfinite(loss.cpu())


@require_diffusers
class AlignPropTrainerWithLoRATester(AlignPropTrainerTester):
"""
Test the AlignPropTrainer class.
"""

def setUp(self):
self.alignprop_config = AlignPropConfig(
num_epochs=2,
train_gradient_accumulation_steps=1,
mixed_precision=None,
truncated_backprop_rand=False,
save_freq=1000000,
)

pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch"
pretrained_revision = "main"

pipeline = DefaultDDPOStableDiffusionPipeline(
pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=True
)

self.trainer = AlignPropTrainer(self.alignprop_config, scorer_function, prompt_function, pipeline)

return super().setUp()
4 changes: 4 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
"DPOTrainer",
"CPOConfig",
"CPOTrainer",
"AlignPropConfig",
"AlignPropTrainer",
"IterativeSFTTrainer",
"KTOConfig",
"KTOTrainer",
Expand Down Expand Up @@ -105,6 +107,8 @@
DPOTrainer,
CPOConfig,
CPOTrainer,
AlignPropConfig,
AlignPropTrainer,
IterativeSFTTrainer,
KTOConfig,
KTOTrainer,
Expand Down
Loading
Loading