From 0eea2059072bec70c4d17437c0e0cac5597ce626 Mon Sep 17 00:00:00 2001 From: Anton Kukulianski Date: Fri, 7 Jun 2024 15:29:35 +0100 Subject: [PATCH] Add InstructPix2Pix pipeline support. --- optimum/neuron/__init__.py | 2 + optimum/neuron/modeling_diffusion.py | 8 +- optimum/neuron/pipelines/__init__.py | 2 + .../neuron/pipelines/diffusers/__init__.py | 1 + ...eline_stable_diffusion_instruct_pix2pix.py | 521 ++++++++++++++++++ tests/inference/inference_utils.py | 1 + .../test_stable_diffusion_pipeline.py | 17 + 7 files changed, 551 insertions(+), 1 deletion(-) create mode 100644 optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_instruct_pix2pix.py diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index 369107cc7..2da4598dc 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -44,6 +44,7 @@ "NeuronStableDiffusionPipeline", "NeuronStableDiffusionImg2ImgPipeline", "NeuronStableDiffusionInpaintPipeline", + "NeuronStableDiffusionInstructPix2PixPipeline", "NeuronLatentConsistencyModelPipeline", "NeuronStableDiffusionXLPipeline", "NeuronStableDiffusionXLImg2ImgPipeline", @@ -78,6 +79,7 @@ NeuronLatentConsistencyModelPipeline, NeuronStableDiffusionImg2ImgPipeline, NeuronStableDiffusionInpaintPipeline, + NeuronStableDiffusionInstructPix2PixPipeline, NeuronStableDiffusionPipeline, NeuronStableDiffusionXLImg2ImgPipeline, NeuronStableDiffusionXLInpaintPipeline, diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 6a85e738f..6a85aa253 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -89,6 +89,7 @@ NeuronLatentConsistencyPipelineMixin, NeuronStableDiffusionImg2ImgPipelineMixin, NeuronStableDiffusionInpaintPipelineMixin, + NeuronStableDiffusionInstructPix2PixPipelineMixin, NeuronStableDiffusionPipelineMixin, NeuronStableDiffusionXLImg2ImgPipelineMixin, NeuronStableDiffusionXLInpaintPipelineMixin, @@ -1003,6 +1004,12 @@ class NeuronStableDiffusionInpaintPipeline( __call__ = NeuronStableDiffusionInpaintPipelineMixin.__call__ +class NeuronStableDiffusionInstructPix2PixPipeline( + NeuronStableDiffusionPipelineBase, NeuronStableDiffusionInstructPix2PixPipelineMixin +): + __call__ = NeuronStableDiffusionInstructPix2PixPipelineMixin.__call__ + + class NeuronLatentConsistencyModelPipeline(NeuronStableDiffusionPipelineBase, NeuronLatentConsistencyPipelineMixin): __call__ = NeuronLatentConsistencyPipelineMixin.__call__ @@ -1081,7 +1088,6 @@ class NeuronStableDiffusionXLInpaintPipeline( if is_neuronx_available(): # TO REMOVE: This class will be included directly in the DDP API of Neuron SDK 2.20 class WeightSeparatedDataParallel(torch_neuronx.DataParallel): - def _load_modules(self, module): try: self.device_ids.sort() diff --git a/optimum/neuron/pipelines/__init__.py b/optimum/neuron/pipelines/__init__.py index 41312ce82..6e5c166be 100644 --- a/optimum/neuron/pipelines/__init__.py +++ b/optimum/neuron/pipelines/__init__.py @@ -24,6 +24,7 @@ "NeuronStableDiffusionPipelineMixin", "NeuronStableDiffusionImg2ImgPipelineMixin", "NeuronStableDiffusionInpaintPipelineMixin", + "NeuronStableDiffusionInstructPix2PixPipelineMixin", "NeuronLatentConsistencyPipelineMixin", "NeuronStableDiffusionXLPipelineMixin", "NeuronStableDiffusionXLImg2ImgPipelineMixin", @@ -36,6 +37,7 @@ NeuronLatentConsistencyPipelineMixin, NeuronStableDiffusionImg2ImgPipelineMixin, NeuronStableDiffusionInpaintPipelineMixin, + NeuronStableDiffusionInstructPix2PixPipelineMixin, NeuronStableDiffusionPipelineMixin, NeuronStableDiffusionXLImg2ImgPipelineMixin, NeuronStableDiffusionXLInpaintPipelineMixin, diff --git a/optimum/neuron/pipelines/diffusers/__init__.py b/optimum/neuron/pipelines/diffusers/__init__.py index b30664695..edd4922ca 100644 --- a/optimum/neuron/pipelines/diffusers/__init__.py +++ b/optimum/neuron/pipelines/diffusers/__init__.py @@ -17,6 +17,7 @@ from .pipeline_stable_diffusion import NeuronStableDiffusionPipelineMixin from .pipeline_stable_diffusion_img2img import NeuronStableDiffusionImg2ImgPipelineMixin from .pipeline_stable_diffusion_inpaint import NeuronStableDiffusionInpaintPipelineMixin +from .pipeline_stable_diffusion_instruct_pix2pix import NeuronStableDiffusionInstructPix2PixPipelineMixin from .pipeline_stable_diffusion_xl import NeuronStableDiffusionXLPipelineMixin from .pipeline_stable_diffusion_xl_img2img import NeuronStableDiffusionXLImg2ImgPipelineMixin from .pipeline_stable_diffusion_xl_inpaint import NeuronStableDiffusionXLInpaintPipelineMixin diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_instruct_pix2pix.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_instruct_pix2pix.py new file mode 100644 index 000000000..37dd27f71 --- /dev/null +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_instruct_pix2pix.py @@ -0,0 +1,521 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Override some diffusers API for NeuronStableDiffusionInstructPix2PixPipeline""" + +import logging +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union + +import PIL +import torch +from diffusers import StableDiffusionInstructPix2PixPipeline +from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.utils.deprecation_utils import deprecate + +from .pipeline_utils import StableDiffusionPipelineMixin + + +if TYPE_CHECKING: + from diffusers.image_processor import PipelineImageInput + + +logger = logging.getLogger(__name__) + + +class NeuronStableDiffusionInstructPix2PixPipelineMixin( + StableDiffusionPipelineMixin, StableDiffusionInstructPix2PixPipeline +): + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Optional["PipelineImageInput"] = None, + num_inference_steps: int = 100, + guidance_scale: float = 7.5, + image_guidance_scale: float = 1.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept + image latents as `image`, but if passing latents directly it is not encoded again. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + image_guidance_scale (`float`, *optional*, defaults to 1.5): + Push the generated image towards the inital `image`. Image guidance scale is enabled by setting + `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely + linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a + value of at least `1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from optimum.neuron import NeuronStableDiffusionInstructPix2PixPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} + >>> input_shapes = {"batch_size": 1, "height": 512, "width": 512} + >>> pipe = NeuronStableDiffusionInstructPix2PixPipeline.from_pretrained( + ... "timbrooks/instruct-pix2pix", export=True, dynamic_batch_size=True, **compiler_args, **input_shapes, + ... ) + >>> pipe.save_pretrained("sd_ip2p/") + + >>> prompt = "in the style of Van Gogh" + >>> image = pipe(prompt=prompt, image=init_image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 0. Check inputs + self.check_inputs( + prompt, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._image_guidance_scale = image_guidance_scale + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 1. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + neuron_batch_size = self.unet.config.neuron["static_batch_size"] + self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) + + # check if scheduler is in sigmas space + scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") + + # 2. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 3. Preprocess image + height = self.vae_encoder.config.neuron["static_height"] + width = self.vae_encoder.config.neuron["static_width"] + image = self.image_processor.preprocess(image, height=height, width=width) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 5. Prepare Image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + self.do_classifier_free_guidance, + generator, + ) + + height, width = image_latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + # 6. Prepare latent variables + num_channels_latents = self.vae_decoder.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 7. Check that shapes of latents and image match the UNet channels + num_channels_image = image_latents.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance. + # The latents are expanded 3 times because for pix2pix the guidance\ + # is applied for both the text and the input image. + latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents + + # concat latents, image_latents in the channel dimension + scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet( + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + )[0] + + # Hack: + # For karras style schedulers the model does classifer free guidance using the + # predicted_original_sample instead of the noise_pred. So we need to compute the + # predicted_original_sample here if we are using a karras style scheduler. + if scheduler_is_in_sigma_space: + step_index = (self.scheduler.timesteps == t).nonzero()[0].item() + sigma = self.scheduler.sigmas[step_index] + noise_pred = latent_model_input - sigma * noise_pred + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + self.guidance_scale * (noise_pred_text - noise_pred_image) + + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) + + # Hack: + # For karras style schedulers the model does classifer free guidance using the + # predicted_original_sample instead of the noise_pred. But the scheduler.step function + # expects the noise_pred and computes the predicted_original_sample internally. So we + # need to overwrite the noise_pred here such that the value of the computed + # predicted_original_sample is correct. + if scheduler_is_in_sigma_space: + noise_pred = (noise_pred - latents) / (-sigma) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + image_latents = callback_outputs.pop("image_latents", image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae_decoder(latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0] + image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionInstructPix2PixPipeline.prepare_image_latents + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [self.vae_encoder(sample=image[i : i + 1])[0] for i in range(batch_size)] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae_encoder(sample=image)[0] + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionInstructPix2PixPipeline._encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, List]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids) + prompt_embeds = prompt_embeds[0] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) + + return prompt_embeds diff --git a/tests/inference/inference_utils.py b/tests/inference/inference_utils.py index bd68d67f8..6d40018aa 100644 --- a/tests/inference/inference_utils.py +++ b/tests/inference/inference_utils.py @@ -46,6 +46,7 @@ "roberta": "hf-internal-testing/tiny-random-RobertaModel", "roformer": "hf-internal-testing/tiny-random-RoFormerModel", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", + "stable-diffusion-ip2p": "hf-internal-testing/tiny-stable-diffusion-pix2pix", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "xlm": "hf-internal-testing/tiny-random-XLMModel", "xlm-roberta": "hf-internal-testing/tiny-xlm-roberta", diff --git a/tests/inference/test_stable_diffusion_pipeline.py b/tests/inference/test_stable_diffusion_pipeline.py index 174eac1f1..b46ffd6a7 100644 --- a/tests/inference/test_stable_diffusion_pipeline.py +++ b/tests/inference/test_stable_diffusion_pipeline.py @@ -24,6 +24,7 @@ NeuronLatentConsistencyModelPipeline, NeuronStableDiffusionImg2ImgPipeline, NeuronStableDiffusionInpaintPipeline, + NeuronStableDiffusionInstructPix2PixPipeline, NeuronStableDiffusionPipeline, NeuronStableDiffusionXLImg2ImgPipeline, NeuronStableDiffusionXLInpaintPipeline, @@ -126,6 +127,22 @@ def test_inpaint_export_and_inference(self, model_arch): image = neuron_pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0] self.assertIsInstance(image, PIL.Image.Image) + @parameterized.expand(["stable-diffusion-ip2p"], skip_on_empty=True) + def test_instruct_pix2pix_export_and_inference(self, model_arch): + neuron_pipeline = NeuronStableDiffusionInstructPix2PixPipeline.from_pretrained( + MODEL_NAMES[model_arch], + export=True, + dynamic_batch_size=True, + **self.STATIC_INPUTS_SHAPES, + **self.COMPILER_ARGS, + ) + + img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" + init_image = download_image(img_url).resize((512, 512)) + prompt = "in the style of Van Gogh" + image = neuron_pipeline(prompt=prompt, image=init_image).images[0] + self.assertIsInstance(image, PIL.Image.Image) + @parameterized.expand(["latent-consistency"], skip_on_empty=True) def test_lcm_export_and_inference(self, model_arch): neuron_pipeline = NeuronLatentConsistencyModelPipeline.from_pretrained(