From 209fb01a01b8974a3a5d33b7b60f3d2e1ca975fe Mon Sep 17 00:00:00 2001 From: matrix Mihir Prabhduesai Date: Mon, 22 Apr 2024 23:35:41 -0400 Subject: [PATCH 01/15] added alignprop template --- examples/scripts/alignprop.py | 210 +++++++++++ tests/test_alignprop_trainer.py | 127 +++++++ trl/__init__.py | 5 + trl/trainer/__init__.py | 3 + trl/trainer/alignprop_config.py | 120 ++++++ trl/trainer/alignprop_trainer.py | 629 +++++++++++++++++++++++++++++++ 6 files changed, 1094 insertions(+) create mode 100644 examples/scripts/alignprop.py create mode 100644 tests/test_alignprop_trainer.py create mode 100644 trl/trainer/alignprop_config.py create mode 100644 trl/trainer/alignprop_trainer.py diff --git a/examples/scripts/alignprop.py b/examples/scripts/alignprop.py new file mode 100644 index 0000000000..8e43202684 --- /dev/null +++ b/examples/scripts/alignprop.py @@ -0,0 +1,210 @@ +# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +python examples/scripts/ddpo.py \ + --num_epochs=200 \ + --train_gradient_accumulation_steps=1 \ + --sample_num_steps=50 \ + --sample_batch_size=6 \ + --train_batch_size=3 \ + --sample_num_batches_per_epoch=4 \ + --per_prompt_stat_tracking=True \ + --per_prompt_stat_tracking_buffer_size=32 \ + --tracker_project_name="stable_diffusion_training" \ + --log_with="wandb" +""" +import os +from dataclasses import dataclass, field + +import numpy as np +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError +from transformers import CLIPModel, CLIPProcessor, HfArgumentParser + +from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline +from trl.import_utils import is_npu_available, is_xpu_available + + +@dataclass +class ScriptArguments: + pretrained_model: str = field( + default="runwayml/stable-diffusion-v1-5", metadata={"help": "the pretrained model to use"} + ) + pretrained_revision: str = field(default="main", metadata={"help": "the pretrained model revision to use"}) + hf_hub_model_id: str = field( + default="ddpo-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"} + ) + hf_hub_aesthetic_model_id: str = field( + default="trl-lib/ddpo-aesthetic-predictor", + metadata={"help": "HuggingFace model ID for aesthetic scorer model weights"}, + ) + hf_hub_aesthetic_model_filename: str = field( + default="aesthetic-model.pth", + metadata={"help": "HuggingFace model filename for aesthetic scorer model weights"}, + ) + use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."}) + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(768, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + @torch.no_grad() + def forward(self, embed): + return self.layers(embed) + + +class AestheticScorer(torch.nn.Module): + """ + This model attempts to predict the aesthetic score of an image. The aesthetic score + is a numerical approximation of how much a specific image is liked by humans on average. + This is from https://github.com/christophschuhmann/improved-aesthetic-predictor + """ + + def __init__(self, *, dtype, model_id, model_filename): + super().__init__() + self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.mlp = MLP() + try: + cached_path = hf_hub_download(model_id, model_filename) + except EntryNotFoundError: + cached_path = os.path.join(model_id, model_filename) + state_dict = torch.load(cached_path, map_location=torch.device("cpu")) + self.mlp.load_state_dict(state_dict) + self.dtype = dtype + self.eval() + + @torch.no_grad() + def __call__(self, images): + device = next(self.parameters()).device + inputs = self.processor(images=images, return_tensors="pt") + inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()} + embed = self.clip.get_image_features(**inputs) + # normalize embedding + embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) + return self.mlp(embed).squeeze(1) + + +def aesthetic_scorer(hub_model_id, model_filename): + scorer = AestheticScorer( + model_id=hub_model_id, + model_filename=model_filename, + dtype=torch.float32, + ) + if is_npu_available(): + scorer = scorer.npu() + elif is_xpu_available(): + scorer = scorer.xpu() + else: + scorer = scorer.cuda() + + def _fn(images, prompts, metadata): + images = (images * 255).round().clamp(0, 255).to(torch.uint8) + scores = scorer(images) + return scores, {} + + return _fn + + +# list of example prompts to feed stable diffusion +animals = [ + "cat", + "dog", + "horse", + "monkey", + "rabbit", + "zebra", + "spider", + "bird", + "sheep", + "deer", + "cow", + "goat", + "lion", + "frog", + "chicken", + "duck", + "goose", + "bee", + "pig", + "turkey", + "fly", + "llama", + "camel", + "bat", + "gorilla", + "hedgehog", + "kangaroo", +] + + +def prompt_fn(): + return np.random.choice(animals), {} + + +def image_outputs_logger(image_data, global_step, accelerate_logger): + # For the sake of this example, we will only log the last batch of images + # and associated data + result = {} + images, prompts, _, rewards, _ = image_data[-1] + + for i, image in enumerate(images): + prompt = prompts[i] + reward = rewards[i].item() + result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float() + + accelerate_logger.log_images( + result, + step=global_step, + ) + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, DDPOConfig)) + args, ddpo_config = parser.parse_args_into_dataclasses() + ddpo_config.project_kwargs = { + "logging_dir": "./logs", + "automatic_checkpoint_naming": True, + "total_limit": 5, + "project_dir": "./save", + } + + pipeline = DefaultDDPOStableDiffusionPipeline( + args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora + ) + + trainer = DDPOTrainer( + ddpo_config, + aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename), + prompt_fn, + pipeline, + image_samples_hook=image_outputs_logger, + ) + + trainer.train() + + trainer.push_to_hub(args.hf_hub_model_id) diff --git a/tests/test_alignprop_trainer.py b/tests/test_alignprop_trainer.py new file mode 100644 index 0000000000..87c7de68eb --- /dev/null +++ b/tests/test_alignprop_trainer.py @@ -0,0 +1,127 @@ +# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import unittest + +import torch + +from trl import is_diffusers_available, is_peft_available + +from .testing_utils import require_diffusers + + +if is_diffusers_available() and is_peft_available(): + from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline + + +def scorer_function(images, prompts, metadata): + return torch.randn(1) * 3.0, {} + + +def prompt_function(): + return ("cabbages", {}) + + +@require_diffusers +class DDPOTrainerTester(unittest.TestCase): + """ + Test the DDPOTrainer class. + """ + + def setUp(self): + self.ddpo_config = DDPOConfig( + num_epochs=2, + train_gradient_accumulation_steps=1, + per_prompt_stat_tracking_buffer_size=32, + sample_num_batches_per_epoch=2, + sample_batch_size=2, + mixed_precision=None, + save_freq=1000000, + ) + pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch" + pretrained_revision = "main" + + pipeline = DefaultDDPOStableDiffusionPipeline( + pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=False + ) + + self.trainer = DDPOTrainer(self.ddpo_config, scorer_function, prompt_function, pipeline) + + return super().setUp() + + def tearDown(self) -> None: + gc.collect() + + def test_loss(self): + advantage = torch.tensor([-1.0]) + clip_range = 0.0001 + ratio = torch.tensor([1.0]) + loss = self.trainer.loss(advantage, clip_range, ratio) + assert loss.item() == 1.0 + + def test_generate_samples(self): + samples, output_pairs = self.trainer._generate_samples(1, 2) + assert len(samples) == 1 + assert len(output_pairs) == 1 + assert len(output_pairs[0][0]) == 2 + + def test_calculate_loss(self): + samples, _ = self.trainer._generate_samples(1, 2) + sample = samples[0] + + latents = sample["latents"][0, 0].unsqueeze(0) + next_latents = sample["next_latents"][0, 0].unsqueeze(0) + log_probs = sample["log_probs"][0, 0].unsqueeze(0) + timesteps = sample["timesteps"][0, 0].unsqueeze(0) + prompt_embeds = sample["prompt_embeds"] + advantage = torch.tensor([1.0], device=prompt_embeds.device) + + assert latents.shape == (1, 4, 64, 64) + assert next_latents.shape == (1, 4, 64, 64) + assert log_probs.shape == (1,) + assert timesteps.shape == (1,) + assert prompt_embeds.shape == (2, 77, 32) + loss, approx_kl, clipfrac = self.trainer.calculate_loss( + latents, timesteps, next_latents, log_probs, advantage, prompt_embeds + ) + + assert torch.isfinite(loss.cpu()) + + +@require_diffusers +class DDPOTrainerWithLoRATester(DDPOTrainerTester): + """ + Test the DDPOTrainer class. + """ + + def setUp(self): + self.ddpo_config = DDPOConfig( + num_epochs=2, + train_gradient_accumulation_steps=1, + per_prompt_stat_tracking_buffer_size=32, + sample_num_batches_per_epoch=2, + sample_batch_size=2, + mixed_precision=None, + save_freq=1000000, + ) + pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch" + pretrained_revision = "main" + + pipeline = DefaultDDPOStableDiffusionPipeline( + pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=True + ) + + self.trainer = DDPOTrainer(self.ddpo_config, scorer_function, prompt_function, pipeline) + + return super().setUp() diff --git a/trl/__init__.py b/trl/__init__.py index 252c4e8bc9..57af615895 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable +import torch.utils.checkpoint as checkpoint _import_structure = { "core": [ @@ -38,6 +39,8 @@ "DPOTrainer", "CPOConfig", "CPOTrainer", + "AlignPropConfig", + "AlignPropTrainer", "IterativeSFTTrainer", "KTOConfig", "KTOTrainer", @@ -102,6 +105,8 @@ DPOTrainer, CPOConfig, CPOTrainer, + AlignPropConfig, + AlignPropTrainer, IterativeSFTTrainer, KTOConfig, KTOTrainer, diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 6310e2a0f9..9244771cfd 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -34,6 +34,8 @@ ], "cpo_config": ["CPOConfig"], "cpo_trainer": ["CPOTrainer"], + "alignprop_config": ["AlignPropConfig"], + "alignprop_trainer": ["AlignPropTrainer"], "iterative_sft_trainer": [ "IterativeSFTTrainer", ], @@ -82,6 +84,7 @@ from .iterative_sft_trainer import IterativeSFTTrainer from .cpo_config import CPOConfig from .cpo_trainer import CPOTrainer + from .alignprop_config import AlignPropConfig from .kto_config import KTOConfig from .kto_trainer import KTOTrainer from .model_config import ModelConfig diff --git a/trl/trainer/alignprop_config.py b/trl/trainer/alignprop_config.py new file mode 100644 index 0000000000..b73bd58d05 --- /dev/null +++ b/trl/trainer/alignprop_config.py @@ -0,0 +1,120 @@ +import os +import sys +import warnings +from dataclasses import dataclass, field +from typing import Literal, Optional + +from ..core import flatten_dict +from ..import_utils import is_bitsandbytes_available, is_torchvision_available + + +@dataclass +class DDPOConfig: + """ + Configuration class for DDPOTrainer + """ + + # common parameters + exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] + """the name of this experiment (by default is the file name without the extension name)""" + run_name: Optional[str] = "" + """Run name for wandb logging and checkpoint saving.""" + seed: int = 0 + """Seed value for random generations""" + log_with: Optional[Literal["wandb", "tensorboard"]] = None + """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" + tracker_kwargs: dict = field(default_factory=dict) + """Keyword arguments for the tracker (e.g. wandb_project)""" + accelerator_kwargs: dict = field(default_factory=dict) + """Keyword arguments for the accelerator""" + project_kwargs: dict = field(default_factory=dict) + """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" + tracker_project_name: str = "trl" + """Name of project to use for tracking""" + logdir: str = "logs" + """Top-level logging directory for checkpoint saving.""" + + # hyperparameters + num_epochs: int = 100 + """Number of epochs to train.""" + save_freq: int = 1 + """Number of epochs between saving model checkpoints.""" + num_checkpoint_limit: int = 5 + """Number of checkpoints to keep before overwriting old ones.""" + mixed_precision: str = "fp16" + """Mixed precision training.""" + allow_tf32: bool = True + """Allow tf32 on Ampere GPUs.""" + resume_from: Optional[str] = "" + """Resume training from a checkpoint.""" + sample_num_steps: int = 50 + """Number of sampler inference steps.""" + sample_eta: float = 1.0 + """Eta parameter for the DDIM sampler.""" + sample_guidance_scale: float = 5.0 + """Classifier-free guidance weight.""" + sample_batch_size: int = 1 + """Batch size (per GPU!) to use for sampling.""" + sample_num_batches_per_epoch: int = 2 + """Number of batches to sample per epoch.""" + train_batch_size: int = 1 + """Batch size (per GPU!) to use for training.""" + train_use_8bit_adam: bool = False + """Whether to use the 8bit Adam optimizer from bitsandbytes.""" + train_learning_rate: float = 3e-4 + """Learning rate.""" + train_adam_beta1: float = 0.9 + """Adam beta1.""" + train_adam_beta2: float = 0.999 + """Adam beta2.""" + train_adam_weight_decay: float = 1e-4 + """Adam weight decay.""" + train_adam_epsilon: float = 1e-8 + """Adam epsilon.""" + train_gradient_accumulation_steps: int = 1 + """Number of gradient accumulation steps.""" + train_max_grad_norm: float = 1.0 + """Maximum gradient norm for gradient clipping.""" + train_num_inner_epochs: int = 1 + """Number of inner epochs per outer epoch.""" + train_cfg: bool = True + """Whether or not to use classifier-free guidance during training.""" + train_adv_clip_max: float = 5 + """Clip advantages to the range.""" + train_clip_range: float = 1e-4 + """The PPO clip range.""" + train_timestep_fraction: float = 1.0 + """The fraction of timesteps to train on.""" + per_prompt_stat_tracking: bool = False + """Whether to track statistics for each prompt separately.""" + per_prompt_stat_tracking_buffer_size: int = 16 + """Number of reward values to store in the buffer for each prompt.""" + per_prompt_stat_tracking_min_count: int = 16 + """The minimum number of reward values to store in the buffer.""" + async_reward_computation: bool = False + """Whether to compute rewards asynchronously.""" + max_workers: int = 2 + """The maximum number of workers to use for async reward computation.""" + negative_prompts: Optional[str] = "" + """Comma-separated list of prompts to use as negative examples.""" + + def to_dict(self): + output_dict = {} + for key, value in self.__dict__.items(): + output_dict[key] = value + return flatten_dict(output_dict) + + def __post_init__(self): + if self.log_with not in ["wandb", "tensorboard"]: + warnings.warn( + "Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'." + ) + + if self.log_with == "wandb" and not is_torchvision_available(): + warnings.warn("Wandb image logging requires torchvision to be installed") + + if self.train_use_8bit_adam and not is_bitsandbytes_available(): + raise ImportError( + "You need to install bitsandbytes to use 8bit Adam. " + "You can install it with `pip install bitsandbytes`." + ) diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py new file mode 100644 index 0000000000..df219da707 --- /dev/null +++ b/trl/trainer/alignprop_trainer.py @@ -0,0 +1,629 @@ +# Copyright 2023 DDPO-pytorch authors (Kevin Black), metric-space, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import warnings +from collections import defaultdict +from concurrent import futures +from typing import Any, Callable, Optional, Tuple +from warnings import warn + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import whoami + +from ..models import DDPOStableDiffusionPipeline +from . import BaseTrainer, DDPOConfig +from .utils import PerPromptStatTracker + + +logger = get_logger(__name__) + + +MODEL_CARD_TEMPLATE = """--- +license: apache-2.0 +tags: +- trl +- ddpo +- diffusers +- reinforcement-learning +- text-to-image +- stable-diffusion +--- + +# {model_name} + +This is a diffusion model that has been fine-tuned with reinforcement learning to + guide the model outputs according to a value, function, or human feedback. The model can be used for image generation conditioned with text. + +""" + + +class DDPOTrainer(BaseTrainer): + """ + The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. + Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch + As of now only Stable Diffusion based pipelines are supported + + Attributes: + **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more + details. + **reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used + **prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model + **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training. + **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images + """ + + _tag_names = ["trl", "ddpo"] + + def __init__( + self, + config: DDPOConfig, + reward_function: Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor], + prompt_function: Callable[[], Tuple[str, Any]], + sd_pipeline: DDPOStableDiffusionPipeline, + image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None, + ): + if image_samples_hook is None: + warn("No image_samples_hook provided; no images will be logged") + + self.prompt_fn = prompt_function + self.reward_fn = reward_function + self.config = config + self.image_samples_callback = image_samples_hook + + accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs) + + if self.config.resume_from: + self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from)) + if "checkpoint_" not in os.path.basename(self.config.resume_from): + # get the most recent checkpoint in this directory + checkpoints = list( + filter( + lambda x: "checkpoint_" in x, + os.listdir(self.config.resume_from), + ) + ) + if len(checkpoints) == 0: + raise ValueError(f"No checkpoints found in {self.config.resume_from}") + checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints]) + self.config.resume_from = os.path.join( + self.config.resume_from, + f"checkpoint_{checkpoint_numbers[-1]}", + ) + + accelerator_project_config.iteration = checkpoint_numbers[-1] + 1 + + # number of timesteps within each trajectory to train on + self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction) + + self.accelerator = Accelerator( + log_with=self.config.log_with, + mixed_precision=self.config.mixed_precision, + project_config=accelerator_project_config, + # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the + # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get + # the total number of optimizer steps to accumulate across. + gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps, + **self.config.accelerator_kwargs, + ) + + is_okay, message = self._config_check() + if not is_okay: + raise ValueError(message) + + is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" + + if self.accelerator.is_main_process: + self.accelerator.init_trackers( + self.config.tracker_project_name, + config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), + init_kwargs=self.config.tracker_kwargs, + ) + + logger.info(f"\n{config}") + + set_seed(self.config.seed, device_specific=True) + + self.sd_pipeline = sd_pipeline + + self.sd_pipeline.set_progress_bar_config( + position=1, + disable=not self.accelerator.is_local_main_process, + leave=False, + desc="Timestep", + dynamic_ncols=True, + ) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + if self.accelerator.mixed_precision == "fp16": + inference_dtype = torch.float16 + elif self.accelerator.mixed_precision == "bf16": + inference_dtype = torch.bfloat16 + else: + inference_dtype = torch.float32 + + self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype) + self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype) + self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype) + + trainable_layers = self.sd_pipeline.get_trainable_layers() + + self.accelerator.register_save_state_pre_hook(self._save_model_hook) + self.accelerator.register_load_state_pre_hook(self._load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + self.optimizer = self._setup_optimizer( + trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers + ) + + self.neg_prompt_embed = self.sd_pipeline.text_encoder( + self.sd_pipeline.tokenizer( + [""] if self.config.negative_prompts is None else self.config.negative_prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) + )[0] + + if config.per_prompt_stat_tracking: + self.stat_tracker = PerPromptStatTracker( + config.per_prompt_stat_tracking_buffer_size, + config.per_prompt_stat_tracking_min_count, + ) + + # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses + # more memory + self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast + + if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora: + unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) + self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters())) + else: + self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) + + if self.config.async_reward_computation: + self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers) + + if config.resume_from: + logger.info(f"Resuming from {config.resume_from}") + self.accelerator.load_state(config.resume_from) + self.first_epoch = int(config.resume_from.split("_")[-1]) + 1 + else: + self.first_epoch = 0 + + def compute_rewards(self, prompt_image_pairs, is_async=False): + if not is_async: + rewards = [] + for images, prompts, prompt_metadata in prompt_image_pairs: + reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata) + rewards.append( + ( + torch.as_tensor(reward, device=self.accelerator.device), + reward_metadata, + ) + ) + else: + rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs) + rewards = [ + (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result()) + for reward, reward_metadata in rewards + ] + + return zip(*rewards) + + def step(self, epoch: int, global_step: int): + """ + Perform a single step of training. + + Args: + epoch (int): The current epoch. + global_step (int): The current global step. + + Side Effects: + - Model weights are updated + - Logs the statistics to the accelerator trackers. + - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker. + + Returns: + global_step (int): The updated global step. + + """ + samples, prompt_image_data = self._generate_samples( + iterations=self.config.sample_num_batches_per_epoch, + batch_size=self.config.sample_batch_size, + ) + + # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) + samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} + rewards, rewards_metadata = self.compute_rewards( + prompt_image_data, is_async=self.config.async_reward_computation + ) + + for i, image_data in enumerate(prompt_image_data): + image_data.extend([rewards[i], rewards_metadata[i]]) + + if self.image_samples_callback is not None: + self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0]) + + rewards = torch.cat(rewards) + rewards = self.accelerator.gather(rewards).cpu().numpy() + + self.accelerator.log( + { + "reward": rewards, + "epoch": epoch, + "reward_mean": rewards.mean(), + "reward_std": rewards.std(), + }, + step=global_step, + ) + + if self.config.per_prompt_stat_tracking: + # gather the prompts across processes + prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy() + prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) + advantages = self.stat_tracker.update(prompts, rewards) + else: + advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) + + # ungather advantages; keep the entries corresponding to the samples on this process + samples["advantages"] = ( + torch.as_tensor(advantages) + .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index] + .to(self.accelerator.device) + ) + + del samples["prompt_ids"] + + total_batch_size, num_timesteps = samples["timesteps"].shape + + for inner_epoch in range(self.config.train_num_inner_epochs): + # shuffle samples along batch dimension + perm = torch.randperm(total_batch_size, device=self.accelerator.device) + samples = {k: v[perm] for k, v in samples.items()} + + # shuffle along time dimension independently for each sample + # still trying to understand the code below + perms = torch.stack( + [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)] + ) + + for key in ["timesteps", "latents", "next_latents", "log_probs"]: + samples[key] = samples[key][ + torch.arange(total_batch_size, device=self.accelerator.device)[:, None], + perms, + ] + + original_keys = samples.keys() + original_values = samples.values() + # rebatch them as user defined train_batch_size is different from sample_batch_size + reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values] + + # Transpose the list of original values + transposed_values = zip(*reshaped_values) + # Create new dictionaries for each row of transposed values + samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values] + + self.sd_pipeline.unet.train() + global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched) + # ensure optimization step at the end of the inner epoch + if not self.accelerator.sync_gradients: + raise ValueError( + "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." + ) + + if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: + self.accelerator.save_state() + + return global_step + + def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds): + """ + Calculate the loss for a batch of an unpacked sample + + Args: + latents (torch.Tensor): + The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] + timesteps (torch.Tensor): + The timesteps sampled from the diffusion model, shape: [batch_size] + next_latents (torch.Tensor): + The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] + log_probs (torch.Tensor): + The log probabilities of the latents, shape: [batch_size] + advantages (torch.Tensor): + The advantages of the latents, shape: [batch_size] + embeds (torch.Tensor): + The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] + Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds + + Returns: + loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) + (all of these are of shape (1,)) + """ + with self.autocast(): + if self.config.train_cfg: + noise_pred = self.sd_pipeline.unet( + torch.cat([latents] * 2), + torch.cat([timesteps] * 2), + embeds, + ).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + else: + noise_pred = self.sd_pipeline.unet( + latents, + timesteps, + embeds, + ).sample + # compute the log prob of next_latents given latents under the current model + + scheduler_step_output = self.sd_pipeline.scheduler_step( + noise_pred, + timesteps, + latents, + eta=self.config.sample_eta, + prev_sample=next_latents, + ) + + log_prob = scheduler_step_output.log_probs + + advantages = torch.clamp( + advantages, + -self.config.train_adv_clip_max, + self.config.train_adv_clip_max, + ) + + ratio = torch.exp(log_prob - log_probs) + + loss = self.loss(advantages, self.config.train_clip_range, ratio) + + approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2) + + clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float()) + + return loss, approx_kl, clipfrac + + def loss( + self, + advantages: torch.Tensor, + clip_range: float, + ratio: torch.Tensor, + ): + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp( + ratio, + 1.0 - clip_range, + 1.0 + clip_range, + ) + return torch.mean(torch.maximum(unclipped_loss, clipped_loss)) + + def _setup_optimizer(self, trainable_layers_parameters): + if self.config.train_use_8bit_adam: + import bitsandbytes + + optimizer_cls = bitsandbytes.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + return optimizer_cls( + trainable_layers_parameters, + lr=self.config.train_learning_rate, + betas=(self.config.train_adam_beta1, self.config.train_adam_beta2), + weight_decay=self.config.train_adam_weight_decay, + eps=self.config.train_adam_epsilon, + ) + + def _save_model_hook(self, models, weights, output_dir): + self.sd_pipeline.save_checkpoint(models, weights, output_dir) + weights.pop() # ensures that accelerate doesn't try to handle saving of the model + + def _load_model_hook(self, models, input_dir): + self.sd_pipeline.load_checkpoint(models, input_dir) + models.pop() # ensures that accelerate doesn't try to handle loading of the model + + def _generate_samples(self, iterations, batch_size): + """ + Generate samples from the model + + Args: + iterations (int): Number of iterations to generate samples for + batch_size (int): Batch size to use for sampling + + Returns: + samples (List[Dict[str, torch.Tensor]]), prompt_image_pairs (List[List[Any]]) + """ + samples = [] + prompt_image_pairs = [] + self.sd_pipeline.unet.eval() + + sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) + + for _ in range(iterations): + prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) + + prompt_ids = self.sd_pipeline.tokenizer( + prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) + prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0] + + with self.autocast(): + sd_output = self.sd_pipeline( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=self.config.sample_num_steps, + guidance_scale=self.config.sample_guidance_scale, + eta=self.config.sample_eta, + output_type="pt", + ) + + images = sd_output.images + latents = sd_output.latents + log_probs = sd_output.log_probs + + latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...) + log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) + timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps) + + samples.append( + { + "prompt_ids": prompt_ids, + "prompt_embeds": prompt_embeds, + "timesteps": timesteps, + "latents": latents[:, :-1], # each entry is the latent before timestep t + "next_latents": latents[:, 1:], # each entry is the latent after timestep t + "log_probs": log_probs, + "negative_prompt_embeds": sample_neg_prompt_embeds, + } + ) + prompt_image_pairs.append([images, prompts, prompt_metadata]) + + return samples, prompt_image_pairs + + def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples): + """ + Train on a batch of samples. Main training segment + + Args: + inner_epoch (int): The current inner epoch + epoch (int): The current epoch + global_step (int): The current global step + batched_samples (List[Dict[str, torch.Tensor]]): The batched samples to train on + + Side Effects: + - Model weights are updated + - Logs the statistics to the accelerator trackers. + + Returns: + global_step (int): The updated global step + """ + info = defaultdict(list) + for _i, sample in enumerate(batched_samples): + if self.config.train_cfg: + # concat negative prompts to sample prompts to avoid two forward passes + embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]]) + else: + embeds = sample["prompt_embeds"] + + for j in range(self.num_train_timesteps): + with self.accelerator.accumulate(self.sd_pipeline.unet): + loss, approx_kl, clipfrac = self.calculate_loss( + sample["latents"][:, j], + sample["timesteps"][:, j], + sample["next_latents"][:, j], + sample["log_probs"][:, j], + sample["advantages"], + embeds, + ) + info["approx_kl"].append(approx_kl) + info["clipfrac"].append(clipfrac) + info["loss"].append(loss) + + self.accelerator.backward(loss) + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_( + self.trainable_layers.parameters() + if not isinstance(self.trainable_layers, list) + else self.trainable_layers, + self.config.train_max_grad_norm, + ) + self.optimizer.step() + self.optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if self.accelerator.sync_gradients: + # log training-related stuff + info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} + info = self.accelerator.reduce(info, reduction="mean") + info.update({"epoch": epoch, "inner_epoch": inner_epoch}) + self.accelerator.log(info, step=global_step) + global_step += 1 + info = defaultdict(list) + return global_step + + def _config_check(self) -> Tuple[bool, str]: + samples_per_epoch = ( + self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch + ) + total_train_batch_size = ( + self.config.train_batch_size + * self.accelerator.num_processes + * self.config.train_gradient_accumulation_steps + ) + + if not self.config.sample_batch_size >= self.config.train_batch_size: + return ( + False, + f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})", + ) + if not self.config.sample_batch_size % self.config.train_batch_size == 0: + return ( + False, + f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})", + ) + if not samples_per_epoch % total_train_batch_size == 0: + return ( + False, + f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})", + ) + return True, "" + + def train(self, epochs: Optional[int] = None): + """ + Train the model for a given number of epochs + """ + global_step = 0 + if epochs is None: + epochs = self.config.num_epochs + for epoch in range(self.first_epoch, epochs): + global_step = self.step(epoch, global_step) + + def create_model_card(self, path: str, model_name: Optional[str] = "TRL DDPO Model") -> None: + """Creates and saves a model card for a TRL model. + + Args: + path (`str`): The path to save the model card to. + model_name (`str`, *optional*): The name of the model, defaults to `TRL DDPO Model`. + """ + try: + user = whoami()["name"] + # handle the offline case + except Exception: + warnings.warn("Cannot retrieve user information assuming you are running in offline mode.") + return + + if not os.path.exists(path): + os.makedirs(path) + + model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}") + with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: + f.write(model_card_content) + + def _save_pretrained(self, save_directory): + self.sd_pipeline.save_pretrained(save_directory) + self.create_model_card(save_directory) From 0db810428916a03b1f72e0fc38f1fd3cc1dd53bc Mon Sep 17 00:00:00 2001 From: matrix Mihir Prabhduesai Date: Thu, 25 Apr 2024 02:47:48 -0400 Subject: [PATCH 02/15] added alignprop support --- docs/source/alignprop_trainer.mdx | 116 ++++++++ examples/scripts/alignprop.py | 38 +-- tests/test_alignprop_trainer.py | 71 ++--- trl/models/modeling_sd_base.py | 244 ++++++++++++++++- trl/trainer/alignprop_config.py | 35 +-- trl/trainer/alignprop_trainer.py | 430 +++++++++--------------------- 6 files changed, 534 insertions(+), 400 deletions(-) create mode 100644 docs/source/alignprop_trainer.mdx diff --git a/docs/source/alignprop_trainer.mdx b/docs/source/alignprop_trainer.mdx new file mode 100644 index 0000000000..13e7f9492d --- /dev/null +++ b/docs/source/alignprop_trainer.mdx @@ -0,0 +1,116 @@ +# Denoising Diffusion Policy Optimization +## The why + +| Before | After DDPO finetuning | +| --- | --- | +|
|
| +|
|
| +|
|
| + + +## Getting started with Stable Diffusion finetuning with reinforcement learning + +The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers` +library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers. +Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made. + +There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.** +There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide. + +The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO). + +For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py) + +Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training. + +Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images. + +## Getting started with `examples/scripts/alignprop.py` + +The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`). + +**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1. + +Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running + +```batch +python alignprop.py --hf_user_access_token +``` + +To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help` + +The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script) + +- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps) +- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False + +## Setting up the image logging hook function + +Expect the function to be given a dictionary with keys +```python +['image', 'prompt', 'prompt_metadata', 'rewards'] + +``` +and `image`, `prompt`, `prompt_metadata`, `rewards`are batched. +You are free to log however you want the use of `wandb` or `tensorboard` is recommended. + +### Key terms + +- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process +- `prompt` : The prompt is the text that is used to generate the image +- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45) +- `image` : The image generated by the Stable Diffusion model + +Example code for logging sampled images with `wandb` is given below. + +```python +# for logging these images to wandb + +def image_outputs_hook(image_data, global_step, accelerate_logger): + # For the sake of this example, we only care about the last batch + # hence we extract the last element of the list + result = {} + images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']] + for i, image in enumerate(images): + pil = Image.fromarray( + (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8) + ) + pil = pil.resize((256, 256)) + result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil] + accelerate_logger.log_images( + result, + step=global_step, + ) + +``` + +### Using the finetuned model + +Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows + +```python + +import torch +from trl import DefaultDDPOStableDiffusionPipeline + +pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/alignprop-finetuned-sd-model") + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +# memory optimization +pipeline.vae.to(device, torch.float16) +pipeline.text_encoder.to(device, torch.float16) +pipeline.unet.to(device, torch.float16) + +prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"] +results = pipeline(prompts) + +for prompt, image in zip(prompts,results.images): + image.save(f"{prompt}.png") + +``` + +## Credits + +This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation + by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://arxiv.org/abs/2310.03739). \ No newline at end of file diff --git a/examples/scripts/alignprop.py b/examples/scripts/alignprop.py index 8e43202684..cb5ea960bd 100644 --- a/examples/scripts/alignprop.py +++ b/examples/scripts/alignprop.py @@ -25,6 +25,7 @@ --log_with="wandb" """ import os +import torchvision from dataclasses import dataclass, field import numpy as np @@ -34,7 +35,7 @@ from huggingface_hub.utils import EntryNotFoundError from transformers import CLIPModel, CLIPProcessor, HfArgumentParser -from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline +from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline from trl.import_utils import is_npu_available, is_xpu_available @@ -72,7 +73,6 @@ def __init__(self): nn.Linear(16, 1), ) - @torch.no_grad() def forward(self, embed): return self.layers(embed) @@ -87,7 +87,9 @@ class AestheticScorer(torch.nn.Module): def __init__(self, *, dtype, model_id, model_filename): super().__init__() self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") - self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + self.target_size = 224 self.mlp = MLP() try: cached_path = hf_hub_download(model_id, model_filename) @@ -98,15 +100,15 @@ def __init__(self, *, dtype, model_id, model_filename): self.dtype = dtype self.eval() - @torch.no_grad() def __call__(self, images): device = next(self.parameters()).device - inputs = self.processor(images=images, return_tensors="pt") - inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()} - embed = self.clip.get_image_features(**inputs) + images = torchvision.transforms.Resize(self.target_size)(images) + images = self.normalize(images).to(self.dtype).to(device) + embed = self.clip.get_image_features(pixel_values=images) # normalize embedding embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) - return self.mlp(embed).squeeze(1) + reward = self.mlp(embed).squeeze(1) + return reward def aesthetic_scorer(hub_model_id, model_filename): @@ -123,7 +125,7 @@ def aesthetic_scorer(hub_model_id, model_filename): scorer = scorer.cuda() def _fn(images, prompts, metadata): - images = (images * 255).round().clamp(0, 255).to(torch.uint8) + images = (images).clamp(0, 1) scores = scorer(images) return scores, {} @@ -166,13 +168,13 @@ def prompt_fn(): return np.random.choice(animals), {} -def image_outputs_logger(image_data, global_step, accelerate_logger): + +def image_outputs_logger(image_pair_data, global_step, accelerate_logger): # For the sake of this example, we will only log the last batch of images # and associated data result = {} - images, prompts, _, rewards, _ = image_data[-1] - - for i, image in enumerate(images): + images, prompts, rewards = [image_pair_data['images'],image_pair_data['prompts'],image_pair_data['rewards']] + for i, image in enumerate(images[:4]): prompt = prompts[i] reward = rewards[i].item() result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float() @@ -184,9 +186,9 @@ def image_outputs_logger(image_data, global_step, accelerate_logger): if __name__ == "__main__": - parser = HfArgumentParser((ScriptArguments, DDPOConfig)) - args, ddpo_config = parser.parse_args_into_dataclasses() - ddpo_config.project_kwargs = { + parser = HfArgumentParser((ScriptArguments, AlignPropConfig)) + args, alignprop_config = parser.parse_args_into_dataclasses() + alignprop_config.project_kwargs = { "logging_dir": "./logs", "automatic_checkpoint_naming": True, "total_limit": 5, @@ -197,8 +199,8 @@ def image_outputs_logger(image_data, global_step, accelerate_logger): args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora ) - trainer = DDPOTrainer( - ddpo_config, + trainer = AlignPropTrainer( + alignprop_config, aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename), prompt_fn, pipeline, diff --git a/tests/test_alignprop_trainer.py b/tests/test_alignprop_trainer.py index 87c7de68eb..9058fd35db 100644 --- a/tests/test_alignprop_trainer.py +++ b/tests/test_alignprop_trainer.py @@ -13,7 +13,6 @@ # limitations under the License. import gc import unittest - import torch from trl import is_diffusers_available, is_peft_available @@ -22,7 +21,7 @@ if is_diffusers_available() and is_peft_available(): - from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline + from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline def scorer_function(images, prompts, metadata): @@ -34,18 +33,17 @@ def prompt_function(): @require_diffusers -class DDPOTrainerTester(unittest.TestCase): +class AlignPropTrainerTester(unittest.TestCase): """ - Test the DDPOTrainer class. + Test the AlignPropTrainer class. """ def setUp(self): - self.ddpo_config = DDPOConfig( + self.alignprop_config = AlignPropConfig( num_epochs=2, train_gradient_accumulation_steps=1, - per_prompt_stat_tracking_buffer_size=32, - sample_num_batches_per_epoch=2, - sample_batch_size=2, + train_batch_size=2, + truncated_backprop_rand=False, mixed_precision=None, save_freq=1000000, ) @@ -56,65 +54,50 @@ def setUp(self): pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=False ) - self.trainer = DDPOTrainer(self.ddpo_config, scorer_function, prompt_function, pipeline) + self.trainer = AlignPropTrainer(self.alignprop_config, scorer_function, prompt_function, pipeline) return super().setUp() def tearDown(self) -> None: gc.collect() - def test_loss(self): - advantage = torch.tensor([-1.0]) - clip_range = 0.0001 - ratio = torch.tensor([1.0]) - loss = self.trainer.loss(advantage, clip_range, ratio) - assert loss.item() == 1.0 - def test_generate_samples(self): - samples, output_pairs = self.trainer._generate_samples(1, 2) - assert len(samples) == 1 - assert len(output_pairs) == 1 - assert len(output_pairs[0][0]) == 2 + output_pairs = self.trainer._generate_samples(2, with_grad=True) + assert len(output_pairs.keys()) == 3 + assert len(output_pairs['images']) == 2 def test_calculate_loss(self): - samples, _ = self.trainer._generate_samples(1, 2) - sample = samples[0] - - latents = sample["latents"][0, 0].unsqueeze(0) - next_latents = sample["next_latents"][0, 0].unsqueeze(0) - log_probs = sample["log_probs"][0, 0].unsqueeze(0) - timesteps = sample["timesteps"][0, 0].unsqueeze(0) - prompt_embeds = sample["prompt_embeds"] - advantage = torch.tensor([1.0], device=prompt_embeds.device) - - assert latents.shape == (1, 4, 64, 64) - assert next_latents.shape == (1, 4, 64, 64) - assert log_probs.shape == (1,) - assert timesteps.shape == (1,) - assert prompt_embeds.shape == (2, 77, 32) - loss, approx_kl, clipfrac = self.trainer.calculate_loss( - latents, timesteps, next_latents, log_probs, advantage, prompt_embeds + sample = self.trainer._generate_samples(2) + + images = sample["images"] + prompts = sample["prompts"] + + assert images.shape == (2, 3, 128, 128) + assert len(prompts) == 2 + + rewards = self.trainer.compute_rewards( + sample ) + loss = self.trainer.calculate_loss(rewards) assert torch.isfinite(loss.cpu()) @require_diffusers -class DDPOTrainerWithLoRATester(DDPOTrainerTester): +class AlignPropTrainerWithLoRATester(AlignPropTrainerTester): """ - Test the DDPOTrainer class. + Test the AlignPropTrainer class. """ def setUp(self): - self.ddpo_config = DDPOConfig( + self.alignprop_config = AlignPropConfig( num_epochs=2, train_gradient_accumulation_steps=1, - per_prompt_stat_tracking_buffer_size=32, - sample_num_batches_per_epoch=2, - sample_batch_size=2, mixed_precision=None, + truncated_backprop_rand=False, save_freq=1000000, ) + pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch" pretrained_revision = "main" @@ -122,6 +105,6 @@ def setUp(self): pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=True ) - self.trainer = DDPOTrainer(self.ddpo_config, scorer_function, prompt_function, pipeline) + self.trainer = AlignPropTrainer(self.alignprop_config, scorer_function, prompt_function, pipeline) return super().setUp() diff --git a/trl/models/modeling_sd_base.py b/trl/models/modeling_sd_base.py index bbca699fff..c9ce23fb88 100644 --- a/trl/models/modeling_sd_base.py +++ b/trl/models/modeling_sd_base.py @@ -17,7 +17,8 @@ import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union - +import torch.utils.checkpoint as checkpoint +import random import numpy as np import torch from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel @@ -233,7 +234,6 @@ def scheduler_step( # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" - # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # to prevent OOB on gather @@ -527,6 +527,243 @@ def pipeline_step( return DDPOPipelineOutput(image, all_latents, all_log_probs) +def pipeline_step_with_grad( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + truncated_backprop: bool = True, + truncated_backprop_rand: bool = True, + gradient_checkpoint: bool = True, + truncated_backprop_timestep: int = 49, + truncated_rand_backprop_minmax: tuple = (0,50), + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, +): + r""" + Function to get RGB image with gradients attached to the model weights. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + truncated_backprop (`bool`, *optional*, defaults to True): + Truncated Backpropation to fixed timesteps, helps prevent collapse during diffusion reward training as shown in AlignProp (https://arxiv.org/abs/2310.03739). + truncated_backprop_rand (`bool`, *optional*, defaults to True): + Truncated Randomized Backpropation randomizes truncation to different diffusion timesteps, this helps prevent collapse during diffusion reward training as shown in AlignProp (https://arxiv.org/abs/2310.03739). + Enabling truncated_backprop_rand allows adapting earlier timesteps in diffusion while not resulting in a collapse. + gradient_checkpoint (`bool`, *optional*, defaults to True): + Adds gradient checkpointing to Unet forward pass. Reduces GPU memory consumption while slightly increasing the training time. + truncated_backprop_timestep (`int`, *optional*, defaults to 49): + Absolute timestep to which the gradients are being backpropagated. Higher number reduces the memory usage and reduces the chances of collapse. + While a lower value, allows more semantic changes in the diffusion generations, as the earlier diffusion timesteps are getting updated. + However it also increases the chances of collapse. + truncated_rand_backprop_minmax (`Tuple`, *optional*, defaults to (0,50)): + Range for randomized backprop. Here the value at 0 index indicates the earlier diffusion timestep to update (closer to noise), while the value + at index 1 indicates the later diffusion timestep to update. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + `DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + with torch.no_grad(): + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + all_latents = [latents] + all_log_probs = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if gradient_checkpoint: + noise_pred = checkpoint.checkpoint(self.unet, latent_model_input, t, prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, use_reentrant=False)[0] + else: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + + if truncated_backprop: + if truncated_backprop_rand: + rand_timestep = random.randint(truncated_rand_backprop_minmax[0],truncated_rand_backprop_minmax[1]) + if i < rand_timestep: + noise_pred = noise_pred.detach() + else: + if i < truncated_backprop_timestep: + noise_pred = noise_pred.detach() + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = scheduler_step(self.scheduler, noise_pred, t, latents, eta) + latents = scheduler_output.latents + log_prob = scheduler_output.log_probs + + all_latents.append(latents) + all_log_probs.append(log_prob) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return DDPOPipelineOutput(image, all_latents, all_log_probs) + class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline): def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str = "main", use_lora: bool = True): @@ -563,6 +800,9 @@ def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str def __call__(self, *args, **kwargs) -> DDPOPipelineOutput: return pipeline_step(self.sd_pipeline, *args, **kwargs) + def rgb_with_grad(self, *args, **kwargs) -> DDPOPipelineOutput: + return pipeline_step_with_grad(self.sd_pipeline, *args, **kwargs) + def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput: return scheduler_step(self.sd_pipeline.scheduler, *args, **kwargs) diff --git a/trl/trainer/alignprop_config.py b/trl/trainer/alignprop_config.py index b73bd58d05..1af8c27add 100644 --- a/trl/trainer/alignprop_config.py +++ b/trl/trainer/alignprop_config.py @@ -9,9 +9,9 @@ @dataclass -class DDPOConfig: +class AlignPropConfig: """ - Configuration class for DDPOTrainer + Configuration class for AlignPropTrainer """ # common parameters @@ -53,10 +53,6 @@ class DDPOConfig: """Eta parameter for the DDIM sampler.""" sample_guidance_scale: float = 5.0 """Classifier-free guidance weight.""" - sample_batch_size: int = 1 - """Batch size (per GPU!) to use for sampling.""" - sample_num_batches_per_epoch: int = 2 - """Number of batches to sample per epoch.""" train_batch_size: int = 1 """Batch size (per GPU!) to use for training.""" train_use_8bit_adam: bool = False @@ -75,28 +71,15 @@ class DDPOConfig: """Number of gradient accumulation steps.""" train_max_grad_norm: float = 1.0 """Maximum gradient norm for gradient clipping.""" - train_num_inner_epochs: int = 1 - """Number of inner epochs per outer epoch.""" - train_cfg: bool = True - """Whether or not to use classifier-free guidance during training.""" - train_adv_clip_max: float = 5 - """Clip advantages to the range.""" - train_clip_range: float = 1e-4 - """The PPO clip range.""" - train_timestep_fraction: float = 1.0 - """The fraction of timesteps to train on.""" - per_prompt_stat_tracking: bool = False - """Whether to track statistics for each prompt separately.""" - per_prompt_stat_tracking_buffer_size: int = 16 - """Number of reward values to store in the buffer for each prompt.""" - per_prompt_stat_tracking_min_count: int = 16 - """The minimum number of reward values to store in the buffer.""" - async_reward_computation: bool = False - """Whether to compute rewards asynchronously.""" - max_workers: int = 2 - """The maximum number of workers to use for async reward computation.""" negative_prompts: Optional[str] = "" """Comma-separated list of prompts to use as negative examples.""" + truncated_backprop_rand: bool = True + """Truncated Randomized Backpropation randomizes truncation to different diffusion timesteps""" + truncated_backprop_timestep: int = 49 + """Absolute timestep to which the gradients are being backpropagated. If truncated_backprop_rand is False""" + truncated_rand_backprop_minmax: tuple = (0,50) + """Range of diffusion timesteps for randomized truncated backprop.""" + def to_dict(self): output_dict = {} diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py index df219da707..eeff83693d 100644 --- a/trl/trainer/alignprop_trainer.py +++ b/trl/trainer/alignprop_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 DDPO-pytorch authors (Kevin Black), metric-space, The HuggingFace Team. All rights reserved. +# Copyright 2023 AlignProp-pytorch authors (Mihir Prabhudesai), metric-space, The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os import warnings from collections import defaultdict @@ -26,8 +25,10 @@ from huggingface_hub import whoami from ..models import DDPOStableDiffusionPipeline -from . import BaseTrainer, DDPOConfig from .utils import PerPromptStatTracker +from . import BaseTrainer, AlignPropConfig + + logger = get_logger(__name__) @@ -37,7 +38,7 @@ license: apache-2.0 tags: - trl -- ddpo +- alignprop - diffusers - reinforcement-learning - text-to-image @@ -46,20 +47,19 @@ # {model_name} -This is a diffusion model that has been fine-tuned with reinforcement learning to - guide the model outputs according to a value, function, or human feedback. The model can be used for image generation conditioned with text. +This is a pipeline that finetunes a diffusion model with reward gradients. The model can be used for image generation conditioned with text. """ -class DDPOTrainer(BaseTrainer): +class AlignPropTrainer(BaseTrainer): """ - The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. - Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch + The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. + Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/ As of now only Stable Diffusion based pipelines are supported Attributes: - **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more + **config** (`AlignPropConfig`) -- Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details. **reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used **prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model @@ -67,11 +67,11 @@ class DDPOTrainer(BaseTrainer): **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images """ - _tag_names = ["trl", "ddpo"] + _tag_names = ["trl", "alignprop"] def __init__( self, - config: DDPOConfig, + config: AlignPropConfig, reward_function: Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor], prompt_function: Callable[[], Tuple[str, Any]], sd_pipeline: DDPOStableDiffusionPipeline, @@ -107,8 +107,6 @@ def __init__( accelerator_project_config.iteration = checkpoint_numbers[-1] + 1 - # number of timesteps within each trajectory to train on - self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction) self.accelerator = Accelerator( log_with=self.config.log_with, @@ -117,7 +115,7 @@ def __init__( # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get # the total number of optimizer steps to accumulate across. - gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps, + gradient_accumulation_steps=self.config.train_gradient_accumulation_steps, **self.config.accelerator_kwargs, ) @@ -130,7 +128,7 @@ def __init__( if self.accelerator.is_main_process: self.accelerator.init_trackers( self.config.tracker_project_name, - config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), + config=dict(alignprop_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), init_kwargs=self.config.tracker_kwargs, ) @@ -185,12 +183,6 @@ def __init__( ).input_ids.to(self.accelerator.device) )[0] - if config.per_prompt_stat_tracking: - self.stat_tracker = PerPromptStatTracker( - config.per_prompt_stat_tracking_buffer_size, - config.per_prompt_stat_tracking_min_count, - ) - # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses # more memory self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast @@ -201,9 +193,6 @@ def __init__( else: self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) - if self.config.async_reward_computation: - self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers) - if config.resume_from: logger.info(f"Resuming from {config.resume_from}") self.accelerator.load_state(config.resume_from) @@ -211,25 +200,9 @@ def __init__( else: self.first_epoch = 0 - def compute_rewards(self, prompt_image_pairs, is_async=False): - if not is_async: - rewards = [] - for images, prompts, prompt_metadata in prompt_image_pairs: - reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata) - rewards.append( - ( - torch.as_tensor(reward, device=self.accelerator.device), - reward_metadata, - ) - ) - else: - rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs) - rewards = [ - (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result()) - for reward, reward_metadata in rewards - ] - - return zip(*rewards) + def compute_rewards(self, prompt_image_pairs): + reward, reward_metadata = self.reward_fn(prompt_image_pairs['images'], prompt_image_pairs['prompts'], prompt_image_pairs['prompt_metadata']) + return reward def step(self, epoch: int, global_step: int): """ @@ -248,162 +221,80 @@ def step(self, epoch: int, global_step: int): global_step (int): The updated global step. """ - samples, prompt_image_data = self._generate_samples( - iterations=self.config.sample_num_batches_per_epoch, - batch_size=self.config.sample_batch_size, - ) - - # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) - samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} - rewards, rewards_metadata = self.compute_rewards( - prompt_image_data, is_async=self.config.async_reward_computation - ) - - for i, image_data in enumerate(prompt_image_data): - image_data.extend([rewards[i], rewards_metadata[i]]) + info = defaultdict(list) + + self.sd_pipeline.unet.train() + + for inner_iters in range(self.config.train_gradient_accumulation_steps): + with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad(): + prompt_image_pairs = self._generate_samples( + batch_size=self.config.train_batch_size, + ) - if self.image_samples_callback is not None: - self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0]) - - rewards = torch.cat(rewards) - rewards = self.accelerator.gather(rewards).cpu().numpy() - - self.accelerator.log( - { - "reward": rewards, - "epoch": epoch, - "reward_mean": rewards.mean(), - "reward_std": rewards.std(), - }, - step=global_step, - ) + rewards = self.compute_rewards( + prompt_image_pairs + ) + + prompt_image_pairs["rewards"] = rewards + + rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy() + + loss = self.calculate_loss(rewards) + + self.accelerator.backward(loss) + + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_( + self.trainable_layers.parameters() + if not isinstance(self.trainable_layers, list) + else self.trainable_layers, + self.config.train_max_grad_norm, + ) - if self.config.per_prompt_stat_tracking: - # gather the prompts across processes - prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy() - prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) - advantages = self.stat_tracker.update(prompts, rewards) + self.optimizer.step() + self.optimizer.zero_grad() + + info["reward_mean"].append(rewards_vis.mean()) + info["reward_std"].append(rewards_vis.std()) + info["loss"].append(loss.item()) + + # Checks if the accelerator has performed an optimization step behind the scenes + if self.accelerator.sync_gradients: + # log training-related stuff + info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()} + info = self.accelerator.reduce(info, reduction="mean") + info.update({"epoch": epoch, "inner_iters": inner_iters}) + self.accelerator.log(info, step=global_step) + global_step += 1 + info = defaultdict(list) else: - advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) - - # ungather advantages; keep the entries corresponding to the samples on this process - samples["advantages"] = ( - torch.as_tensor(advantages) - .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index] - .to(self.accelerator.device) - ) - - del samples["prompt_ids"] - - total_batch_size, num_timesteps = samples["timesteps"].shape - - for inner_epoch in range(self.config.train_num_inner_epochs): - # shuffle samples along batch dimension - perm = torch.randperm(total_batch_size, device=self.accelerator.device) - samples = {k: v[perm] for k, v in samples.items()} - - # shuffle along time dimension independently for each sample - # still trying to understand the code below - perms = torch.stack( - [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)] + raise ValueError( + "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." ) - - for key in ["timesteps", "latents", "next_latents", "log_probs"]: - samples[key] = samples[key][ - torch.arange(total_batch_size, device=self.accelerator.device)[:, None], - perms, - ] - - original_keys = samples.keys() - original_values = samples.values() - # rebatch them as user defined train_batch_size is different from sample_batch_size - reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values] - - # Transpose the list of original values - transposed_values = zip(*reshaped_values) - # Create new dictionaries for each row of transposed values - samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values] - - self.sd_pipeline.unet.train() - global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched) - # ensure optimization step at the end of the inner epoch - if not self.accelerator.sync_gradients: - raise ValueError( - "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." - ) + + # Logs generated images + if self.image_samples_callback is not None: + self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0]) if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: self.accelerator.save_state() - + return global_step - def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds): + def calculate_loss(self, rewards): """ Calculate the loss for a batch of an unpacked sample Args: - latents (torch.Tensor): - The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] - timesteps (torch.Tensor): - The timesteps sampled from the diffusion model, shape: [batch_size] - next_latents (torch.Tensor): - The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] - log_probs (torch.Tensor): - The log probabilities of the latents, shape: [batch_size] - advantages (torch.Tensor): - The advantages of the latents, shape: [batch_size] - embeds (torch.Tensor): - The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] - Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds - + rewards (torch.Tensor): + Differentiable reward scalars for each generated image, shape: [batch_size] Returns: - loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) + loss (torch.Tensor) (all of these are of shape (1,)) """ - with self.autocast(): - if self.config.train_cfg: - noise_pred = self.sd_pipeline.unet( - torch.cat([latents] * 2), - torch.cat([timesteps] * 2), - embeds, - ).sample - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - else: - noise_pred = self.sd_pipeline.unet( - latents, - timesteps, - embeds, - ).sample - # compute the log prob of next_latents given latents under the current model - - scheduler_step_output = self.sd_pipeline.scheduler_step( - noise_pred, - timesteps, - latents, - eta=self.config.sample_eta, - prev_sample=next_latents, - ) - - log_prob = scheduler_step_output.log_probs - - advantages = torch.clamp( - advantages, - -self.config.train_adv_clip_max, - self.config.train_adv_clip_max, - ) - - ratio = torch.exp(log_prob - log_probs) - - loss = self.loss(advantages, self.config.train_clip_range, ratio) - - approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2) - - clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float()) - - return loss, approx_kl, clipfrac + # Loss is specific to Aesthetic Reward function used in AlignProp (https://arxiv.org/pdf/2310.03739.pdf) + loss = 10.0 - (rewards).mean() + return loss def loss( self, @@ -443,154 +334,73 @@ def _load_model_hook(self, models, input_dir): self.sd_pipeline.load_checkpoint(models, input_dir) models.pop() # ensures that accelerate doesn't try to handle loading of the model - def _generate_samples(self, iterations, batch_size): + def _generate_samples(self, batch_size, with_grad= True): """ Generate samples from the model Args: - iterations (int): Number of iterations to generate samples for batch_size (int): Batch size to use for sampling + with_grad (bool): Whether the generated RGBs should have gradients attached to it. Returns: - samples (List[Dict[str, torch.Tensor]]), prompt_image_pairs (List[List[Any]]) + prompt_image_pairs (Dict[Any]) """ samples = [] - prompt_image_pairs = [] - self.sd_pipeline.unet.eval() + prompt_image_pairs = {} sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) - for _ in range(iterations): - prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) - - prompt_ids = self.sd_pipeline.tokenizer( - prompts, - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=self.sd_pipeline.tokenizer.model_max_length, - ).input_ids.to(self.accelerator.device) - prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0] - - with self.autocast(): - sd_output = self.sd_pipeline( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=sample_neg_prompt_embeds, - num_inference_steps=self.config.sample_num_steps, - guidance_scale=self.config.sample_guidance_scale, - eta=self.config.sample_eta, - output_type="pt", - ) - - images = sd_output.images - latents = sd_output.latents - log_probs = sd_output.log_probs - - latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...) - log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) - timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps) - - samples.append( - { - "prompt_ids": prompt_ids, - "prompt_embeds": prompt_embeds, - "timesteps": timesteps, - "latents": latents[:, :-1], # each entry is the latent before timestep t - "next_latents": latents[:, 1:], # each entry is the latent after timestep t - "log_probs": log_probs, - "negative_prompt_embeds": sample_neg_prompt_embeds, - } - ) - prompt_image_pairs.append([images, prompts, prompt_metadata]) + prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) - return samples, prompt_image_pairs + prompt_ids = self.sd_pipeline.tokenizer( + prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) - def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples): - """ - Train on a batch of samples. Main training segment + prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0] - Args: - inner_epoch (int): The current inner epoch - epoch (int): The current epoch - global_step (int): The current global step - batched_samples (List[Dict[str, torch.Tensor]]): The batched samples to train on + if with_grad: + sd_output = self.sd_pipeline.rgb_with_grad( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=self.config.sample_num_steps, + guidance_scale=self.config.sample_guidance_scale, + eta=self.config.sample_eta, + truncated_backprop_rand= self.config.truncated_backprop_rand, + truncated_backprop_timestep= self.config.truncated_backprop_timestep, + truncated_rand_backprop_minmax= self.config.truncated_rand_backprop_minmax, + output_type="pt", + ) + else: + sd_output = self.sd_pipeline( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=self.config.sample_num_steps, + guidance_scale=self.config.sample_guidance_scale, + eta=self.config.sample_eta, + output_type="pt", + ) + + images = sd_output.images + latents = sd_output.latents + log_probs = sd_output.log_probs - Side Effects: - - Model weights are updated - - Logs the statistics to the accelerator trackers. + prompt_image_pairs["images"] = images + prompt_image_pairs["prompts"] = prompts + prompt_image_pairs["prompt_metadata"] = prompt_metadata + + return prompt_image_pairs - Returns: - global_step (int): The updated global step - """ - info = defaultdict(list) - for _i, sample in enumerate(batched_samples): - if self.config.train_cfg: - # concat negative prompts to sample prompts to avoid two forward passes - embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]]) - else: - embeds = sample["prompt_embeds"] - - for j in range(self.num_train_timesteps): - with self.accelerator.accumulate(self.sd_pipeline.unet): - loss, approx_kl, clipfrac = self.calculate_loss( - sample["latents"][:, j], - sample["timesteps"][:, j], - sample["next_latents"][:, j], - sample["log_probs"][:, j], - sample["advantages"], - embeds, - ) - info["approx_kl"].append(approx_kl) - info["clipfrac"].append(clipfrac) - info["loss"].append(loss) - - self.accelerator.backward(loss) - if self.accelerator.sync_gradients: - self.accelerator.clip_grad_norm_( - self.trainable_layers.parameters() - if not isinstance(self.trainable_layers, list) - else self.trainable_layers, - self.config.train_max_grad_norm, - ) - self.optimizer.step() - self.optimizer.zero_grad() - - # Checks if the accelerator has performed an optimization step behind the scenes - if self.accelerator.sync_gradients: - # log training-related stuff - info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} - info = self.accelerator.reduce(info, reduction="mean") - info.update({"epoch": epoch, "inner_epoch": inner_epoch}) - self.accelerator.log(info, step=global_step) - global_step += 1 - info = defaultdict(list) - return global_step def _config_check(self) -> Tuple[bool, str]: - samples_per_epoch = ( - self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch - ) total_train_batch_size = ( self.config.train_batch_size * self.accelerator.num_processes * self.config.train_gradient_accumulation_steps ) - - if not self.config.sample_batch_size >= self.config.train_batch_size: - return ( - False, - f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})", - ) - if not self.config.sample_batch_size % self.config.train_batch_size == 0: - return ( - False, - f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})", - ) - if not samples_per_epoch % total_train_batch_size == 0: - return ( - False, - f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})", - ) return True, "" def train(self, epochs: Optional[int] = None): @@ -603,12 +413,12 @@ def train(self, epochs: Optional[int] = None): for epoch in range(self.first_epoch, epochs): global_step = self.step(epoch, global_step) - def create_model_card(self, path: str, model_name: Optional[str] = "TRL DDPO Model") -> None: + def create_model_card(self, path: str, model_name: Optional[str] = "TRL AlignProp Model") -> None: """Creates and saves a model card for a TRL model. Args: path (`str`): The path to save the model card to. - model_name (`str`, *optional*): The name of the model, defaults to `TRL DDPO Model`. + model_name (`str`, *optional*): The name of the model, defaults to `TRL AlignProp Model`. """ try: user = whoami()["name"] From 4b9dc1afd1d059a1abc3c9dd053037245ca727ab Mon Sep 17 00:00:00 2001 From: Mihir Prabhudesai Date: Thu, 25 Apr 2024 14:00:43 -0400 Subject: [PATCH 03/15] Update alignprop_trainer.mdx --- docs/source/alignprop_trainer.mdx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/alignprop_trainer.mdx b/docs/source/alignprop_trainer.mdx index 13e7f9492d..5cd0fc34f2 100644 --- a/docs/source/alignprop_trainer.mdx +++ b/docs/source/alignprop_trainer.mdx @@ -1,7 +1,8 @@ -# Denoising Diffusion Policy Optimization +# Aligning Text-to-Image Diffusion Models with Reward Backpropagation + ## The why -| Before | After DDPO finetuning | +| Before | After finetuning | | --- | --- | |
|
| |
|
| @@ -113,4 +114,4 @@ for prompt, image in zip(prompts,results.images): ## Credits This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation - by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://arxiv.org/abs/2310.03739). \ No newline at end of file + by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://arxiv.org/abs/2310.03739). From af84272a6dfbcc29ecb68f0900374718a09b64ec Mon Sep 17 00:00:00 2001 From: Mihir Prabhudesai Date: Tue, 28 May 2024 17:44:19 -0400 Subject: [PATCH 04/15] Update alignprop_trainer.mdx --- docs/source/alignprop_trainer.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/alignprop_trainer.mdx b/docs/source/alignprop_trainer.mdx index 5cd0fc34f2..cc39c1bead 100644 --- a/docs/source/alignprop_trainer.mdx +++ b/docs/source/alignprop_trainer.mdx @@ -15,10 +15,10 @@ The machinery for finetuning of Stable Diffusion models with reinforcement learn library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers. Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made. -There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.** +There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `AlignPropTrainer`, which is one of the methods for fine-tuning Stable Diffusion with reward backpropagation. **Note: Only the StableDiffusion architecture is supported at this point.** There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide. -The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO). +The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is one of the way to pass gradients from the scheduler step befitting of the algorithm at hand (AlignProp). For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py) From f3ff177556a3cce347caf8b1317364e077ba989d Mon Sep 17 00:00:00 2001 From: matrix Mihir Prabhduesai Date: Sat, 1 Jun 2024 16:34:17 -0400 Subject: [PATCH 05/15] added better why statement --- docs/source/alignprop_trainer.mdx | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/source/alignprop_trainer.mdx b/docs/source/alignprop_trainer.mdx index cc39c1bead..2094790566 100644 --- a/docs/source/alignprop_trainer.mdx +++ b/docs/source/alignprop_trainer.mdx @@ -2,12 +2,16 @@ ## The why -| Before | After finetuning | -| --- | --- | -|
|
| -|
|
| -|
|
| +If your reward function is differentiable, directly backpropagating gradients from the reward models is significantly more sample efficient than doing policy gradient algorithm like DDPO. +AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation. + +
+ + ## Getting started with Stable Diffusion finetuning with reinforcement learning From c3fe757e923f5ce055f19dd03398ecf5485c15ef Mon Sep 17 00:00:00 2001 From: matrix Mihir Prabhduesai Date: Sat, 1 Jun 2024 22:12:56 -0400 Subject: [PATCH 06/15] fixed inference code --- docs/source/alignprop_trainer.mdx | 42 +++++-------------------------- 1 file changed, 6 insertions(+), 36 deletions(-) diff --git a/docs/source/alignprop_trainer.mdx b/docs/source/alignprop_trainer.mdx index 2094790566..f1c508f529 100644 --- a/docs/source/alignprop_trainer.mdx +++ b/docs/source/alignprop_trainer.mdx @@ -2,33 +2,11 @@ ## The why - -If your reward function is differentiable, directly backpropagating gradients from the reward models is significantly more sample efficient than doing policy gradient algorithm like DDPO. +If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO. AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
- - -## Getting started with Stable Diffusion finetuning with reinforcement learning - -The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers` -library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers. -Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made. - -There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `AlignPropTrainer`, which is one of the methods for fine-tuning Stable Diffusion with reward backpropagation. **Note: Only the StableDiffusion architecture is supported at this point.** -There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide. - -The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is one of the way to pass gradients from the scheduler step befitting of the algorithm at hand (AlignProp). - -For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py) - -Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training. - -Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images. ## Getting started with `examples/scripts/alignprop.py` @@ -94,25 +72,17 @@ def image_outputs_hook(image_data, global_step, accelerate_logger): Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows ```python +from diffusers import StableDiffusionPipeline +pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") +pipeline.to("cuda") -import torch -from trl import DefaultDDPOStableDiffusionPipeline - -pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/alignprop-finetuned-sd-model") - -device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - -# memory optimization -pipeline.vae.to(device, torch.float16) -pipeline.text_encoder.to(device, torch.float16) -pipeline.unet.to(device, torch.float16) +pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics') prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"] results = pipeline(prompts) for prompt, image in zip(prompts,results.images): - image.save(f"{prompt}.png") - + image.save(f"dump/{prompt}.png") ``` ## Credits From 4f8501ec3aede13bae05322e32855acdcc17052d Mon Sep 17 00:00:00 2001 From: matrix Mihir Prabhduesai Date: Sun, 2 Jun 2024 01:06:22 -0400 Subject: [PATCH 07/15] changed self to pipeline --- trl/models/modeling_sd_base.py | 56 ++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/trl/models/modeling_sd_base.py b/trl/models/modeling_sd_base.py index c9ce23fb88..14e5fb9a6f 100644 --- a/trl/models/modeling_sd_base.py +++ b/trl/models/modeling_sd_base.py @@ -528,7 +528,7 @@ def pipeline_step( return DDPOPipelineOutput(image, all_latents, all_log_probs) def pipeline_step_with_grad( - self, + pipeline, prompt: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, @@ -554,8 +554,8 @@ def pipeline_step_with_grad( guidance_rescale: float = 0.0, ): r""" - Function to get RGB image with gradients attached to the model weights. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + Function to get RGB image with gradients attached to the model weights. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to pipeline.unet.config.sample_size * pipeline.vae_scale_factor): The height in pixels of the generated image. + width (`int`, *optional*, defaults to pipeline.unet.config.sample_size * pipeline.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -617,7 +617,7 @@ def pipeline_step_with_grad( called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in + `pipeline.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are @@ -631,12 +631,12 @@ def pipeline_step_with_grad( `DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities """ # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor + height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor + width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor with torch.no_grad(): # 1. Check inputs. Raise error if not correct - self.check_inputs( + pipeline.check_inputs( prompt, height, width, @@ -654,7 +654,7 @@ def pipeline_step_with_grad( else: batch_size = prompt_embeds.shape[0] - device = self._execution_device + device = pipeline._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -662,7 +662,7 @@ def pipeline_step_with_grad( # 3. Encode input prompt text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - prompt_embeds = self._encode_prompt( + prompt_embeds = pipeline._encode_prompt( prompt, device, num_images_per_prompt, @@ -674,12 +674,12 @@ def pipeline_step_with_grad( ) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + pipeline.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = pipeline.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( + num_channels_latents = pipeline.unet.config.in_channels + latents = pipeline.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, @@ -690,20 +690,20 @@ def pipeline_step_with_grad( latents, ) # 6. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order all_latents = [latents] all_log_probs = [] - with self.progress_bar(total=num_inference_steps) as progress_bar: + with pipeline.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual if gradient_checkpoint: - noise_pred = checkpoint.checkpoint(self.unet, latent_model_input, t, prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, use_reentrant=False)[0] + noise_pred = checkpoint.checkpoint(pipeline.unet, latent_model_input, t, prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, use_reentrant=False)[0] else: - noise_pred = self.unet( + noise_pred = pipeline.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, @@ -711,13 +711,17 @@ def pipeline_step_with_grad( return_dict=False, )[0] - + # truncating backpropagation is critical for preventing overoptimization (https://arxiv.org/abs/2304.05977). if truncated_backprop: + # Randomized truncation randomizes the truncation process (https://arxiv.org/abs/2310.03739) + # the range of truncation is defined by truncated_rand_backprop_minmax + # Setting truncated_rand_backprop_minmax[0] to be low will allow the model to update earlier timesteps in the diffusion chain, while setitng it high will reduce the memory usage. if truncated_backprop_rand: rand_timestep = random.randint(truncated_rand_backprop_minmax[0],truncated_rand_backprop_minmax[1]) if i < rand_timestep: noise_pred = noise_pred.detach() else: + # fixed truncation process if i < truncated_backprop_timestep: noise_pred = noise_pred.detach() @@ -731,7 +735,7 @@ def pipeline_step_with_grad( noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = scheduler_step(self.scheduler, noise_pred, t, latents, eta) + scheduler_output = scheduler_step(pipeline.scheduler, noise_pred, t, latents, eta) latents = scheduler_output.latents log_prob = scheduler_output.log_probs @@ -739,14 +743,14 @@ def pipeline_step_with_grad( all_log_probs.append(log_prob) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = pipeline.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None @@ -756,11 +760,11 @@ def pipeline_step_with_grad( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = pipeline.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() + if hasattr(pipeline, "final_offload_hook") and pipeline.final_offload_hook is not None: + pipeline.final_offload_hook.offload() return DDPOPipelineOutput(image, all_latents, all_log_probs) From 34af985e69ae846c2231b80603dc6a8a9e67104c Mon Sep 17 00:00:00 2001 From: matrix Mihir Prabhduesai Date: Sun, 2 Jun 2024 01:07:53 -0400 Subject: [PATCH 08/15] removed aesthetic classifier --- examples/scripts/alignprop.py | 87 ++--------------------------------- 1 file changed, 4 insertions(+), 83 deletions(-) diff --git a/examples/scripts/alignprop.py b/examples/scripts/alignprop.py index cb5ea960bd..42c79c2c28 100644 --- a/examples/scripts/alignprop.py +++ b/examples/scripts/alignprop.py @@ -25,18 +25,13 @@ --log_with="wandb" """ import os -import torchvision from dataclasses import dataclass, field - import numpy as np import torch import torch.nn as nn -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import EntryNotFoundError -from transformers import CLIPModel, CLIPProcessor, HfArgumentParser - +from transformers import HfArgumentParser from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline -from trl.import_utils import is_npu_available, is_xpu_available +from trl.models.auxiliary_modules import aesthetic_scorer @dataclass @@ -46,7 +41,7 @@ class ScriptArguments: ) pretrained_revision: str = field(default="main", metadata={"help": "the pretrained model revision to use"}) hf_hub_model_id: str = field( - default="ddpo-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"} + default="alignprop-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"} ) hf_hub_aesthetic_model_id: str = field( default="trl-lib/ddpo-aesthetic-predictor", @@ -59,78 +54,6 @@ class ScriptArguments: use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."}) -class MLP(nn.Module): - def __init__(self): - super().__init__() - self.layers = nn.Sequential( - nn.Linear(768, 1024), - nn.Dropout(0.2), - nn.Linear(1024, 128), - nn.Dropout(0.2), - nn.Linear(128, 64), - nn.Dropout(0.1), - nn.Linear(64, 16), - nn.Linear(16, 1), - ) - - def forward(self, embed): - return self.layers(embed) - - -class AestheticScorer(torch.nn.Module): - """ - This model attempts to predict the aesthetic score of an image. The aesthetic score - is a numerical approximation of how much a specific image is liked by humans on average. - This is from https://github.com/christophschuhmann/improved-aesthetic-predictor - """ - - def __init__(self, *, dtype, model_id, model_filename): - super().__init__() - self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") - self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], - std=[0.26862954, 0.26130258, 0.27577711]) - self.target_size = 224 - self.mlp = MLP() - try: - cached_path = hf_hub_download(model_id, model_filename) - except EntryNotFoundError: - cached_path = os.path.join(model_id, model_filename) - state_dict = torch.load(cached_path, map_location=torch.device("cpu")) - self.mlp.load_state_dict(state_dict) - self.dtype = dtype - self.eval() - - def __call__(self, images): - device = next(self.parameters()).device - images = torchvision.transforms.Resize(self.target_size)(images) - images = self.normalize(images).to(self.dtype).to(device) - embed = self.clip.get_image_features(pixel_values=images) - # normalize embedding - embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) - reward = self.mlp(embed).squeeze(1) - return reward - - -def aesthetic_scorer(hub_model_id, model_filename): - scorer = AestheticScorer( - model_id=hub_model_id, - model_filename=model_filename, - dtype=torch.float32, - ) - if is_npu_available(): - scorer = scorer.npu() - elif is_xpu_available(): - scorer = scorer.xpu() - else: - scorer = scorer.cuda() - - def _fn(images, prompts, metadata): - images = (images).clamp(0, 1) - scores = scorer(images) - return scores, {} - - return _fn - # list of example prompts to feed stable diffusion animals = [ @@ -177,8 +100,7 @@ def image_outputs_logger(image_pair_data, global_step, accelerate_logger): for i, image in enumerate(images[:4]): prompt = prompts[i] reward = rewards[i].item() - result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float() - + result[f"{prompt}"] = image.unsqueeze(0).float() accelerate_logger.log_images( result, step=global_step, @@ -198,7 +120,6 @@ def image_outputs_logger(image_pair_data, global_step, accelerate_logger): pipeline = DefaultDDPOStableDiffusionPipeline( args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora ) - trainer = AlignPropTrainer( alignprop_config, aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename), From c0a6ce32014db391e19b24e36123bd4fcb258c8c Mon Sep 17 00:00:00 2001 From: matrix Mihir Prabhduesai Date: Sun, 2 Jun 2024 01:08:34 -0400 Subject: [PATCH 09/15] added aesthetic to auxiliary models --- trl/models/auxiliary_modules.py | 92 +++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 trl/models/auxiliary_modules.py diff --git a/trl/models/auxiliary_modules.py b/trl/models/auxiliary_modules.py new file mode 100644 index 0000000000..27f1483447 --- /dev/null +++ b/trl/models/auxiliary_modules.py @@ -0,0 +1,92 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch.nn as nn +import torch +from transformers import CLIPModel, CLIPProcessor +from huggingface_hub.utils import EntryNotFoundError +from huggingface_hub import hf_hub_download +from trl.import_utils import is_npu_available, is_xpu_available +import torchvision + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(768, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + def forward(self, embed): + return self.layers(embed) + + +class AestheticScorer(torch.nn.Module): + """ + This model attempts to predict the aesthetic score of an image. The aesthetic score + is a numerical approximation of how much a specific image is liked by humans on average. + This is from https://github.com/christophschuhmann/improved-aesthetic-predictor + """ + + def __init__(self, *, dtype, model_id, model_filename): + super().__init__() + self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + self.target_size = 224 + self.mlp = MLP() + try: + cached_path = hf_hub_download(model_id, model_filename) + except EntryNotFoundError: + cached_path = os.path.join(model_id, model_filename) + state_dict = torch.load(cached_path, map_location=torch.device("cpu")) + self.mlp.load_state_dict(state_dict) + self.dtype = dtype + self.eval() + + def __call__(self, images): + device = next(self.parameters()).device + images = torchvision.transforms.Resize(self.target_size)(images) + images = self.normalize(images).to(self.dtype).to(device) + embed = self.clip.get_image_features(pixel_values=images) + # normalize embedding + embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) + reward = self.mlp(embed).squeeze(1) + return reward + + +def aesthetic_scorer(hub_model_id, model_filename): + scorer = AestheticScorer( + model_id=hub_model_id, + model_filename=model_filename, + dtype=torch.float32, + ) + if is_npu_available(): + scorer = scorer.npu() + elif is_xpu_available(): + scorer = scorer.xpu() + else: + scorer = scorer.cuda() + + def _fn(images, prompts, metadata): + images = (images).clamp(0, 1) + scores = scorer(images) + return scores, {} + + return _fn From fbeafbc0543e6c571c69d8ab078307a897e334dc Mon Sep 17 00:00:00 2001 From: matrix Mihir Prabhduesai Date: Sun, 2 Jun 2024 01:09:54 -0400 Subject: [PATCH 10/15] added unseen prompt logging --- trl/trainer/alignprop_trainer.py | 34 ++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py index eeff83693d..1530582356 100644 --- a/trl/trainer/alignprop_trainer.py +++ b/trl/trainer/alignprop_trainer.py @@ -47,7 +47,7 @@ # {model_name} -This is a pipeline that finetunes a diffusion model with reward gradients. The model can be used for image generation conditioned with text. +This is a pipeline that finetunes a diffusion model with reward backpropagation while using randomized truncation (https://arxiv.org/abs/2310.03739). The model can be used for image generation conditioned with text. """ @@ -271,10 +271,28 @@ def step(self, epoch: int, global_step: int): raise ValueError( "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." ) - # Logs generated images - if self.image_samples_callback is not None: - self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0]) + if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0: + with torch.no_grad(): + train_prompts = ['lion', 'duck', 'llama', 'cat'] + train_prompt_image_pairs = self._generate_samples( + batch_size=len(train_prompts), prompts=train_prompts + ) + train_rewards = self.compute_rewards( + train_prompt_image_pairs + ) + train_prompt_image_pairs['rewards'] = train_rewards + test_prompts = ['elephant', 'dolphin', 'panda', 'penguin', 'octopus', 'koala', 'crocodile', 'chimpanzee'] + test_prompt_image_pairs = self._generate_samples( + batch_size=len(test_prompts), prompts=test_prompts + ) + test_rewards = self.compute_rewards(test_prompt_image_pairs) + test_prompt_image_pairs['rewards'] = test_rewards + + logs = {"test_reward": test_rewards.mean(), "train_reward": train_rewards.mean()} + self.accelerator.log(logs, step=global_step) + + self.image_samples_callback(train_prompt_image_pairs, global_step, self.accelerator.trackers[0]) if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: self.accelerator.save_state() @@ -334,7 +352,7 @@ def _load_model_hook(self, models, input_dir): self.sd_pipeline.load_checkpoint(models, input_dir) models.pop() # ensures that accelerate doesn't try to handle loading of the model - def _generate_samples(self, batch_size, with_grad= True): + def _generate_samples(self, batch_size, with_grad= True, prompts=None): """ Generate samples from the model @@ -350,7 +368,11 @@ def _generate_samples(self, batch_size, with_grad= True): sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) - prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) + if prompts is None: + prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) + else: + prompt_metadata = [{} for _ in range(batch_size)] + prompt_ids = self.sd_pipeline.tokenizer( prompts, From d804207e060555e2ed9fd63119935c543ff6859c Mon Sep 17 00:00:00 2001 From: matrix Mihir Prabhduesai Date: Sun, 2 Jun 2024 03:52:59 -0400 Subject: [PATCH 11/15] removed unseen prompt log --- trl/trainer/alignprop_config.py | 2 ++ trl/trainer/alignprop_trainer.py | 21 +-------------------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/trl/trainer/alignprop_config.py b/trl/trainer/alignprop_config.py index 1af8c27add..fb7878de6a 100644 --- a/trl/trainer/alignprop_config.py +++ b/trl/trainer/alignprop_config.py @@ -23,6 +23,8 @@ class AlignPropConfig: """Seed value for random generations""" log_with: Optional[Literal["wandb", "tensorboard"]] = None """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" + log_image_freq =1 + """Logging Frequency for images""" tracker_kwargs: dict = field(default_factory=dict) """Keyword arguments for the tracker (e.g. wandb_project)""" accelerator_kwargs: dict = field(default_factory=dict) diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py index 1530582356..ba81eff354 100644 --- a/trl/trainer/alignprop_trainer.py +++ b/trl/trainer/alignprop_trainer.py @@ -273,26 +273,7 @@ def step(self, epoch: int, global_step: int): ) # Logs generated images if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0: - with torch.no_grad(): - train_prompts = ['lion', 'duck', 'llama', 'cat'] - train_prompt_image_pairs = self._generate_samples( - batch_size=len(train_prompts), prompts=train_prompts - ) - train_rewards = self.compute_rewards( - train_prompt_image_pairs - ) - train_prompt_image_pairs['rewards'] = train_rewards - test_prompts = ['elephant', 'dolphin', 'panda', 'penguin', 'octopus', 'koala', 'crocodile', 'chimpanzee'] - test_prompt_image_pairs = self._generate_samples( - batch_size=len(test_prompts), prompts=test_prompts - ) - test_rewards = self.compute_rewards(test_prompt_image_pairs) - test_prompt_image_pairs['rewards'] = test_rewards - - logs = {"test_reward": test_rewards.mean(), "train_reward": train_rewards.mean()} - self.accelerator.log(logs, step=global_step) - - self.image_samples_callback(train_prompt_image_pairs, global_step, self.accelerator.trackers[0]) + self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0]) if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: self.accelerator.save_state() From 405db53cf8482e434f1cdef245eb0f671a0cc4cc Mon Sep 17 00:00:00 2001 From: matrix Mihir Prabhduesai Date: Sun, 2 Jun 2024 04:02:02 -0400 Subject: [PATCH 12/15] fixed minor --- examples/scripts/alignprop.py | 12 ++++-------- trl/trainer/alignprop_config.py | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/examples/scripts/alignprop.py b/examples/scripts/alignprop.py index 42c79c2c28..c783e7bf9c 100644 --- a/examples/scripts/alignprop.py +++ b/examples/scripts/alignprop.py @@ -12,15 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -python examples/scripts/ddpo.py \ - --num_epochs=200 \ - --train_gradient_accumulation_steps=1 \ +python examples/scripts/alignprop.py \ + --num_epochs=20 \ + --train_gradient_accumulation_steps=4 \ --sample_num_steps=50 \ - --sample_batch_size=6 \ - --train_batch_size=3 \ - --sample_num_batches_per_epoch=4 \ - --per_prompt_stat_tracking=True \ - --per_prompt_stat_tracking_buffer_size=32 \ + --train_batch_size=8 \ --tracker_project_name="stable_diffusion_training" \ --log_with="wandb" """ diff --git a/trl/trainer/alignprop_config.py b/trl/trainer/alignprop_config.py index fb7878de6a..2634b39e5c 100644 --- a/trl/trainer/alignprop_config.py +++ b/trl/trainer/alignprop_config.py @@ -59,7 +59,7 @@ class AlignPropConfig: """Batch size (per GPU!) to use for training.""" train_use_8bit_adam: bool = False """Whether to use the 8bit Adam optimizer from bitsandbytes.""" - train_learning_rate: float = 3e-4 + train_learning_rate: float = 1e-3 """Learning rate.""" train_adam_beta1: float = 0.9 """Adam beta1.""" From 8296a07fd5a815d9c1c31f2314d4cc1c080f0623 Mon Sep 17 00:00:00 2001 From: Mihir Prabhudesai Date: Wed, 12 Jun 2024 09:20:22 -0400 Subject: [PATCH 13/15] remove not needed import in trl/__init__.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- trl/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/__init__.py b/trl/__init__.py index 8601cd8324..fcd91595a4 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable -import torch.utils.checkpoint as checkpoint _import_structure = { "core": [ From 4607f96634b7f1a44db98ca1ad319a3878d8a4bf Mon Sep 17 00:00:00 2001 From: matrix Mihir Prabhduesai Date: Wed, 12 Jun 2024 16:47:33 -0400 Subject: [PATCH 14/15] fixed styling --- examples/scripts/alignprop.py | 18 ++++---- tests/test_alignprop_trainer.py | 13 +++--- trl/__init__.py | 1 - trl/models/auxiliary_modules.py | 20 +++++---- trl/models/modeling_sd_base.py | 39 +++++++++++------ trl/trainer/__init__.py | 4 +- trl/trainer/alignprop_config.py | 7 ++-- trl/trainer/alignprop_trainer.py | 72 ++++++++++++-------------------- 8 files changed, 84 insertions(+), 90 deletions(-) diff --git a/examples/scripts/alignprop.py b/examples/scripts/alignprop.py index c783e7bf9c..f482c49da8 100644 --- a/examples/scripts/alignprop.py +++ b/examples/scripts/alignprop.py @@ -12,20 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -python examples/scripts/alignprop.py \ +Total Batch size = 128 = 4 (num_gpus) * 8 (per_device_batch) * 4 (accumulation steps) +Feel free to reduce batch size or increasing truncated_rand_backprop_min to a higher value to reduce memory usage. + +CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/scripts/alignprop.py \ --num_epochs=20 \ --train_gradient_accumulation_steps=4 \ --sample_num_steps=50 \ --train_batch_size=8 \ --tracker_project_name="stable_diffusion_training" \ --log_with="wandb" + """ -import os from dataclasses import dataclass, field + import numpy as np -import torch -import torch.nn as nn -from transformers import HfArgumentParser +from transformers import HfArgumentParser + from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline from trl.models.auxiliary_modules import aesthetic_scorer @@ -50,7 +53,6 @@ class ScriptArguments: use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."}) - # list of example prompts to feed stable diffusion animals = [ "cat", @@ -87,15 +89,13 @@ def prompt_fn(): return np.random.choice(animals), {} - def image_outputs_logger(image_pair_data, global_step, accelerate_logger): # For the sake of this example, we will only log the last batch of images # and associated data result = {} - images, prompts, rewards = [image_pair_data['images'],image_pair_data['prompts'],image_pair_data['rewards']] + images, prompts, _ = [image_pair_data["images"], image_pair_data["prompts"], image_pair_data["rewards"]] for i, image in enumerate(images[:4]): prompt = prompts[i] - reward = rewards[i].item() result[f"{prompt}"] = image.unsqueeze(0).float() accelerate_logger.log_images( result, diff --git a/tests/test_alignprop_trainer.py b/tests/test_alignprop_trainer.py index 9058fd35db..7faff69da3 100644 --- a/tests/test_alignprop_trainer.py +++ b/tests/test_alignprop_trainer.py @@ -13,6 +13,7 @@ # limitations under the License. import gc import unittest + import torch from trl import is_diffusers_available, is_peft_available @@ -64,21 +65,19 @@ def tearDown(self) -> None: def test_generate_samples(self): output_pairs = self.trainer._generate_samples(2, with_grad=True) assert len(output_pairs.keys()) == 3 - assert len(output_pairs['images']) == 2 + assert len(output_pairs["images"]) == 2 def test_calculate_loss(self): sample = self.trainer._generate_samples(2) - + images = sample["images"] prompts = sample["prompts"] assert images.shape == (2, 3, 128, 128) assert len(prompts) == 2 - rewards = self.trainer.compute_rewards( - sample - ) - loss = self.trainer.calculate_loss(rewards) + rewards = self.trainer.compute_rewards(sample) + loss = self.trainer.calculate_loss(rewards) assert torch.isfinite(loss.cpu()) @@ -97,7 +96,7 @@ def setUp(self): truncated_backprop_rand=False, save_freq=1000000, ) - + pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch" pretrained_revision = "main" diff --git a/trl/__init__.py b/trl/__init__.py index 8601cd8324..fcd91595a4 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable -import torch.utils.checkpoint as checkpoint _import_structure = { "core": [ diff --git a/trl/models/auxiliary_modules.py b/trl/models/auxiliary_modules.py index 27f1483447..eb6b71936e 100644 --- a/trl/models/auxiliary_modules.py +++ b/trl/models/auxiliary_modules.py @@ -11,13 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch.nn as nn +import os import torch -from transformers import CLIPModel, CLIPProcessor -from huggingface_hub.utils import EntryNotFoundError +import torch.nn as nn +import torchvision from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError +from transformers import CLIPModel + from trl.import_utils import is_npu_available, is_xpu_available -import torchvision + class MLP(nn.Module): def __init__(self): @@ -47,9 +50,10 @@ class AestheticScorer(torch.nn.Module): def __init__(self, *, dtype, model_id, model_filename): super().__init__() self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") - self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], - std=[0.26862954, 0.26130258, 0.27577711]) - self.target_size = 224 + self.normalize = torchvision.transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] + ) + self.target_size = 224 self.mlp = MLP() try: cached_path = hf_hub_download(model_id, model_filename) @@ -67,7 +71,7 @@ def __call__(self, images): embed = self.clip.get_image_features(pixel_values=images) # normalize embedding embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) - reward = self.mlp(embed).squeeze(1) + reward = self.mlp(embed).squeeze(1) return reward diff --git a/trl/models/modeling_sd_base.py b/trl/models/modeling_sd_base.py index 14e5fb9a6f..44d3e70034 100644 --- a/trl/models/modeling_sd_base.py +++ b/trl/models/modeling_sd_base.py @@ -14,13 +14,14 @@ import contextlib import os +import random import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union -import torch.utils.checkpoint as checkpoint -import random + import numpy as np import torch +import torch.utils.checkpoint as checkpoint from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg @@ -527,6 +528,7 @@ def pipeline_step( return DDPOPipelineOutput(image, all_latents, all_log_probs) + def pipeline_step_with_grad( pipeline, prompt: Optional[Union[str, List[str]]] = None, @@ -538,7 +540,7 @@ def pipeline_step_with_grad( truncated_backprop_rand: bool = True, gradient_checkpoint: bool = True, truncated_backprop_timestep: int = 49, - truncated_rand_backprop_minmax: tuple = (0,50), + truncated_rand_backprop_minmax: tuple = (0, 50), negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, @@ -574,8 +576,8 @@ def pipeline_step_with_grad( gradient_checkpoint (`bool`, *optional*, defaults to True): Adds gradient checkpointing to Unet forward pass. Reduces GPU memory consumption while slightly increasing the training time. truncated_backprop_timestep (`int`, *optional*, defaults to 49): - Absolute timestep to which the gradients are being backpropagated. Higher number reduces the memory usage and reduces the chances of collapse. - While a lower value, allows more semantic changes in the diffusion generations, as the earlier diffusion timesteps are getting updated. + Absolute timestep to which the gradients are being backpropagated. Higher number reduces the memory usage and reduces the chances of collapse. + While a lower value, allows more semantic changes in the diffusion generations, as the earlier diffusion timesteps are getting updated. However it also increases the chances of collapse. truncated_rand_backprop_minmax (`Tuple`, *optional*, defaults to (0,50)): Range for randomized backprop. Here the value at 0 index indicates the earlier diffusion timestep to update (closer to noise), while the value @@ -659,9 +661,11 @@ def pipeline_step_with_grad( # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - + # 3. Encode input prompt - text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = pipeline._encode_prompt( prompt, device, @@ -698,11 +702,18 @@ def pipeline_step_with_grad( # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) - + # predict the noise residual if gradient_checkpoint: - noise_pred = checkpoint.checkpoint(pipeline.unet, latent_model_input, t, prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, use_reentrant=False)[0] - else: + noise_pred = checkpoint.checkpoint( + pipeline.unet, + latent_model_input, + t, + prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + use_reentrant=False, + )[0] + else: noise_pred = pipeline.unet( latent_model_input, t, @@ -711,13 +722,15 @@ def pipeline_step_with_grad( return_dict=False, )[0] - # truncating backpropagation is critical for preventing overoptimization (https://arxiv.org/abs/2304.05977). + # truncating backpropagation is critical for preventing overoptimization (https://arxiv.org/abs/2304.05977). if truncated_backprop: # Randomized truncation randomizes the truncation process (https://arxiv.org/abs/2310.03739) # the range of truncation is defined by truncated_rand_backprop_minmax # Setting truncated_rand_backprop_minmax[0] to be low will allow the model to update earlier timesteps in the diffusion chain, while setitng it high will reduce the memory usage. if truncated_backprop_rand: - rand_timestep = random.randint(truncated_rand_backprop_minmax[0],truncated_rand_backprop_minmax[1]) + rand_timestep = random.randint( + truncated_rand_backprop_minmax[0], truncated_rand_backprop_minmax[1] + ) if i < rand_timestep: noise_pred = noise_pred.detach() else: @@ -805,7 +818,7 @@ def __call__(self, *args, **kwargs) -> DDPOPipelineOutput: return pipeline_step(self.sd_pipeline, *args, **kwargs) def rgb_with_grad(self, *args, **kwargs) -> DDPOPipelineOutput: - return pipeline_step_with_grad(self.sd_pipeline, *args, **kwargs) + return pipeline_step_with_grad(self.sd_pipeline, *args, **kwargs) def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput: return scheduler_step(self.sd_pipeline.scheduler, *args, **kwargs) diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index a700c2c181..8d43e7761e 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -34,7 +34,7 @@ "cpo_config": ["CPOConfig"], "cpo_trainer": ["CPOTrainer"], "alignprop_config": ["AlignPropConfig"], - "alignprop_trainer": ["AlignPropTrainer"], + "alignprop_trainer": ["AlignPropTrainer"], "iterative_sft_trainer": ["IterativeSFTTrainer"], "kto_config": ["KTOConfig"], "kto_trainer": ["KTOTrainer"], @@ -83,7 +83,7 @@ from .iterative_sft_trainer import IterativeSFTTrainer from .cpo_config import CPOConfig from .cpo_trainer import CPOTrainer - from .alignprop_config import AlignPropConfig + from .alignprop_config import AlignPropConfig from .kto_config import KTOConfig from .kto_trainer import KTOTrainer from .model_config import ModelConfig diff --git a/trl/trainer/alignprop_config.py b/trl/trainer/alignprop_config.py index 2634b39e5c..7bd4cd32bd 100644 --- a/trl/trainer/alignprop_config.py +++ b/trl/trainer/alignprop_config.py @@ -23,8 +23,8 @@ class AlignPropConfig: """Seed value for random generations""" log_with: Optional[Literal["wandb", "tensorboard"]] = None """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" - log_image_freq =1 - """Logging Frequency for images""" + log_image_freq = 1 + """Logging Frequency for images""" tracker_kwargs: dict = field(default_factory=dict) """Keyword arguments for the tracker (e.g. wandb_project)""" accelerator_kwargs: dict = field(default_factory=dict) @@ -79,9 +79,8 @@ class AlignPropConfig: """Truncated Randomized Backpropation randomizes truncation to different diffusion timesteps""" truncated_backprop_timestep: int = 49 """Absolute timestep to which the gradients are being backpropagated. If truncated_backprop_rand is False""" - truncated_rand_backprop_minmax: tuple = (0,50) + truncated_rand_backprop_minmax: tuple = (0, 50) """Range of diffusion timesteps for randomized truncated backprop.""" - def to_dict(self): output_dict = {} diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py index ba81eff354..9024a410d8 100644 --- a/trl/trainer/alignprop_trainer.py +++ b/trl/trainer/alignprop_trainer.py @@ -14,7 +14,6 @@ import os import warnings from collections import defaultdict -from concurrent import futures from typing import Any, Callable, Optional, Tuple from warnings import warn @@ -25,10 +24,7 @@ from huggingface_hub import whoami from ..models import DDPOStableDiffusionPipeline -from .utils import PerPromptStatTracker -from . import BaseTrainer, AlignPropConfig - - +from . import AlignPropConfig, BaseTrainer logger = get_logger(__name__) @@ -107,7 +103,6 @@ def __init__( accelerator_project_config.iteration = checkpoint_numbers[-1] + 1 - self.accelerator = Accelerator( log_with=self.config.log_with, mixed_precision=self.config.mixed_precision, @@ -119,16 +114,14 @@ def __init__( **self.config.accelerator_kwargs, ) - is_okay, message = self._config_check() - if not is_okay: - raise ValueError(message) - is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" if self.accelerator.is_main_process: self.accelerator.init_trackers( self.config.tracker_project_name, - config=dict(alignprop_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), + config=dict(alignprop_trainer_config=config.to_dict()) + if not is_using_tensorboard + else config.to_dict(), init_kwargs=self.config.tracker_kwargs, ) @@ -201,7 +194,9 @@ def __init__( self.first_epoch = 0 def compute_rewards(self, prompt_image_pairs): - reward, reward_metadata = self.reward_fn(prompt_image_pairs['images'], prompt_image_pairs['prompts'], prompt_image_pairs['prompt_metadata']) + reward, reward_metadata = self.reward_fn( + prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"] + ) return reward def step(self, epoch: int, global_step: int): @@ -222,27 +217,25 @@ def step(self, epoch: int, global_step: int): """ info = defaultdict(list) - + self.sd_pipeline.unet.train() - - for inner_iters in range(self.config.train_gradient_accumulation_steps): + + for _ in range(self.config.train_gradient_accumulation_steps): with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad(): prompt_image_pairs = self._generate_samples( batch_size=self.config.train_batch_size, ) - rewards = self.compute_rewards( - prompt_image_pairs - ) - + rewards = self.compute_rewards(prompt_image_pairs) + prompt_image_pairs["rewards"] = rewards - + rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy() - + loss = self.calculate_loss(rewards) - + self.accelerator.backward(loss) - + if self.accelerator.sync_gradients: self.accelerator.clip_grad_norm_( self.trainable_layers.parameters() @@ -253,17 +246,17 @@ def step(self, epoch: int, global_step: int): self.optimizer.step() self.optimizer.zero_grad() - + info["reward_mean"].append(rewards_vis.mean()) info["reward_std"].append(rewards_vis.std()) - info["loss"].append(loss.item()) + info["loss"].append(loss.item()) # Checks if the accelerator has performed an optimization step behind the scenes if self.accelerator.sync_gradients: # log training-related stuff info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()} info = self.accelerator.reduce(info, reduction="mean") - info.update({"epoch": epoch, "inner_iters": inner_iters}) + info.update({"epoch": epoch}) self.accelerator.log(info, step=global_step) global_step += 1 info = defaultdict(list) @@ -277,7 +270,7 @@ def step(self, epoch: int, global_step: int): if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: self.accelerator.save_state() - + return global_step def calculate_loss(self, rewards): @@ -333,7 +326,7 @@ def _load_model_hook(self, models, input_dir): self.sd_pipeline.load_checkpoint(models, input_dir) models.pop() # ensures that accelerate doesn't try to handle loading of the model - def _generate_samples(self, batch_size, with_grad= True, prompts=None): + def _generate_samples(self, batch_size, with_grad=True, prompts=None): """ Generate samples from the model @@ -344,7 +337,6 @@ def _generate_samples(self, batch_size, with_grad= True, prompts=None): Returns: prompt_image_pairs (Dict[Any]) """ - samples = [] prompt_image_pairs = {} sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) @@ -354,7 +346,6 @@ def _generate_samples(self, batch_size, with_grad= True, prompts=None): else: prompt_metadata = [{} for _ in range(batch_size)] - prompt_ids = self.sd_pipeline.tokenizer( prompts, return_tensors="pt", @@ -372,9 +363,9 @@ def _generate_samples(self, batch_size, with_grad= True, prompts=None): num_inference_steps=self.config.sample_num_steps, guidance_scale=self.config.sample_guidance_scale, eta=self.config.sample_eta, - truncated_backprop_rand= self.config.truncated_backprop_rand, - truncated_backprop_timestep= self.config.truncated_backprop_timestep, - truncated_rand_backprop_minmax= self.config.truncated_rand_backprop_minmax, + truncated_backprop_rand=self.config.truncated_backprop_rand, + truncated_backprop_timestep=self.config.truncated_backprop_timestep, + truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax, output_type="pt", ) else: @@ -386,25 +377,14 @@ def _generate_samples(self, batch_size, with_grad= True, prompts=None): eta=self.config.sample_eta, output_type="pt", ) - + images = sd_output.images - latents = sd_output.latents - log_probs = sd_output.log_probs prompt_image_pairs["images"] = images prompt_image_pairs["prompts"] = prompts prompt_image_pairs["prompt_metadata"] = prompt_metadata - - return prompt_image_pairs - - def _config_check(self) -> Tuple[bool, str]: - total_train_batch_size = ( - self.config.train_batch_size - * self.accelerator.num_processes - * self.config.train_gradient_accumulation_steps - ) - return True, "" + return prompt_image_pairs def train(self, epochs: Optional[int] = None): """ From 32ec653f74ea141967eef2f6fcf709c887f6d211 Mon Sep 17 00:00:00 2001 From: matrix Mihir Prabhduesai Date: Wed, 12 Jun 2024 16:52:42 -0400 Subject: [PATCH 15/15] updated _toctree --- docs/source/_toctree.yml | 2 ++ trl/models/auxiliary_modules.py | 1 + 2 files changed, 3 insertions(+) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index e69a418f30..37ca392e4e 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -37,6 +37,8 @@ title: CPO Trainer - local: ddpo_trainer title: Denoising Diffusion Policy Optimization + - local: alignprop_trainer + title: AlignProp Trainer - local: orpo_trainer title: ORPO Trainer - local: iterative_sft_trainer diff --git a/trl/models/auxiliary_modules.py b/trl/models/auxiliary_modules.py index eb6b71936e..ed1f9b7507 100644 --- a/trl/models/auxiliary_modules.py +++ b/trl/models/auxiliary_modules.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os + import torch import torch.nn as nn import torchvision