diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index 7ca391480..ce38b5a63 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -598,12 +598,6 @@ def main_export( # Validate compiled model if do_validation is True: - if library_name == "diffusers": - # Do not validate vae encoder due to the sampling randomness - neuron_outputs.pop("vae_encoder") - models_and_neuron_configs.pop("vae_encoder", None) - output_model_names.pop("vae_encoder", None) - try: validate_models_outputs( models_and_neuron_configs=models_and_neuron_configs, diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index 80973af57..63a5996c9 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -724,7 +724,7 @@ def inputs(self) -> List[str]: @property def outputs(self) -> List[str]: - return ["latent_sample"] + return ["latent_parameters"] def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): dummy_inputs = super().generate_dummy_inputs(**kwargs) diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py index eec314526..a5673c6ae 100644 --- a/optimum/exporters/neuron/utils.py +++ b/optimum/exporters/neuron/utils.py @@ -421,7 +421,7 @@ def get_submodels_for_export_stable_diffusion( # VAE Encoder vae_encoder = copy.deepcopy(pipeline.vae) - vae_encoder.forward = lambda sample: {"latent_sample": vae_encoder.encode(x=sample)["latent_dist"].sample()} + vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters} models_for_export.append((DIFFUSION_MODEL_VAE_ENCODER_NAME, vae_encoder)) # VAE Decoder diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index d32704cc9..915b98d6a 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""NeuroStableDiffusionPipeline class for inference of diffusion models on neuron devices.""" +"""NeuronDiffusionPipelineBase class for inference of diffusion models on neuron devices.""" import copy import importlib +import inspect import logging import os import shutil @@ -27,7 +28,7 @@ import torch from huggingface_hub import snapshot_download -from transformers import CLIPFeatureExtractor, CLIPTokenizer, PretrainedConfig +from transformers import CLIPFeatureExtractor, CLIPTokenizer, PretrainedConfig, T5Tokenizer from transformers.modeling_outputs import ModelOutput from ..exporters.neuron import ( @@ -44,6 +45,7 @@ DIFFUSION_MODEL_CONTROLNET_NAME, DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, DIFFUSION_MODEL_TEXT_ENCODER_NAME, + DIFFUSION_MODEL_TRANSFORMER_NAME, DIFFUSION_MODEL_UNET_NAME, DIFFUSION_MODEL_VAE_DECODER_NAME, DIFFUSION_MODEL_VAE_ENCODER_NAME, @@ -73,30 +75,33 @@ if is_diffusers_available(): from diffusers import ( ControlNetModel, - DDIMScheduler, + LatentConsistencyModelPipeline, LCMScheduler, - LMSDiscreteScheduler, - PNDMScheduler, + PixArtAlphaPipeline, + StableDiffusionControlNetPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline, + StableDiffusionXLControlNetPipeline, StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, ) from diffusers.configuration_utils import FrozenDict from diffusers.image_processor import VaeImageProcessor + from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.controlnet import ControlNetOutput + from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.pipelines.controlnet import MultiControlNetModel + from diffusers.pipelines.pipeline_utils import DiffusionPipeline + from diffusers.schedulers import SchedulerMixin from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME - from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available + from diffusers.utils import CONFIG_NAME from .pipelines import ( - NeuronLatentConsistencyPipelineMixin, NeuronStableDiffusionControlNetPipelineMixin, - NeuronStableDiffusionImg2ImgPipelineMixin, - NeuronStableDiffusionInpaintPipelineMixin, - NeuronStableDiffusionInstructPix2PixPipelineMixin, - NeuronStableDiffusionPipelineMixin, NeuronStableDiffusionXLControlNetPipelineMixin, - NeuronStableDiffusionXLImg2ImgPipelineMixin, - NeuronStableDiffusionXLInpaintPipelineMixin, NeuronStableDiffusionXLPipelineMixin, ) @@ -108,26 +113,41 @@ logger = logging.getLogger(__name__) -class NeuronStableDiffusionPipelineBase(NeuronTracedModel): - auto_model_class = StableDiffusionPipeline +class NeuronDiffusionPipelineBase(NeuronTracedModel): + auto_model_class = DiffusionPipeline + task = None library_name = "diffusers" base_model_prefix = "neuron_model" config_name = "model_index.json" sub_component_config_name = "config.json" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "vae_encoder", + "image_encoder", + "unet", + "transformer", + "feature_extractor", + ] def __init__( self, - text_encoder: torch.jit._script.ScriptModule, - unet: torch.jit._script.ScriptModule, - vae_decoder: Union[torch.jit._script.ScriptModule, "NeuronModelVaeDecoder"], config: Dict[str, Any], configs: Dict[str, "PretrainedConfig"], neuron_configs: Dict[str, "NeuronDefaultConfig"], - tokenizer: CLIPTokenizer, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, LCMScheduler], - data_parallel_mode: Literal["none", "unet", "all"], - vae_encoder: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelVaeEncoder"]] = None, + data_parallel_mode: Literal["none", "unet", "transformer", "all"], + scheduler: Optional[SchedulerMixin], + vae_decoder: Union[torch.jit._script.ScriptModule, "NeuronModelVaeDecoder"], + text_encoder: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTextEncoder"]] = None, text_encoder_2: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTextEncoder"]] = None, + unet: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelUnet"]] = None, + transformer: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTransformer"]] = None, + vae_encoder: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelVaeEncoder"]] = None, + image_encoder: Optional[torch.jit._script.ScriptModule] = None, + safety_checker: Optional[torch.jit._script.ScriptModule] = None, + tokenizer: Optional[Union[CLIPTokenizer, T5Tokenizer]] = None, tokenizer_2: Optional[CLIPTokenizer] = None, feature_extractor: Optional[CLIPFeatureExtractor] = None, controlnet: Optional[ @@ -138,49 +158,71 @@ def __init__( "NeuronMultiControlNetModel", ] ] = None, + # stable diffusion xl specific arguments + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, model_and_config_save_paths: Optional[Dict[str, Tuple[str, Path]]] = None, ): """ Args: - text_encoder (`torch.jit._script.ScriptModule`): - The Neuron TorchScript module associated to the text encoder. - unet (`torch.jit._script.ScriptModule`): - The Neuron TorchScript module associated to the U-NET. - vae_decoder (`torch.jit._script.ScriptModule`): - The Neuron TorchScript module associated to the VAE decoder. config (`Dict[str, Any]`): A config dictionary from which the model components will be instantiated. Make sure to only load configuration files of compatible classes. configs (Dict[str, "PretrainedConfig"], defaults to `None`): A dictionary configurations for components of the pipeline. neuron_configs (Dict[str, "NeuronDefaultConfig"], defaults to `None`): - A list of Neuron configurations. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - scheduler (`Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]`): - A scheduler to be used in combination with the U-NET component to denoise the encoded image latents. + A list of Neuron configurations related to the compilation. data_parallel_mode (`Literal["none", "unet", "all"]`): Mode to decide what components to load into both NeuronCores of a Neuron device. Can be "none"(no data parallel), "unet"(only load unet into both cores of each device), "all"(load the whole pipeline into both cores). - vae_encoder (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`): - The Neuron TorchScript module associated to the VAE encoder. - text_encoder_2 (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`): + scheduler (`Optional[SchedulerMixin]`): + A scheduler to be used in combination with the U-NET component to denoise the encoded image latents. + vae_decoder (`Union[torch.jit._script.ScriptModule, "NeuronModelVaeDecoder"]`): + The Neuron TorchScript module associated to the VAE decoder. + text_encoder (`Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTextEncoder"]]`, defaults to `None`): + The Neuron TorchScript module associated to the text encoder. + text_encoder_2 (`Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTextEncoder"]]`, defaults to `None`): The Neuron TorchScript module associated to the second frozen text encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. - controlnet (`Optional[Union[torch.jit._script.ScriptModule, List[torch.jit._script.ScriptModule], "NeuronControlNetModel", "NeuronMultiControlNetModel"]]`, defaults to `None`): - The Neuron TorchScript module(s) associated to the ControlNet(s). + unet (`Optional[Union[torch.jit._script.ScriptModule, "NeuronModelUnet"]]`, defaults to `None`): + The Neuron TorchScript module associated to the U-NET. + transformer (`Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTransformer"]]`, defaults to `None`): + The Neuron TorchScript module associated to the diffuser transformer. + vae_encoder (`Optional[Union[torch.jit._script.ScriptModule, "NeuronModelVaeEncoder"]]`, defaults to `None`): + The Neuron TorchScript module associated to the VAE encoder. + image_encoder (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`): + The Neuron TorchScript module associated to the frozen CLIP image-encoder. + safety_checker (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`): + The Neuron TorchScript module associated to the Classification module that estimates whether generated images could be considered offensive or harmful. + tokenizer (`Optional[Union[CLIPTokenizer, T5Tokenizer]]`, defaults to `None`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) for stable diffusion models, + or tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer) for diffusion transformers. tokenizer_2 (`Optional[CLIPTokenizer]`, defaults to `None`): Second tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). feature_extractor (`Optional[CLIPFeatureExtractor]`, defaults to `None`): A model extracting features from generated images to be used as inputs for the `safety_checker` + controlnet (`Optional[Union[torch.jit._script.ScriptModule, List[torch.jit._script.ScriptModule], "NeuronControlNetModel", "NeuronMultiControlNetModel"]]`, defaults to `None`): + The Neuron TorchScript module(s) associated to the ControlNet(s). + requires_aesthetics_score (`bool`, defaults to `False`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, defaults to `True`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`Optional[bool]`, defaults to `None`): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. model_save_dir (`Optional[Union[str, Path, TemporaryDirectory]]`, defaults to `None`): The directory under which the exported Neuron models were saved. model_and_config_save_paths (`Optional[Dict[str, Tuple[str, Path]]]`, defaults to `None`): The paths where exported Neuron models were saved. """ + # configurations self._internal_dict = config self.data_parallel_mode = data_parallel_mode self.configs = configs @@ -189,6 +231,7 @@ def __init__( neuron_config._config.neuron["dynamic_batch_size"] for neuron_config in self.neuron_configs.values() ) + # pipeline components self.text_encoder = ( NeuronModelTextEncoder( text_encoder, @@ -196,8 +239,8 @@ def __init__( self.configs[DIFFUSION_MODEL_TEXT_ENCODER_NAME], self.neuron_configs[DIFFUSION_MODEL_TEXT_ENCODER_NAME], ) - if text_encoder is not None - else None + if text_encoder is not None and not isinstance(text_encoder, NeuronModelTextEncoder) + else text_encoder ) self.text_encoder_2 = ( NeuronModelTextEncoder( @@ -209,28 +252,44 @@ def __init__( if text_encoder_2 is not None and not isinstance(text_encoder_2, NeuronModelTextEncoder) else text_encoder_2 ) - self.unet = NeuronModelUnet( - unet, self, self.configs[DIFFUSION_MODEL_UNET_NAME], self.neuron_configs[DIFFUSION_MODEL_UNET_NAME] + self.unet = ( + NeuronModelUnet( + unet, self, self.configs[DIFFUSION_MODEL_UNET_NAME], self.neuron_configs[DIFFUSION_MODEL_UNET_NAME] + ) + if unet is not None and not isinstance(unet, NeuronModelUnet) + else unet + ) + self.transformer = ( + NeuronModelTransformer( + transformer, + self, + self.configs[DIFFUSION_MODEL_TRANSFORMER_NAME], + self.neuron_configs[DIFFUSION_MODEL_TRANSFORMER_NAME], + ) + if transformer is not None and not isinstance(transformer, NeuronModelTransformer) + else transformer ) - if vae_encoder is not None and not isinstance(vae_encoder, NeuronModelVaeEncoder): - self.vae_encoder = NeuronModelVaeEncoder( + self.vae_encoder = ( + NeuronModelVaeEncoder( vae_encoder, self, self.configs[DIFFUSION_MODEL_VAE_ENCODER_NAME], self.neuron_configs[DIFFUSION_MODEL_VAE_ENCODER_NAME], ) - else: - self.vae_encoder = vae_encoder - - if vae_decoder is not None and not isinstance(vae_decoder, NeuronModelVaeDecoder): - self.vae_decoder = NeuronModelVaeDecoder( + if vae_encoder is not None and not isinstance(vae_encoder, NeuronModelVaeEncoder) + else vae_encoder + ) + self.vae_decoder = ( + NeuronModelVaeDecoder( vae_decoder, self, self.configs[DIFFUSION_MODEL_VAE_DECODER_NAME], self.neuron_configs[DIFFUSION_MODEL_VAE_DECODER_NAME], ) - else: - self.vae_decoder = vae_decoder + if vae_decoder is not None and not isinstance(vae_decoder, NeuronModelVaeDecoder) + else vae_decoder + ) + self.vae = NeuronModelVae(self.vae_encoder, self.vae_decoder) if ( controlnet @@ -254,31 +313,48 @@ def __init__( self.tokenizer = tokenizer self.tokenizer_2 = tokenizer_2 self.scheduler = scheduler + + # change lcm scheduler which extends the denoising procedure self.is_lcm = False - if NeuronStableDiffusionPipelineBase.is_lcm(self.unet.config): + if NeuronDiffusionPipelineBase.is_lcm(self.unet.config): self.is_lcm = True self.scheduler = LCMScheduler.from_config(self.scheduler.config) + self.feature_extractor = feature_extractor - self.safety_checker = None - sub_models = { - DIFFUSION_MODEL_TEXT_ENCODER_NAME: self.text_encoder, - DIFFUSION_MODEL_UNET_NAME: self.unet, - DIFFUSION_MODEL_VAE_DECODER_NAME: self.vae_decoder, + self.image_encoder = image_encoder # TODO: implement the class `NeuronImageEncoder`. + self.safety_checker = safety_checker # TODO: implement the class `NeuronStableDiffusionSafetyChecker`. + + all_possible_init_args = { + "vae": self.vae, + "unet": self.unet, + "transformer": self.transformer, + "text_encoder": self.text_encoder, + "text_encoder_2": self.text_encoder_2, + "controlnet": self.controlnet, + "image_encoder": self.image_encoder, + "safety_checker": self.safety_checker, + "scheduler": self.scheduler, + "tokenizer": self.tokenizer, + "tokenizer_2": self.tokenizer_2, + "feature_extractor": self.feature_extractor, + "requires_aesthetics_score": requires_aesthetics_score, + "force_zeros_for_empty_prompt": force_zeros_for_empty_prompt, + "add_watermarker": add_watermarker, } - if self.text_encoder_2 is not None: - sub_models[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = self.text_encoder_2 - if self.vae_encoder is not None: - sub_models[DIFFUSION_MODEL_VAE_ENCODER_NAME] = self.vae_encoder - - for name in sub_models.keys(): - self._internal_dict[name] = ("optimum", sub_models[name].__class__.__name__) - self._internal_dict.pop("vae", None) + diffusers_pipeline_args = {} + for key in inspect.signature(self.auto_model_class).parameters.keys(): + if key in all_possible_init_args: + diffusers_pipeline_args[key] = all_possible_init_args[key] + self.auto_model_class.__init__(self, **diffusers_pipeline_args) self._attributes_init(model_save_dir) self.model_and_config_save_paths = model_and_config_save_paths if model_and_config_save_paths else None + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) - if hasattr(self.vae_decoder.config, "block_out_channels"): - self.vae_scale_factor = 2 ** (len(self.vae_decoder.config.block_out_channels) - 1) + # Calculate static shapes + if hasattr(self.vae.config, "block_out_channels"): + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) else: self.vae_scale_factor = 8 @@ -306,12 +382,13 @@ def is_lcm(unet_config): @staticmethod @requires_torch_neuronx def load_model( - data_parallel_mode: Optional[Literal["none", "unet", "all"]], - text_encoder_path: Union[str, Path], - unet_path: Union[str, Path], - vae_decoder_path: Optional[Union[str, Path]] = None, - vae_encoder_path: Optional[Union[str, Path]] = None, + data_parallel_mode: Optional[Literal["none", "unet", "transformer", "all"]], + text_encoder_path: Optional[Union[str, Path]] = None, text_encoder_2_path: Optional[Union[str, Path]] = None, + unet_path: Optional[Union[str, Path]] = None, + transformer_path: Optional[Union[str, Path]] = None, + vae_encoder_path: Optional[Union[str, Path]] = None, + vae_decoder_path: Optional[Union[str, Path]] = None, controlnet_paths: Optional[List[Path]] = None, dynamic_batch_size: bool = False, to_neuron: bool = False, @@ -324,16 +401,18 @@ def load_model( data_parallel_mode (`Optional[Literal["none", "unet", "all"]]`): Mode to decide what components to load into both NeuronCores of a Neuron device. Can be "none"(no data parallel), "unet"(only load unet into both cores of each device), "all"(load the whole pipeline into both cores). - text_encoder_path (`Union[str, Path]`): + text_encoder_path (`Union[str, Path]`, defaults to `None`): Path of the compiled text encoder. - unet_path (`Union[str, Path]`): + text_encoder_2_path (`Optional[Union[str, Path]]`, defaults to `None`): + Path of the compiled second frozen text encoder. SDXL only. + unet_path (`Optional[Union[str, Path]]`, defaults to `None`): Path of the compiled U-NET. - vae_decoder_path (`Optional[Union[str, Path]]`, defaults to `None`): - Path of the compiled VAE decoder. + transformer_path (`Optional[Union[str, Path]]`, defaults to `None`): + Path of the compiled diffusion transformer. vae_encoder_path (`Optional[Union[str, Path]]`, defaults to `None`): Path of the compiled VAE encoder. It is optional, only used for tasks taking images as input. - text_encoder_2_path (`Optional[Union[str, Path]]`, defaults to `None`): - Path of the compiled second frozen text encoder. SDXL only. + vae_decoder_path (`Optional[Union[str, Path]]`, defaults to `None`): + Path of the compiled VAE decoder. controlnet_paths (`Optional[List[Path]]`, defaults to `None`): Path of the compiled controlnets. dynamic_batch_size (`bool`, defaults to `False`): @@ -343,81 +422,86 @@ def load_model( """ submodels = { "text_encoder": text_encoder_path, + "text_encoder_2": text_encoder_2_path, "unet": unet_path, - "vae_decoder": vae_decoder_path, + "transformer": transformer_path, "vae_encoder": vae_encoder_path, - "text_encoder_2": text_encoder_2_path, + "vae_decoder": vae_decoder_path, "controlnet": controlnet_paths, } + def _load_models_to_neuron(submodels, models_on_both_cores=None, models_on_a_single_core=None): + # loading models to both cores, eg. unet, transformer. + if models_on_both_cores: + for model_name in models_on_both_cores: + submodel_paths = submodels[model_name] + # for the case of multiple controlnets the path could be a list + if not isinstance(submodel_paths, list): + submodel_paths = [submodel_paths] + submodels_list = [] + for submodel_path in submodel_paths: + if submodel_path is not None and submodel_path.is_file(): + submodel = NeuronTracedModel.load_model( + submodel_path, to_neuron=False + ) # No need to load to neuron manually when dp + submodel = torch_neuronx.DataParallel( + submodel, + [0, 1], + set_dynamic_batching=dynamic_batch_size, + ) + submodels_list.append(submodel) + if submodels_list: + submodels[model_name] = submodels_list if len(submodels_list) > 1 else submodels_list[0] + else: + submodels[model_name] = None + # loading models to a single core, eg. text encoders, vae. + if models_on_a_single_core: + for model_name in models_on_a_single_core: + submodel_paths = submodels[model_name] + # for the case of multiple controlnets the path could be a list + if not isinstance(submodel_paths, list): + submodel_paths = [submodel_paths] + submodels_list = [] + for submodel_path in submodel_paths: + if submodel_path is not None and submodel_path.is_file(): + submodel = NeuronTracedModel.load_model(submodel_path, to_neuron=to_neuron) + submodels_list.append(submodel) + if submodels_list: + submodels[model_name] = submodels_list if len(submodels_list) > 1 else submodels_list[0] + else: + submodels[model_name] = None + return submodels + if data_parallel_mode == "all": logger.info("Loading the whole pipeline into both Neuron Cores...") - for submodel_name, submodel_paths in submodels.items(): - if not isinstance(submodel_paths, list): - submodel_paths = [submodel_paths] - submodels_list = [] - for submodel_path in submodel_paths: - if submodel_path is not None and submodel_path.is_file(): - submodel = NeuronTracedModel.load_model( - submodel_path, to_neuron=False - ) # No need to load to neuron manually when dp - submodel = torch_neuronx.DataParallel( - submodel, - [0, 1], - set_dynamic_batching=dynamic_batch_size, - ) - submodels_list.append(submodel) - if submodels_list: - submodels[submodel_name] = submodels_list if len(submodels_list) > 1 else submodels_list[0] - else: - submodels[submodel_name] = None + submodels = _load_models_to_neuron(submodels=submodels, models_on_both_cores=list(submodels)) elif data_parallel_mode == "unet": logger.info("Loading only U-Net into both Neuron Cores...") - submodels.pop("unet") - submodels.pop("controlnet") # controlnet takes inputs with the same batch_size as the unet - for submodel_name, submodel_path in submodels.items(): - if submodel_path is not None and submodel_path.is_file(): - submodels[submodel_name] = NeuronTracedModel.load_model(submodel_path, to_neuron=to_neuron) - else: - submodels[submodel_name] = None - # load unet - unet = NeuronTracedModel.load_model( - unet_path, to_neuron=False - ) # No need to load to neuron manually when dp - submodels["unet"] = torch_neuronx.DataParallel( - unet, - [0, 1], - set_dynamic_batching=dynamic_batch_size, + models_on_a_single_core = list(submodels) + models_on_a_single_core.remove("unet") + models_on_a_single_core.remove( + "controlnet" + ) # controlnet takes inputs with the same batch_size as the unet + submodels = _load_models_to_neuron( + submodels=submodels, + models_on_both_cores=["unet", "controlnet"], + models_on_a_single_core=models_on_a_single_core, + ) + elif data_parallel_mode == "transformer": + logger.info("Loading only diffusion transformer into both Neuron Cores...") + models_on_a_single_core = list(submodels) + models_on_a_single_core.remove("transformer") + models_on_a_single_core.remove( + "controlnet" + ) # controlnet takes inputs with the same batch_size as the transformer + submodels = _load_models_to_neuron( + submodels=submodels, + models_on_both_cores=["transformer", "controlnet"], + models_on_a_single_core=models_on_a_single_core, ) - # load controlnets - if controlnet_paths: - controlnets = [] - for controlnet_path in controlnet_paths: - if controlnet_path.is_file(): - controlnet = NeuronTracedModel.load_model( - controlnet_path, to_neuron=False - ) # No need to load to neuron manually when dp - controlnets.append( - torch_neuronx.DataParallel(controlnet, [0, 1], set_dynamic_batching=dynamic_batch_size) - ) - if controlnets: - submodels["controlnet"] = controlnets if len(controlnets) > 1 else controlnets[0] - else: - submodels["controlnet"] = None elif data_parallel_mode == "none": logger.info("Loading the pipeline without any data parallelism...") - for submodel_name, submodel_paths in submodels.items(): - if not isinstance(submodel_paths, list): - submodel_paths = [submodel_paths] - submodels_list = [] - for submodel_path in submodel_paths: - if submodel_path is not None and submodel_path.is_file(): - submodel = NeuronTracedModel.load_model(submodel_path, to_neuron=to_neuron) - submodels_list.append(submodel) - if submodels_list: - submodels[submodel_name] = submodels_list if len(submodels_list) > 1 else submodels_list[0] - else: - submodels[submodel_name] = None + submodels = _load_models_to_neuron(submodels=submodels, models_on_a_single_core=list(submodels)) else: raise ValueError("You need to pass `data_parallel_mode` to define Neuron Core allocation.") @@ -425,7 +509,7 @@ def load_model( def replace_weights(self, weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]] = None): check_if_weights_replacable(self.configs, weights) - model_names = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"] + model_names = ["text_encoder", "text_encoder_2", "unet", "transformer", "vae_decoder", "vae_encoder"] for name in model_names: model = getattr(self, name, None) weight = getattr(weights, name, None) @@ -433,13 +517,22 @@ def replace_weights(self, weights: Optional[Union[Dict[str, torch.Tensor], torch model = replace_weights(model.model, weight) @staticmethod - def set_default_dp_mode(unet_config): - if NeuronStableDiffusionPipelineBase.is_lcm(unet_config) is True: - # LCM applies guidance using guidance embeddings, so we can load the whole pipeline into both cores. - return "all" + def set_default_dp_mode(configs: Dict): + if "unet" in configs: + unet_config = configs["unet"] + if NeuronDiffusionPipelineBase.is_lcm(unet_config) is True: + # LCM applies guidance using guidance embeddings, so we can load the whole pipeline into both cores. + return "all" + else: + # Load U-Net into both cores for classifier-free guidance which doubles batch size of inputs passed to the U-Net. + return "unet" + elif "transformer" in configs: + return "transformer" else: - # Load U-Net into both cores for classifier-free guidance which doubles batch size of inputs passed to the U-Net. - return "unet" + logger.warning( + "There is no unet nor transformer in your pipeline, the data parallelism will be disabled, make sure that you are loading the model correctly!" + ) + return "none" def _save_pretrained( self, @@ -447,6 +540,7 @@ def _save_pretrained( text_encoder_file_name: str = NEURON_FILE_NAME, text_encoder_2_file_name: str = NEURON_FILE_NAME, unet_file_name: str = NEURON_FILE_NAME, + transformer_file_name: str = NEURON_FILE_NAME, vae_encoder_file_name: str = NEURON_FILE_NAME, vae_decoder_file_name: str = NEURON_FILE_NAME, controlnet_file_name: str = NEURON_FILE_NAME, @@ -461,14 +555,21 @@ def _save_pretrained( return save_directory = Path(save_directory) - if not self.model_and_config_save_paths.get(DIFFUSION_MODEL_VAE_ENCODER_NAME)[0].is_file(): - self.model_and_config_save_paths.pop(DIFFUSION_MODEL_VAE_ENCODER_NAME) - - if not self.model_and_config_save_paths.get(DIFFUSION_MODEL_TEXT_ENCODER_NAME)[0].is_file(): - self.model_and_config_save_paths.pop(DIFFUSION_MODEL_TEXT_ENCODER_NAME) - if not self.model_and_config_save_paths.get(DIFFUSION_MODEL_TEXT_ENCODER_2_NAME)[0].is_file(): - self.model_and_config_save_paths.pop(DIFFUSION_MODEL_TEXT_ENCODER_2_NAME) + def _remove_submodel_if_non_exist(model_names): + for model_name in model_names: + if not self.model_and_config_save_paths.get(model_name)[0].is_file(): + self.model_and_config_save_paths.pop(model_name) + + _remove_submodel_if_non_exist( + [ + DIFFUSION_MODEL_TEXT_ENCODER_NAME, + DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, + DIFFUSION_MODEL_UNET_NAME, + DIFFUSION_MODEL_TRANSFORMER_NAME, + DIFFUSION_MODEL_VAE_ENCODER_NAME, + ] + ) if not self.model_and_config_save_paths.get(DIFFUSION_MODEL_CONTROLNET_NAME)[0]: self.model_and_config_save_paths.pop(DIFFUSION_MODEL_CONTROLNET_NAME) @@ -486,6 +587,9 @@ def _save_pretrained( / DIFFUSION_MODEL_TEXT_ENCODER_2_NAME / text_encoder_2_file_name, DIFFUSION_MODEL_UNET_NAME: save_directory / DIFFUSION_MODEL_UNET_NAME / unet_file_name, + DIFFUSION_MODEL_TRANSFORMER_NAME: save_directory + / DIFFUSION_MODEL_TRANSFORMER_NAME + / transformer_file_name, DIFFUSION_MODEL_VAE_ENCODER_NAME: save_directory / DIFFUSION_MODEL_VAE_ENCODER_NAME / vae_encoder_file_name, @@ -546,16 +650,13 @@ def _from_pretrained( text_encoder_file_name: Optional[str] = NEURON_FILE_NAME, text_encoder_2_file_name: Optional[str] = NEURON_FILE_NAME, unet_file_name: Optional[str] = NEURON_FILE_NAME, + transformer_file_name: Optional[str] = NEURON_FILE_NAME, vae_encoder_file_name: Optional[str] = NEURON_FILE_NAME, vae_decoder_file_name: Optional[str] = NEURON_FILE_NAME, controlnet_file_name: Optional[str] = NEURON_FILE_NAME, - text_encoder_2: Optional["NeuronModelTextEncoder"] = None, - vae_encoder: Optional["NeuronModelVaeEncoder"] = None, - vae_decoder: Optional["NeuronModelVaeDecoder"] = None, - controlnet: Optional[Union["NeuronControlNetModel", "NeuronMultiControlNetModel"]] = None, local_files_only: bool = False, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - data_parallel_mode: Optional[Literal["none", "unet", "all"]] = None, + data_parallel_mode: Optional[Literal["none", "unet", "transformer", "all"]] = None, **kwargs, # To share kwargs only available for `_from_transformers` ): model_id = str(model_id) @@ -570,6 +671,7 @@ def _from_pretrained( text_encoder_file_name, text_encoder_2_file_name, unet_file_name, + transformer_file_name, vae_encoder_file_name, vae_decoder_file_name, controlnet_file_name, @@ -617,6 +719,10 @@ def _from_pretrained( new_model_save_dir / DIFFUSION_MODEL_UNET_NAME / unet_file_name, new_model_save_dir / DIFFUSION_MODEL_UNET_NAME / cls.sub_component_config_name, ), + "transformer": ( + new_model_save_dir / DIFFUSION_MODEL_TRANSFORMER_NAME / transformer_file_name, + new_model_save_dir / DIFFUSION_MODEL_TRANSFORMER_NAME / cls.sub_component_config_name, + ), "vae_encoder": ( new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_NAME / vae_encoder_file_name, new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_NAME / cls.sub_component_config_name, @@ -658,20 +764,16 @@ def _from_pretrained( neuron_configs[name] = sub_neuron_configs if len(sub_neuron_configs) > 1 else sub_neuron_configs[0] if data_parallel_mode is None: - data_parallel_mode = cls.set_default_dp_mode(configs["unet"]) + data_parallel_mode = cls.set_default_dp_mode(configs) pipe = cls.load_model( data_parallel_mode=data_parallel_mode, text_encoder_path=model_and_config_save_paths["text_encoder"][0], unet_path=model_and_config_save_paths["unet"][0], - vae_decoder_path=model_and_config_save_paths["vae_decoder"][0] if vae_decoder is None else None, - vae_encoder_path=model_and_config_save_paths["vae_encoder"][0] if vae_encoder is None else None, - text_encoder_2_path=model_and_config_save_paths["text_encoder_2"][0] if text_encoder_2 is None else None, - controlnet_paths=( - model_and_config_save_paths["controlnet"][0] - if controlnet is None and model_and_config_save_paths["controlnet"][0] - else None - ), + vae_decoder_path=model_and_config_save_paths["vae_decoder"][0], + vae_encoder_path=model_and_config_save_paths["vae_encoder"][0], + text_encoder_2_path=model_and_config_save_paths["text_encoder_2"][0], + controlnet_paths=model_and_config_save_paths["controlnet"][0], dynamic_batch_size=neuron_configs[DIFFUSION_MODEL_UNET_NAME].dynamic_batch_size, to_neuron=not inline_weights_to_neff, ) @@ -681,15 +783,16 @@ def _from_pretrained( return cls( text_encoder=pipe.get("text_encoder"), + text_encoder_2=pipe.get("text_encoder_2"), unet=pipe.get("unet"), - vae_decoder=vae_decoder or pipe.get("vae_decoder"), + transformer=pipe.get("transformer"), + vae_encoder=pipe.get("vae_encoder"), + vae_decoder=pipe.get("vae_decoder"), + controlnet=pipe.get("controlnet"), config=config, tokenizer=sub_models.get("tokenizer", None), - scheduler=sub_models.get("scheduler"), - vae_encoder=vae_encoder or pipe.get("vae_encoder"), - text_encoder_2=text_encoder_2 or pipe.get("text_encoder_2"), - controlnet=controlnet or pipe.get("controlnet"), tokenizer_2=sub_models.get("tokenizer_2", None), + scheduler=sub_models.get("scheduler"), feature_extractor=sub_models.get("feature_extractor", None), data_parallel_mode=data_parallel_mode, configs=configs, @@ -727,14 +830,14 @@ def _export( auto_cast_type: Optional[str] = "bf16", dynamic_batch_size: bool = False, output_hidden_states: bool = False, - data_parallel_mode: Optional[Literal["none", "unet", "all"]] = None, + data_parallel_mode: Optional[Literal["none", "unet", "transformer", "all"]] = None, lora_model_ids: Optional[Union[str, List[str]]] = None, lora_weight_names: Optional[Union[str, List[str]]] = None, lora_adapter_names: Optional[Union[str, List[str]]] = None, lora_scales: Optional[Union[float, List[float]]] = None, controlnet_ids: Optional[Union[str, List[str]]] = None, **kwargs_shapes, - ) -> "NeuronStableDiffusionPipelineBase": + ) -> "NeuronDiffusionPipelineBase": """ Args: model_id (`Union[str, Path]`): @@ -791,7 +894,7 @@ def _export( batch size during the compilation, but it comes with a potential tradeoff in terms of latency. output_hidden_states (`bool`, defaults to `False`): Whether or not for the traced text encoders to return the hidden states of all layers. - data_parallel_mode (`Optional[Literal["none", "unet", "all"]]`, defaults to `None`): + data_parallel_mode (`Optional[Literal["none", "unet", "transformer", "all"]]`, defaults to `None`): Mode to decide what components to load into both NeuronCores of a Neuron device. Can be "none"(no data parallel), "unet"(only load unet into both cores of each device), "all"(load the whole pipeline into both cores). lora_model_ids (`Optional[Union[str, List[str]]]`, defaults to `None`): @@ -808,7 +911,10 @@ def _export( Shapes to use during inference. This argument allows to override the default shapes used during the export. """ if task is None: - task = TasksManager.infer_task_from_model(cls.auto_model_class) + if cls.task is not None: + task = cls.task + else: + task = TasksManager.infer_task_from_model(cls.auto_model_class) # mandatory shapes input_shapes = normalize_stable_diffusion_input_shapes(kwargs_shapes) @@ -957,23 +1063,67 @@ def _load_config(cls, config_name_or_path: Union[str, os.PathLike], **kwargs): def _save_config(self, save_directory): self.save_config(save_directory) + @property + def components(self) -> Dict[str, Any]: + components = { + "vae_encoder": self.vae_encoder, + "vae_decoder": self.vae_decoder, + "unet": self.unet, + "transformer": self.transformer, + "text_encoder": self.text_encoder, + "text_encoder_2": self.text_encoder_2, + "image_encoder": self.image_encoder, + "safety_checker": self.safety_checker, + "neuron_configs": self.neuron_configs, + "data_parallel_mode": self.data_parallel_mode, + "feature_extractor": self.feature_extractor, + "configs": self.configs, + "config": self.config, + "tokenizer": self.tokenizer, + "tokenizer_2": self.tokenizer_2, + "scheduler": self.scheduler, + } + return components + + @property + def do_classifier_free_guidance(self): + return ( + self._guidance_scale > 1 + and self.unet.config.time_cond_proj_dim is None + and ( + self.dynamic_batch_size + or self.data_parallel_mode == "unet" + or self.data_parallel_mode == "transformer" + ) + ) + + def __call__(self, *args, **kwargs): + # Height and width to unet (static shapes) + height = self.unet.config.neuron["static_height"] * self.vae_scale_factor + width = self.unet.config.neuron["static_width"] * self.vae_scale_factor + kwargs.pop("height", None) + kwargs.pop("width", None) + if kwargs.get("image", None): + kwargs["image"] = self.image_processor.preprocess(kwargs["image"], height=height, width=width) + return self.auto_model_class.__call__(self, height=height, width=width, *args, **kwargs) + class _NeuronDiffusionModelPart: """ - For multi-file Neuron models, represents a part of the model. + For multi-file Neuron models, represents a part / a model in the pipeline. """ def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronTracedModel, + parent_pipeline: NeuronDiffusionPipelineBase, config: Optional[Union[DiffusersPretrainedConfig, PretrainedConfig]] = None, neuron_config: Optional["NeuronDefaultConfig"] = None, model_type: str = "unet", device: Optional[int] = None, ): self.model = model - self.parent_model = parent_model + self.parent_pipeline = parent_pipeline self.config = config self.neuron_config = neuron_config self.model_type = model_type @@ -986,23 +1136,30 @@ def forward(self, *args, **kwargs): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) + @property + def dtype(self): + return None + + def to(self, *args, **kwargs): + pass + class NeuronModelTextEncoder(_NeuronDiffusionModelPart): def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronTracedModel, + parent_pipeline: NeuronDiffusionPipelineBase, config: Optional[DiffusersPretrainedConfig] = None, neuron_config: Optional[Dict[str, str]] = None, ): - super().__init__(model, parent_model, config, neuron_config, DIFFUSION_MODEL_TEXT_ENCODER_NAME) + super().__init__(model, parent_pipeline, config, neuron_config, DIFFUSION_MODEL_TEXT_ENCODER_NAME) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + return_dict: Optional[bool] = True, ): if attention_mask is not None: assert torch.equal( @@ -1023,16 +1180,20 @@ def forward( return outputs + def modules(self): + # dummy func for passing `unscale_lora_layers`. + return [] + class NeuronModelUnet(_NeuronDiffusionModelPart): def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronTracedModel, + parent_pipeline: NeuronDiffusionPipelineBase, config: Optional[DiffusersPretrainedConfig] = None, neuron_config: Optional[Dict[str, str]] = None, ): - super().__init__(model, parent_model, config, neuron_config, DIFFUSION_MODEL_UNET_NAME) + super().__init__(model, parent_pipeline, config, neuron_config, DIFFUSION_MODEL_UNET_NAME) if hasattr(self.model, "device"): self.device = self.model.device @@ -1042,10 +1203,14 @@ def forward( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - timestep_cond: Optional[torch.Tensor] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, ): + if cross_attention_kwargs is not None: + logger.warning("`cross_attention_kwargs` is not yet supported during the tracing and it will be ignored.") timestep = timestep.float().expand((sample.shape[0],)) inputs = (sample, timestep, encoder_hidden_states) if timestep_cond is not None: @@ -1061,41 +1226,75 @@ def forward( inputs = inputs + (text_embeds, time_ids) outputs = self.model(*inputs) + if return_dict: + outputs = ModelOutput(dict(zip(self.neuron_config.outputs, outputs))) return outputs +class NeuronModelTransformer(_NeuronDiffusionModelPart): + def __init__( + self, + model: torch.jit._script.ScriptModule, + parent_pipeline: NeuronDiffusionPipelineBase, + config: Optional[DiffusersPretrainedConfig] = None, + neuron_config: Optional[Dict[str, str]] = None, + ): + super().__init__(model, parent_pipeline, config, neuron_config, DIFFUSION_MODEL_TRANSFORMER_NAME) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + pass + + class NeuronModelVaeEncoder(_NeuronDiffusionModelPart): def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronTracedModel, + parent_pipeline: NeuronDiffusionPipelineBase, config: Optional[DiffusersPretrainedConfig] = None, neuron_config: Optional[Dict[str, str]] = None, ): - super().__init__(model, parent_model, config, neuron_config, DIFFUSION_MODEL_VAE_ENCODER_NAME) + super().__init__(model, parent_pipeline, config, neuron_config, DIFFUSION_MODEL_VAE_ENCODER_NAME) - def forward(self, sample: torch.Tensor): + def forward(self, sample: torch.Tensor, return_dict: bool = True): inputs = (sample,) outputs = self.model(*inputs) - return tuple(output for output in outputs.values()) + if "latent_parameters" in outputs: + outputs["latent_dist"] = DiagonalGaussianDistribution(parameters=outputs.pop("latent_parameters")) + + if not return_dict: + return tuple(output for output in outputs.values()) + else: + return AutoencoderKLOutput(latent_dist=outputs["latent_dist"]) class NeuronModelVaeDecoder(_NeuronDiffusionModelPart): def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronTracedModel, + parent_pipeline: NeuronDiffusionPipelineBase, config: Optional[DiffusersPretrainedConfig] = None, neuron_config: Optional[Dict[str, str]] = None, ): - super().__init__(model, parent_model, config, neuron_config, DIFFUSION_MODEL_VAE_DECODER_NAME) + super().__init__(model, parent_pipeline, config, neuron_config, DIFFUSION_MODEL_VAE_DECODER_NAME) def forward( self, latent_sample: torch.Tensor, image: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + generator=None, ): inputs = (latent_sample,) if image is not None: @@ -1104,7 +1303,30 @@ def forward( inputs += (mask,) outputs = self.model(*inputs) - return tuple(output for output in outputs.values()) + if not return_dict: + return tuple(output for output in outputs.values()) + else: + return DecoderOutput(**outputs) + + +class NeuronModelVae(_NeuronDiffusionModelPart): + def __init__( + self, + encoder: Optional[NeuronModelVaeEncoder], + decoder: NeuronModelVaeDecoder, + ): + self.encoder = encoder + self.decoder = decoder + + @property + def config(self): + return self.decoder.config + + def encode(self, *args, **kwargs): + return self.encoder(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.decoder(*args, **kwargs) class NeuronControlNetModel(_NeuronDiffusionModelPart): @@ -1117,11 +1339,11 @@ class NeuronControlNetModel(_NeuronDiffusionModelPart): def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronTracedModel, + parent_pipeline: NeuronDiffusionPipelineBase, config: Optional[DiffusersPretrainedConfig] = None, neuron_config: Optional[Dict[str, str]] = None, ): - super().__init__(model, parent_model, config, neuron_config, DIFFUSION_MODEL_CONTROLNET_NAME) + super().__init__(model, parent_pipeline, config, neuron_config, DIFFUSION_MODEL_CONTROLNET_NAME) def forward( self, @@ -1152,6 +1374,10 @@ def forward( return outputs + @property + def __class__(self): + return ControlNetModel + class NeuronMultiControlNetModel(_NeuronDiffusionModelPart): auto_model_class = MultiControlNetModel @@ -1163,12 +1389,12 @@ class NeuronMultiControlNetModel(_NeuronDiffusionModelPart): def __init__( self, models: List[torch.jit._script.ScriptModule], - parent_model: NeuronTracedModel, + parent_pipeline: NeuronTracedModel, config: Optional[DiffusersPretrainedConfig] = None, neuron_config: Optional[Dict[str, str]] = None, ): self.nets = models - self.parent_model = parent_model + self.parent_pipeline = parent_pipeline self.config = config self.neuron_config = neuron_config self.model_type = DIFFUSION_MODEL_CONTROLNET_NAME @@ -1180,7 +1406,7 @@ def forward( timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, + conditioning_scale: List[float], guess_mode: bool = False, return_dict: bool = True, ) -> Union["ControlNetOutput", Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: @@ -1209,152 +1435,74 @@ def forward( return down_block_res_samples, mid_block_res_sample + @property + def __class__(self): + return MultiControlNetModel -class NeuronStableDiffusionPipeline(NeuronStableDiffusionPipelineBase, NeuronStableDiffusionPipelineMixin): - __call__ = NeuronStableDiffusionPipelineMixin.__call__ +class NeuronStableDiffusionPipeline(NeuronDiffusionPipelineBase, StableDiffusionPipeline): + main_input_name = "prompt" + auto_model_class = StableDiffusionPipeline -class NeuronStableDiffusionImg2ImgPipeline( - NeuronStableDiffusionPipelineBase, NeuronStableDiffusionImg2ImgPipelineMixin -): - __call__ = NeuronStableDiffusionImg2ImgPipelineMixin.__call__ +class NeuronStableDiffusionImg2ImgPipeline(NeuronDiffusionPipelineBase, StableDiffusionImg2ImgPipeline): + main_input_name = "image" + auto_model_class = StableDiffusionImg2ImgPipeline -class NeuronStableDiffusionInpaintPipeline( - NeuronStableDiffusionPipelineBase, NeuronStableDiffusionInpaintPipelineMixin -): - __call__ = NeuronStableDiffusionInpaintPipelineMixin.__call__ + +class NeuronStableDiffusionInpaintPipeline(NeuronDiffusionPipelineBase, StableDiffusionInpaintPipeline): + main_input_name = "prompt" + auto_model_class = StableDiffusionInpaintPipeline class NeuronStableDiffusionInstructPix2PixPipeline( - NeuronStableDiffusionPipelineBase, NeuronStableDiffusionInstructPix2PixPipelineMixin + NeuronDiffusionPipelineBase, StableDiffusionInstructPix2PixPipeline ): - __call__ = NeuronStableDiffusionInstructPix2PixPipelineMixin.__call__ + main_input_name = "prompt" + task = "task-to-image" + auto_model_class = StableDiffusionInstructPix2PixPipeline -class NeuronLatentConsistencyModelPipeline(NeuronStableDiffusionPipelineBase, NeuronLatentConsistencyPipelineMixin): - __call__ = NeuronLatentConsistencyPipelineMixin.__call__ +class NeuronLatentConsistencyModelPipeline(NeuronDiffusionPipelineBase, LatentConsistencyModelPipeline): + main_input_name = "prompt" + auto_model_class = LatentConsistencyModelPipeline class NeuronStableDiffusionControlNetPipeline( - NeuronStableDiffusionPipelineBase, NeuronStableDiffusionControlNetPipelineMixin + NeuronStableDiffusionControlNetPipelineMixin, NeuronDiffusionPipelineBase, StableDiffusionControlNetPipeline ): - __call__ = NeuronStableDiffusionControlNetPipelineMixin.__call__ + main_input_name = "prompt" + auto_model_class = StableDiffusionControlNetPipeline -class NeuronStableDiffusionXLPipelineBase(NeuronStableDiffusionPipelineBase): - # `TasksManager` registered img2ime pipeline for `stable-diffusion-xl`: https://github.com/huggingface/optimum/blob/v1.12.0/optimum/exporters/tasks.py#L174 - auto_model_class = StableDiffusionXLImg2ImgPipeline - - def __init__( - self, - text_encoder: torch.jit._script.ScriptModule, - unet: torch.jit._script.ScriptModule, - vae_decoder: torch.jit._script.ScriptModule, - config: Dict[str, Any], - tokenizer: CLIPTokenizer, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - data_parallel_mode: Literal["none", "unet", "all"], - vae_encoder: Optional[torch.jit._script.ScriptModule] = None, - text_encoder_2: Optional[torch.jit._script.ScriptModule] = None, - tokenizer_2: Optional[CLIPTokenizer] = None, - feature_extractor: Optional[CLIPFeatureExtractor] = None, - controlnet: Optional[ - Union[ - torch.jit._script.ScriptModule, - List[torch.jit._script.ScriptModule], - "NeuronControlNetModel", - "NeuronMultiControlNetModel", - ] - ] = None, - configs: Optional[Dict[str, "PretrainedConfig"]] = None, - neuron_configs: Optional[Dict[str, "NeuronDefaultConfig"]] = None, - model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - model_and_config_save_paths: Optional[Dict[str, Tuple[str, Path]]] = None, - add_watermarker: Optional[bool] = None, - ): - super().__init__( - text_encoder=text_encoder, - unet=unet, - vae_decoder=vae_decoder, - config=config, - tokenizer=tokenizer, - scheduler=scheduler, - data_parallel_mode=data_parallel_mode, - vae_encoder=vae_encoder, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - feature_extractor=feature_extractor, - controlnet=controlnet, - configs=configs, - neuron_configs=neuron_configs, - model_save_dir=model_save_dir, - model_and_config_save_paths=model_and_config_save_paths, - ) - - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() - - if add_watermarker: - if not is_invisible_watermark_available(): - raise ImportError( - "`add_watermarker` requires invisible-watermark to be installed, which can be installed with `pip install invisible-watermark`." - ) - from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker - - self.watermark = StableDiffusionXLWatermarker() - else: - self.watermark = None +class NeuronPixArtAlphaPipeline(NeuronDiffusionPipelineBase, PixArtAlphaPipeline): + main_input_name = "prompt" + auto_model_class = PixArtAlphaPipeline -class NeuronStableDiffusionXLPipeline(NeuronStableDiffusionXLPipelineBase, NeuronStableDiffusionXLPipelineMixin): - __call__ = NeuronStableDiffusionXLPipelineMixin.__call__ +class NeuronStableDiffusionXLPipeline( + NeuronStableDiffusionXLPipelineMixin, NeuronDiffusionPipelineBase, StableDiffusionXLPipeline +): + main_input_name = "prompt" + auto_model_class = StableDiffusionXLPipeline class NeuronStableDiffusionXLImg2ImgPipeline( - NeuronStableDiffusionXLPipelineBase, NeuronStableDiffusionXLImg2ImgPipelineMixin + NeuronStableDiffusionXLPipelineMixin, NeuronDiffusionPipelineBase, StableDiffusionXLImg2ImgPipeline ): - __call__ = NeuronStableDiffusionXLImg2ImgPipelineMixin.__call__ + main_input_name = "prompt" + auto_model_class = StableDiffusionXLImg2ImgPipeline class NeuronStableDiffusionXLInpaintPipeline( - NeuronStableDiffusionXLPipelineBase, NeuronStableDiffusionXLInpaintPipelineMixin + NeuronStableDiffusionXLPipelineMixin, NeuronDiffusionPipelineBase, StableDiffusionXLInpaintPipeline ): - __call__ = NeuronStableDiffusionXLInpaintPipelineMixin.__call__ + main_input_name = "image" + auto_model_class = StableDiffusionXLInpaintPipeline class NeuronStableDiffusionXLControlNetPipeline( - NeuronStableDiffusionXLPipelineBase, NeuronStableDiffusionXLControlNetPipelineMixin + NeuronStableDiffusionXLControlNetPipelineMixin, NeuronDiffusionPipelineBase, StableDiffusionXLControlNetPipeline ): - __call__ = NeuronStableDiffusionXLControlNetPipelineMixin.__call__ - - -if is_neuronx_available(): - # TO REMOVE: This class will be included directly in the DDP API of Neuron SDK 2.20 - class WeightSeparatedDataParallel(torch_neuronx.DataParallel): - - def _load_modules(self, module): - try: - self.device_ids.sort() - - loaded_modules = [module] - # If device_ids is non-consecutive, perform deepcopy's and load onto each core independently. - for i in range(len(self.device_ids) - 1): - loaded_modules.append(copy.deepcopy(module)) - for i, nc_index in enumerate(self.device_ids): - torch_neuronx.experimental.placement.set_neuron_cores(loaded_modules[i], nc_index, 1) - torch_neuronx.move_trace_to_device(loaded_modules[i], nc_index) - - except ValueError as err: - self.dynamic_batching_failed = True - logger.warning(f"Automatic dynamic batching failed due to {err}.") - logger.warning( - "Please disable dynamic batching by calling `disable_dynamic_batching()` " - "on your DataParallel module." - ) - self.num_workers = 2 * len(loaded_modules) - return loaded_modules - -else: - - class WeightSeparatedDataParallel: - pass + main_input_name = "prompt" + auto_model_class = StableDiffusionXLControlNetPipeline diff --git a/optimum/neuron/pipelines/__init__.py b/optimum/neuron/pipelines/__init__.py index d4da684ab..d45c704ff 100644 --- a/optimum/neuron/pipelines/__init__.py +++ b/optimum/neuron/pipelines/__init__.py @@ -21,35 +21,19 @@ _import_structure = { "transformers": ["pipeline"], "diffusers": [ - "NeuronStableDiffusionPipelineMixin", - "NeuronStableDiffusionImg2ImgPipelineMixin", - "NeuronStableDiffusionInpaintPipelineMixin", - "NeuronStableDiffusionInstructPix2PixPipelineMixin", - "NeuronLatentConsistencyPipelineMixin", - "NeuronStableDiffusionControlNetPipelineMixin", "NeuronStableDiffusionXLPipelineMixin", - "NeuronStableDiffusionXLImg2ImgPipelineMixin", - "NeuronStableDiffusionXLInpaintPipelineMixin", + "NeuronStableDiffusionControlNetPipelineMixin", "NeuronStableDiffusionXLControlNetPipelineMixin", ], } if TYPE_CHECKING: from .diffusers import ( - NeuronLatentConsistencyPipelineMixin, NeuronStableDiffusionControlNetPipelineMixin, - NeuronStableDiffusionImg2ImgPipelineMixin, - NeuronStableDiffusionInpaintPipelineMixin, - NeuronStableDiffusionInstructPix2PixPipelineMixin, - NeuronStableDiffusionPipelineMixin, NeuronStableDiffusionXLControlNetPipelineMixin, - NeuronStableDiffusionXLImg2ImgPipelineMixin, - NeuronStableDiffusionXLInpaintPipelineMixin, NeuronStableDiffusionXLPipelineMixin, ) - from .transformers import ( - pipeline, - ) + from .transformers import pipeline else: import sys diff --git a/optimum/neuron/pipelines/diffusers/__init__.py b/optimum/neuron/pipelines/diffusers/__init__.py index eefe130a0..361398d8d 100644 --- a/optimum/neuron/pipelines/diffusers/__init__.py +++ b/optimum/neuron/pipelines/diffusers/__init__.py @@ -12,14 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from .pipeline_controlnet import NeuronStableDiffusionControlNetPipelineMixin from .pipeline_controlnet_sd_xl import NeuronStableDiffusionXLControlNetPipelineMixin -from .pipeline_latent_consistency_text2img import NeuronLatentConsistencyPipelineMixin -from .pipeline_stable_diffusion import NeuronStableDiffusionPipelineMixin -from .pipeline_stable_diffusion_img2img import NeuronStableDiffusionImg2ImgPipelineMixin -from .pipeline_stable_diffusion_inpaint import NeuronStableDiffusionInpaintPipelineMixin -from .pipeline_stable_diffusion_instruct_pix2pix import NeuronStableDiffusionInstructPix2PixPipelineMixin -from .pipeline_stable_diffusion_xl import NeuronStableDiffusionXLPipelineMixin -from .pipeline_stable_diffusion_xl_img2img import NeuronStableDiffusionXLImg2ImgPipelineMixin -from .pipeline_stable_diffusion_xl_inpaint import NeuronStableDiffusionXLInpaintPipelineMixin +from .pipeline_utils import NeuronStableDiffusionXLPipelineMixin diff --git a/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py b/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py index daaa4f9a4..24beab45a 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py @@ -18,158 +18,18 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch -from diffusers import StableDiffusionControlNetPipeline +from diffusers import ControlNetModel from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import PipelineImageInput +from diffusers.pipelines.controlnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps -from .pipeline_utils import StableDiffusionPipelineMixin - logger = logging.getLogger(__name__) -class NeuronStableDiffusionControlNetPipelineMixin(StableDiffusionPipelineMixin, StableDiffusionControlNetPipeline): - # Adapted from https://github.com/huggingface/diffusers/blob/de9528ebc7725012cf097e43f565aeff24940eda/src/diffusers/pipelines/controlnet/pipeline_controlnet.py#L594 - # Replace class types with Neuron ones - def check_inputs( - self, - prompt, - image, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ip_adapter_image=None, - ip_adapter_image_embeds=None, - controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, - callback_on_step_end_tensor_inputs=None, - ): - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Check `image` - if self.controlnet.__class__.__name__ == "NeuronControlNetModel": - self.check_image(image, prompt, prompt_embeds) - elif self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": - if not isinstance(image, list): - raise TypeError("For multiple controlnets: `image` must be type `list`") - - # When `image` is a nested list: - # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) - elif any(isinstance(i, list) for i in image): - transposed_image = [list(t) for t in zip(*image)] - if len(transposed_image) != len(self.controlnet.nets): - raise ValueError( - f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." - ) - for image_ in transposed_image: - self.check_image(image_, prompt, prompt_embeds) - elif len(image) != len(self.controlnet.nets): - raise ValueError( - f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." - ) - else: - for image_ in image: - self.check_image(image_, prompt, prompt_embeds) - else: - assert False - - # Check `controlnet_conditioning_scale` - if self.controlnet.__class__.__name__ == "NeuronControlNetModel": - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - elif self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": - if isinstance(controlnet_conditioning_scale, list): - if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError( - "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " - "The conditioning scale must be fixed across the batch." - ) - elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( - self.controlnet.nets - ): - raise ValueError( - "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" - " the same length as the number of controlnets" - ) - else: - assert False - - if not isinstance(control_guidance_start, (tuple, list)): - control_guidance_start = [control_guidance_start] - - if not isinstance(control_guidance_end, (tuple, list)): - control_guidance_end = [control_guidance_end] - - if len(control_guidance_start) != len(control_guidance_end): - raise ValueError( - f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." - ) - - if self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": - if len(control_guidance_start) != len(self.controlnet.nets): - raise ValueError( - f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." - ) - - for start, end in zip(control_guidance_start, control_guidance_end): - if start >= end: - raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." - ) - if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") - if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - - if ip_adapter_image is not None and ip_adapter_image_embeds is not None: - raise ValueError( - "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." - ) - - if ip_adapter_image_embeds is not None: - if not isinstance(ip_adapter_image_embeds, list): - raise ValueError( - f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" - ) - elif ip_adapter_image_embeds[0].ndim not in [3, 4]: - raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" - ) - +class NeuronStableDiffusionControlNetPipelineMixin: def __call__( self, prompt: Union[str, List[str]] = None, @@ -308,7 +168,7 @@ def __call__( elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if controlnet.__class__.__name__ == "NeuronMultiControlNetModel" else 1 + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], @@ -318,6 +178,7 @@ def __call__( self.check_inputs( prompt=prompt, image=image, + callback_steps=None, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -341,14 +202,12 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - if controlnet.__class__.__name__ == "NeuronMultiControlNetModel" and isinstance( - controlnet_conditioning_scale, float - ): + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions - if controlnet.__class__.__name__ == "NeuronControlNetModel" + if isinstance(controlnet, ControlNetModel) else controlnet.config[0].global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions @@ -361,22 +220,21 @@ def __call__( text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - do_classifier_free_guidance = guidance_scale > 1.0 and ( - self.dynamic_batch_size or self.data_parallel_mode == "unet" - ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, + None, num_images_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # TODO: support ip adapter @@ -386,9 +244,9 @@ def __call__( ) # 4. Prepare image - height = self.vae_encoder.config.neuron["static_height"] - width = self.vae_encoder.config.neuron["static_width"] - if controlnet.__class__.__name__ == "NeuronControlNetModel": + height = self.vae.config.neuron["static_height"] * self.vae_scale_factor + width = self.vae.config.neuron["static_width"] * self.vae_scale_factor + if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, @@ -397,11 +255,11 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=None, dtype=None, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] - elif controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + elif isinstance(controlnet, MultiControlNetModel): images = [] # Nested lists as ControlNet condition @@ -418,7 +276,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=None, dtype=None, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -447,6 +305,7 @@ def __call__( height, width, prompt_embeds.dtype, + None, generator, latents, ) @@ -472,18 +331,18 @@ def __call__( 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] - controlnet_keep.append(keeps[0] if controlnet.__class__.__name__ == "NeuronControlNetModel" else keeps) + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -502,7 +361,7 @@ def __call__( # Duplicate inputs for ddp t = torch.tensor([t] * 2) if self.data_parallel_mode == "unet" else t - if controlnet.__class__.__name__ == "NeuronControlNetModel": + if isinstance(controlnet, ControlNetModel): cond_scale = ( torch.tensor([cond_scale]).repeat(2) if self.data_parallel_mode == "unet" @@ -527,7 +386,7 @@ def __call__( return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -540,13 +399,15 @@ def __call__( t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, + return_dict=False, )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -570,8 +431,10 @@ def __call__( progress_bar.update() if not output_type == "latent": - image = self.vae_decoder(latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0] - image, has_nsfw_concept = self.run_safety_checker(image, dtype=prompt_embeds.dtype) + image = self.vae.decode( + latents / getattr(self.vae.config, "scaling_factor", 0.18215), return_dict=False, generator=generator + )[0] + image, has_nsfw_concept = self.run_safety_checker(image, None, dtype=prompt_embeds.dtype) else: image = latents has_nsfw_concept = None diff --git a/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py b/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py index 5555add8e..f95e1b8b9 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py @@ -19,181 +19,20 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from diffusers import StableDiffusionXLControlNetPipeline +from diffusers import ControlNetModel from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import PipelineImageInput +from diffusers.pipelines.controlnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput -from .pipeline_utils import StableDiffusionXLPipelineMixin +from .pipeline_utils import NeuronStableDiffusionXLPipelineMixin logger = logging.getLogger(__name__) -class NeuronStableDiffusionXLControlNetPipelineMixin( - StableDiffusionXLPipelineMixin, StableDiffusionXLControlNetPipeline -): - # Adapted from https://github.com/huggingface/diffusers/blob/v0.29.2/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L625 - # Replace class types with Neuron ones - def check_inputs( - self, - prompt, - prompt_2, - image, - negative_prompt=None, - negative_prompt_2=None, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - ip_adapter_image=None, - ip_adapter_image_embeds=None, - negative_pooled_prompt_embeds=None, - controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, - callback_on_step_end_tensor_inputs=None, - ): - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - elif negative_prompt_2 is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - # Check `image` - if self.controlnet.__class__.__name__ == "NeuronControlNetModel": - self.check_image(image, prompt, prompt_embeds) - elif self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": - if not isinstance(image, list): - raise TypeError("For multiple controlnets: `image` must be type `list`") - - # When `image` is a nested list: - # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) - elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings are not supported at the moment.") - elif len(image) != len(self.controlnet.nets): - raise ValueError( - f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." - ) - else: - for image_ in image: - self.check_image(image_, prompt, prompt_embeds) - else: - raise ValueError( - f"{self.controlnet.__class__.__name__} is not a supported class for ControlNet. The class must be either `NeuronControlNetModel` or `NeuronMultiControlNetModel`." - ) - - # Check `controlnet_conditioning_scale` - if self.controlnet.__class__.__name__ == "NeuronControlNetModel": - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - elif self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": - if isinstance(controlnet_conditioning_scale, list): - if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError("A single batch of multiple conditionings are not supported at the moment.") - elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( - self.controlnet.nets - ): - raise ValueError( - "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" - " the same length as the number of controlnets" - ) - else: - raise ValueError( - f"{self.controlnet.__class__.__name__} is not a supported class for ControlNet. The class must be either `NeuronControlNetModel` or `NeuronMultiControlNetModel`." - ) - - if not isinstance(control_guidance_start, (tuple, list)): - control_guidance_start = [control_guidance_start] - - if not isinstance(control_guidance_end, (tuple, list)): - control_guidance_end = [control_guidance_end] - - if len(control_guidance_start) != len(control_guidance_end): - raise ValueError( - f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." - ) - - if self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": - if len(control_guidance_start) != len(self.controlnet.nets): - raise ValueError( - f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." - ) - - for start, end in zip(control_guidance_start, control_guidance_end): - if start >= end: - raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." - ) - if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") - if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - - if ip_adapter_image is not None and ip_adapter_image_embeds is not None: - raise ValueError( - "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." - ) - - if ip_adapter_image_embeds is not None: - if not isinstance(ip_adapter_image_embeds, list): - raise ValueError( - f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" - ) - elif ip_adapter_image_embeds[0].ndim not in [3, 4]: - raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" - ) - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.30.0/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L899 - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - +class NeuronStableDiffusionXLControlNetPipelineMixin(NeuronStableDiffusionXLPipelineMixin): # Adapted from https://github.com/huggingface/diffusers/blob/1f81fbe274e67c843283e69eb8f00bb56f75ffc4/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L1001 def __call__( self, @@ -390,7 +229,7 @@ def __call__( elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if controlnet.__class__.__name__ == "NeuronMultiControlNetModel" else 1 + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], @@ -401,6 +240,7 @@ def __call__( prompt=prompt, prompt_2=prompt_2, image=image, + callback_steps=None, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, @@ -430,12 +270,12 @@ def __call__( if isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = torch.tensor([controlnet_conditioning_scale]) - if controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + if isinstance(controlnet, MultiControlNetModel): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions - if controlnet.__class__.__name__ == "NeuronControlNetModel" + if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions @@ -448,10 +288,6 @@ def __call__( text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - lora_scale = None - do_classifier_free_guidance = guidance_scale > 1.0 and ( - self.dynamic_batch_size or self.data_parallel_mode == "unet" - ) ( prompt_embeds, negative_prompt_embeds, @@ -460,15 +296,16 @@ def __call__( ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, + device=None, num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=lora_scale, + lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) @@ -480,9 +317,9 @@ def __call__( ) # 4. Prepare image - height = self.vae_encoder.config.neuron["static_height"] - width = self.vae_encoder.config.neuron["static_width"] - if controlnet.__class__.__name__ == "NeuronControlNetModel": + height = self.vae.config.neuron["static_height"] * self.vae_scale_factor + width = self.vae.config.neuron["static_width"] * self.vae_scale_factor + if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, @@ -491,11 +328,11 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=None, dtype=None, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] - elif controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + elif isinstance(controlnet, MultiControlNetModel): images = [] for image_ in image: @@ -507,7 +344,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=None, dtype=None, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -536,6 +373,7 @@ def __call__( height, width, prompt_embeds.dtype, + None, generator, latents, ) @@ -558,7 +396,7 @@ def __call__( 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] - controlnet_keep.append(keeps[0] if controlnet.__class__.__name__ == "NeuronControlNetModel" else keeps) + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 7.2 Prepare added time ids & embeddings if isinstance(image, list): @@ -592,7 +430,7 @@ def __call__( else: negative_add_time_ids = add_time_ids - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) @@ -623,13 +461,13 @@ def __call__( # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -641,7 +479,9 @@ def __call__( else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds - controlnet_added_cond_kwargs = copy.deepcopy(added_cond_kwargs) + controlnet_added_cond_kwargs = copy.deepcopy( + added_cond_kwargs + ) # will be moved to neuron for the controlnet, thus not reusable by the unet if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] @@ -670,7 +510,7 @@ def __call__( return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -688,13 +528,15 @@ def __call__( t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, + return_dict=False, )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -724,12 +566,8 @@ def __call__( if not output_type == "latent": # unscale/denormalize the latents # denormalize with the mean and std if available and not None - has_latents_mean = ( - hasattr(self.vae_decoder.config, "latents_mean") and self.vae_decoder.config.latents_mean is not None - ) - has_latents_std = ( - hasattr(self.vae_decoder.config, "latents_std") and self.vae_decoder.config.latents_std is not None - ) + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) @@ -738,12 +576,13 @@ def __call__( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = ( - latents * latents_std / getattr(self.vae_decoder.config, "scaling_factor", 0.18215) + latents_mean + latents * latents_std / getattr(self.vae.config.scaling_factor, "scaling_factor", 0.18215) + + latents_mean ) else: - latents = latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215) + latents = latents / getattr(self.vae.config, "scaling_factor", 0.18215) - image = self.vae_decoder(latents)[0] + image = self.vae.decode(latents, return_dict=False)[0] else: image = latents diff --git a/optimum/neuron/pipelines/diffusers/pipeline_latent_consistency_text2img.py b/optimum/neuron/pipelines/diffusers/pipeline_latent_consistency_text2img.py deleted file mode 100644 index 8d3a08253..000000000 --- a/optimum/neuron/pipelines/diffusers/pipeline_latent_consistency_text2img.py +++ /dev/null @@ -1,273 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Override some diffusers API for NeuronLatentConsistencyModelPipeline""" - -import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union - -import torch -from diffusers import LatentConsistencyModelPipeline -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput - -from .pipeline_utils import StableDiffusionPipelineMixin - - -if TYPE_CHECKING: - pass - - -logger = logging.getLogger(__name__) - - -class NeuronLatentConsistencyPipelineMixin(StableDiffusionPipelineMixin, LatentConsistencyModelPipeline): - # Adapted from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py#L470 - def check_inputs( - self, - prompt: Union[str, List[str]], - height: int, - width: int, - prompt_embeds: Optional[torch.FloatTensor] = None, - callback_on_step_end_tensor_inputs: List[str] = None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not set(callback_on_step_end_tensor_inputs) <= set( - self._callback_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py#L525 - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - num_inference_steps: int = 50, - original_inference_steps: Optional[int] = None, - guidance_scale: float = 8.5, - num_images_per_prompt: int = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "pil", - return_dict: bool = True, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - original_inference_steps (`Optional[int]`, defaults to `None`): - The original number of inference steps use to generate a linearly-spaced timestep schedule, from which - we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule, - following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the - scheduler's `original_inference_steps` attribute. - guidance_scale (`float`, defaults to 8.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - Note that the original latent consistency models paper uses a different CFG formulation where the - guidance scales are decreased by 1 (so in the paper formulation CFG is enabled when `guidance_scale > - 0`). - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. - generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - output_type (`str`, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - cross_attention_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - clip_skip (`Optional[int]`, defaults to `None`): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Optional[Callable]`, defaults to `None`): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List[str]`, defaults to `["latents"]`): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeine class. - - Examples: - - Returns: - [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. - """ - # -1. Check `num_images_per_prompt` - if self.num_images_per_prompt != num_images_per_prompt and not self.dynamic_batch_size: - logger.warning( - f"Overriding `num_images_per_prompt({num_images_per_prompt})` to {self.num_images_per_prompt} used for the compilation. Please recompile the models with your " - f"custom `num_images_per_prompt` or turn on `dynamic_batch_size`, if you wish generating {num_images_per_prompt} per prompt." - ) - num_images_per_prompt = self.num_images_per_prompt - - # 0. Height and width to unet (static shapes) - height = self.unet.config.neuron["static_height"] * self.vae_scale_factor - width = self.unet.config.neuron["static_width"] * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) - self._guidance_scale = guidance_scale - self._clip_skip = clip_skip - self._cross_attention_kwargs = cross_attention_kwargs - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - neuron_batch_size = self.unet.config.neuron["static_batch_size"] - self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) - - # 3. Encode input prompt - if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." - ) - lora_scale = None - - # NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided - # distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the - # unconditional prompt "" (the empty string). Due to this, LCMs currently do not support negative prompts. - prompt_embeds, _ = self.encode_prompt( - prompt, - num_images_per_prompt, - False, # do_classifier_free_guidance set to False - negative_prompt=None, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=None, - lora_scale=lora_scale, - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, original_inference_steps=original_inference_steps) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variable - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - generator, - latents, - ) - bs = batch_size * num_images_per_prompt - - # 6. Get Guidance Scale Embedding - # NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper - # CFG formulation, so we need to subtract 1 from the input guidance_scale. - # LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG) - w = torch.tensor(self.guidance_scale - 1).repeat(bs) - w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).to( - dtype=latents.dtype - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) - - # 8. LCM MultiStep Sampling Loop: - self._num_timesteps = len(timesteps) - num_warmup_steps = self._num_timesteps - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - latents = latents.to(prompt_embeds.dtype) - - # model prediction (v-prediction, eps, x) - model_pred = self.unet( - sample=latents, - timestep=t, - encoder_hidden_states=prompt_embeds, - timestep_cond=w_embedding, - )[0] - - # compute the previous noisy sample x_t -> x_t-1 - latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False) - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - w_embedding = callback_outputs.pop("w_embedding", w_embedding) - denoised = callback_outputs.pop("denoised", denoised) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - denoised = denoised.to(prompt_embeds.dtype) - if not output_type == "latent": - image = self.vae_decoder(denoised / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0] - image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype) - else: - image = denoised - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py deleted file mode 100644 index 572cea262..000000000 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py +++ /dev/null @@ -1,259 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Override some diffusers API for NeuroStableDiffusionPipeline""" - -import logging -from typing import Any, Callable, Dict, List, Optional, Union - -import torch -from diffusers import StableDiffusionPipeline -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg - -from .pipeline_utils import StableDiffusionPipelineMixin - - -logger = logging.getLogger(__name__) - - -class NeuronStableDiffusionPipelineMixin(StableDiffusionPipelineMixin, StableDiffusionPipeline): - # Adapted from https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566 - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.0, - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. If it is different from the batch size used for the compiltaion, - it will be overriden by the static batch size of neuron (except for dynamic batching). - eta (`float`, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`diffusers.schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - output_type (`Optional[str]`, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Optional[Callable]`, defaults to `None`): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, defaults to `None`): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - guidance_rescale (`float`, defaults to 0.0): - Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when - using zero terminal SNR. - - Examples: - - ```py - >>> from optimum.neuron import NeuronStableDiffusionPipeline - - >>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} - >>> input_shapes = {"batch_size": 1, "height": 512, "width": 512} - - >>> stable_diffusion = NeuronStableDiffusionPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", export=True, **compiler_args, **input_shapes - ... ) - >>> stable_diffusion.save_pretrained("sd_neuron/") - - >>> prompt = "a photo of an astronaut riding a horse on mars" - >>> image = stable_diffusion(prompt).images[0] - ``` - - Returns: - [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. - """ - # -1. Check `num_images_per_prompt` - if self.num_images_per_prompt != num_images_per_prompt and not self.dynamic_batch_size: - logger.warning( - f"Overriding `num_images_per_prompt({num_images_per_prompt})` to {self.num_images_per_prompt} used for the compilation. Please recompile the models with your " - f"custom `num_images_per_prompt` or turn on `dynamic_batch_size`, if you wish generating {num_images_per_prompt} per prompt." - ) - num_images_per_prompt = self.num_images_per_prompt - - # 0. Height and width to unet (static shapes) - height = self.unet.config.neuron["static_height"] * self.vae_scale_factor - width = self.unet.config.neuron["static_width"] * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - neuron_batch_size = self.unet.config.neuron["static_batch_size"] - self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 and ( - self.dynamic_batch_size or self.data_parallel_mode == "unet" - ) - - # 3. Encode input prompt - if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." - ) - text_encoder_lora_scale = None - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - # [modified for neuron] Remove not traced inputs: cross_attention_kwargs, return_dict - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - if do_classifier_free_guidance and guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if not output_type == "latent": - # [Modified] Replace with pre-compiled vae decoder - image = self.vae_decoder(latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0] - image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype) - else: - image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_img2img.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_img2img.py deleted file mode 100644 index d818d9605..000000000 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_img2img.py +++ /dev/null @@ -1,305 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Override some diffusers API for NeuroStableDiffusionImg2ImgPipeline""" - -import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union - -import PIL -import torch -from diffusers import StableDiffusionImg2ImgPipeline -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.utils import deprecate -from diffusers.utils.torch_utils import randn_tensor - -from .pipeline_utils import StableDiffusionPipelineMixin - - -if TYPE_CHECKING: - from diffusers.image_processor import PipelineImageInput - - -logger = logging.getLogger(__name__) - - -class NeuronStableDiffusionImg2ImgPipelineMixin(StableDiffusionPipelineMixin, StableDiffusionImg2ImgPipeline): - # Adapted from diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.prepare_latents - def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - else: - # [Modified] Replace with pre-compiled vae encoder, encode the init image into latents and scale the latents - init_latents = self.vae_encoder(sample=image)[0] - scaling_factor = self.vae_encoder.config.scaling_factor or 0.18215 - init_latents = scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, dtype=dtype) - - # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - latents = init_latents - - return latents - - # Adapted from diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.__call__ - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - image: Optional["PipelineImageInput"] = None, - strength: float = 0.8, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - image (`Optional["PipelineImageInput"]`, defaults to `None`): - `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both - numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list - or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a - list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image - latents as `image`, but if passing latents directly it is not encoded again. - strength (`float`, defaults to 0.8): - Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a - starting point and more noise is added the higher the `strength`. The number of denoising steps depends - on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising - process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 - essentially ignores `image`. - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter is modulated by `strength`. - guidance_scale (`float`, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`Optional[Union[str, List[str]`, defaults to `None`): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. If it is different from the batch size used for the compiltaion, - it will be overriden by the static batch size of neuron (except for dynamic batching). - eta (`float`, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`diffusers.schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - output_type (`Optional[str]`, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Optional[Callable]`, defaults to `None`): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, defaults to `None`): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - - Examples: - - ```py - >>> from optimum.neuron import NeuronStableDiffusionImg2ImgPipeline - >>> from diffusers.utils import load_image - - >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" - >>> init_image = load_image(url).convert("RGB") - - >>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} - >>> input_shapes = {"batch_size": 1, "height": 512, "width": 512} - >>> pipeline = NeuronStableDiffusionImg2ImgPipeline.from_pretrained( - ... "nitrosocke/Ghibli-Diffusion", export=True, **compiler_args, **input_shapes, - ... ) - >>> pipeline.save_pretrained("sd_img2img/") - - >>> prompt = "ghibli style, a fantasy landscape with snowcapped mountains, trees, lake with detailed reflection." - >>> image = pipeline(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images[0] - ``` - - Returns: - [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. - """ - # 0. Check `num_images_per_prompt` - if self.num_images_per_prompt != num_images_per_prompt and not self.dynamic_batch_size: - logger.warning( - f"Overriding `num_images_per_prompt({num_images_per_prompt})` to {self.num_images_per_prompt} used for the compilation. Please recompile the models with your " - f"custom `num_images_per_prompt` or turn on `dynamic_batch_size`, if you wish generating {num_images_per_prompt} image per prompt." - ) - num_images_per_prompt = self.num_images_per_prompt - - # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - neuron_batch_size = self.unet.config.neuron["static_batch_size"] - self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 and ( - self.dynamic_batch_size or self.data_parallel_mode == "unet" - ) - - # 3. Encode input prompt - if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." - ) - text_encoder_lora_scale = None - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - # 4. Preprocess image - height = self.vae_encoder.config.neuron["static_height"] - width = self.vae_encoder.config.neuron["static_width"] - image = self.image_processor.preprocess(image, height=height, width=width) - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device=None) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 6. Prepare latent variables - latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, generator - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - # [modified for neuron] Remove not traced inputs: cross_attention_kwargs, return_dict - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if not output_type == "latent": - # [Modified] Replace with pre-compiled vae decoder - image = self.vae_decoder(latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0] - image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype) - else: - image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py deleted file mode 100644 index b757f5936..000000000 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py +++ /dev/null @@ -1,402 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Override some diffusers API for NeuroStableDiffusionInpaintPipeline""" - -import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union - -import torch -from diffusers import StableDiffusionInpaintPipeline -from diffusers.image_processor import VaeImageProcessor -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput - -from .pipeline_utils import StableDiffusionPipelineMixin - - -if TYPE_CHECKING: - from diffusers.image_processor import PipelineImageInput - - -logger = logging.getLogger(__name__) - - -class NeuronStableDiffusionInpaintPipelineMixin(StableDiffusionPipelineMixin, StableDiffusionInpaintPipeline): - prepare_latents = StableDiffusionInpaintPipeline.prepare_latents - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.21.2/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py#L629 - def _encode_vae_image( - self, image: torch.Tensor, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None - ): - image_latents = self.vae_encoder(sample=image)[0] - image_latents = self.vae_encoder.config.scaling_factor * image_latents - - return image_latents - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.21.2/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py#L699 - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - image: Optional["PipelineImageInput"] = None, - mask_image: Optional["PipelineImageInput"] = None, - masked_image_latents: Optional[torch.FloatTensor] = None, - strength: float = 1.0, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - clip_skip: int = None, - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - image (`Optional["PipelineImageInput"]`, defaults to `None`): - `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to - be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch - tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the - expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the - expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but - if passing latents directly it is not encoded again. - mask_image (`Optional["PipelineImageInput"]`, defaults to `None`): - `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask - are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a - single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one - color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, - H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, - 1)`, or `(H, W)`. - strength (`float`, defaults to 1.0): - Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a - starting point and more noise is added the higher the `strength`. The number of denoising steps depends - on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising - process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 - essentially ignores `image`. - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter is modulated by `strength`. - guidance_scale (`float`, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`Optional[Union[str, List[str]`, defaults to `None`): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. If it is different from the batch size used for the compiltaion, - it will be overriden by the static batch size of neuron (except for dynamic batching). - eta (`float`, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`diffusers.schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - output_type (`Optional[str]`, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Optional[Callable]`, defaults to `None`): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, defaults to `None`): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - clip_skip (`int`, defaults to `None`): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - - Examples: - - ```py - >>> from optimum.neuron import NeuronStableDiffusionInpaintPipeline - >>> from diffusers.utils import load_image - - >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" - >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" - - >>> init_image = load_image(img_url).convert("RGB") - >>> mask_image = load_image(mask_url).convert("RGB") - - >>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} - >>> input_shapes = {"batch_size": 1, "height": 1024, "width": 1024} - >>> pipeline = NeuronStableDiffusionInpaintPipeline.from_pretrained( - ... "runwayml/stable-diffusion-inpainting", export=True, **compiler_args, **input_shapes, - ... ) - >>> pipeline.save_pretrained("sd_inpaint/") - - >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" - >>> image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0] - ``` - - Returns: - [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. - """ - # -1. Check `num_images_per_prompt` - if self.num_images_per_prompt != num_images_per_prompt and not self.dynamic_batch_size: - logger.warning( - f"Overriding `num_images_per_prompt({num_images_per_prompt})` to {self.num_images_per_prompt} used for the compilation. Please recompile the models with your " - f"custom `num_images_per_prompt` or turn on `dynamic_batch_size`, if you wish generating {num_images_per_prompt} per prompt." - ) - num_images_per_prompt = self.num_images_per_prompt - - # 0. Height and width to unet (static shapes) - height = self.unet.config.neuron["static_height"] * self.vae_scale_factor - width = self.unet.config.neuron["static_width"] * self.vae_scale_factor - - # 1. Check inputs - self.check_inputs( - prompt, - image, - mask_image, - height, - width, - strength, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - neuron_batch_size = self.unet.config.neuron["static_batch_size"] - self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 and ( - self.dynamic_batch_size or self.data_parallel_mode == "unet" - ) - - # 3. Encode input prompt - if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." - ) - text_encoder_lora_scale = None - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - # 4. set timesteps - self.scheduler.set_timesteps(num_inference_steps) - timesteps, num_inference_steps = self.get_timesteps( - num_inference_steps=num_inference_steps, strength=strength, device=None - ) - # check that number of inference steps is not < 1 - as this doesn't make sense - if num_inference_steps < 1: - raise ValueError( - f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" - f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." - ) - # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise - is_strength_max = strength == 1.0 - - # 5. Preprocess mask and image - init_image = self.image_processor.preprocess(image, height=height, width=width) - init_image = init_image.to(dtype=torch.float32) - - # 6. Prepare latent variables - num_channels_latents = self.vae_encoder.config.latent_channels - num_channels_unet = self.unet.config.in_channels - return_image_latents = num_channels_unet == 4 - - latents_outputs = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - None, - generator, - latents, - image=init_image, - timestep=latent_timestep, - is_strength_max=is_strength_max, - return_noise=True, - return_image_latents=return_image_latents, - ) - - if return_image_latents: - latents, noise, image_latents = latents_outputs - else: - latents, noise = latents_outputs - - # 7. Prepare mask latent variables - self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True - ) - mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width) - - if masked_image_latents is None: - masked_image = init_image * (mask_condition < 0.5) - else: - masked_image = masked_image_latents - - mask, masked_image_latents = self.prepare_mask_latents( - mask_condition, - masked_image, - batch_size * num_images_per_prompt, - height, - width, - prompt_embeds.dtype, - None, - generator, - do_classifier_free_guidance, - ) - - # 8. Check that sizes of mask, masked image and latents match - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - elif num_channels_unet != 4: - raise ValueError( - f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." - ) - - # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 10. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - if num_channels_unet == 9: - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - # predict the noise residual - # [modified for neuron] Remove not traced inputs: cross_attention_kwargs, return_dict - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - if num_channels_unet == 4: - init_latents_proper = image_latents[:1] - init_mask = mask[:1] - - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.add_noise( - init_latents_proper, noise, torch.tensor([noise_timestep]) - ) - - latents = (1 - init_mask) * init_latents_proper + init_mask * latents - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if not output_type == "latent": - condition_kwargs = {} - if "AsymmetricAutoencoderKL" in self.vae_decoder.config._class_name: - init_image = init_image.to(dtype=masked_image_latents.dtype) - init_image_condition = init_image.clone() - # [modified for neuron] Remove generator which is not an input for the compilation - init_image = self._encode_vae_image(init_image) - mask_condition = mask_condition.to(dtype=masked_image_latents.dtype) - condition_kwargs = {"image": init_image_condition, "mask": mask_condition} - # [Modified] Replace with pre-compiled vae decoder - image = self.vae_decoder( - latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215), **condition_kwargs - )[0] - image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype) - else: - image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_instruct_pix2pix.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_instruct_pix2pix.py deleted file mode 100644 index 9eb659aea..000000000 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_instruct_pix2pix.py +++ /dev/null @@ -1,475 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Override some diffusers API for NeuronStableDiffusionInstructPix2PixPipeline""" - -import logging -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union - -import PIL -import torch -from diffusers import StableDiffusionInstructPix2PixPipeline -from diffusers.loaders import TextualInversionLoaderMixin -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.utils.deprecation_utils import deprecate - -from .pipeline_utils import StableDiffusionPipelineMixin - - -if TYPE_CHECKING: - from diffusers.image_processor import PipelineImageInput - - -logger = logging.getLogger(__name__) - - -class NeuronStableDiffusionInstructPix2PixPipelineMixin( - StableDiffusionPipelineMixin, StableDiffusionInstructPix2PixPipeline -): - @torch.no_grad() - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - image: Optional["PipelineImageInput"] = None, - num_inference_steps: int = 100, - guidance_scale: float = 7.5, - image_guidance_scale: float = 1.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "pil", - return_dict: bool = True, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - image (`Optional["PipelineImageInput"]`, defaults to `None`): - `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept - image latents as `image`, but if passing latents directly it is not encoded again. - num_inference_steps (`int`, defaults to 100): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - image_guidance_scale (`float`, defaults to 1.5): - Push the generated image towards the inital `image`. Image guidance scale is enabled by setting - `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely - linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a - value of at least `1`. - negative_prompt (`Optional[Union[str, List[str]`, defaults to `None`): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. - eta (`float`, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`diffusers.schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - output_type (`str`, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback_on_step_end (`Optional[Callable[[int, int, Dict], None]]`, defaults to `None`): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List[str]`, defaults to `["latents"]`): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeine class. - - Examples: - - ```py - >>> import PIL - >>> import requests - >>> from io import BytesIO - - >>> from optimum.neuron import NeuronStableDiffusionInstructPix2PixPipeline - - - >>> def download_image(url): - ... response = requests.get(url) - ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") - - - >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" - - >>> init_image = download_image(img_url).resize((512, 512)) - >>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} - >>> input_shapes = {"batch_size": 1, "height": 512, "width": 512} - >>> pipe = NeuronStableDiffusionInstructPix2PixPipeline.from_pretrained( - ... "timbrooks/instruct-pix2pix", export=True, dynamic_batch_size=True, **compiler_args, **input_shapes, - ... ) - >>> pipe.save_pretrained("sd_ip2p/") - - >>> prompt = "Add a beautiful sunset" - >>> image = pipe(prompt=prompt, image=init_image).images[0] - ``` - - Returns: - [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. - """ - - # 0. Check inputs - self.check_inputs( - prompt, - None, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - callback_on_step_end_tensor_inputs, - ) - self._guidance_scale = guidance_scale - self._image_guidance_scale = image_guidance_scale - - if image is None: - raise ValueError("`image` input cannot be undefined.") - - # 1. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - neuron_batch_size = self.unet.config.neuron["static_batch_size"] - self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) - - # 2. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - num_images_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 3. Preprocess image - height = self.vae_encoder.config.neuron["static_height"] - width = self.vae_encoder.config.neuron["static_width"] - image = self.image_processor.preprocess(image, height=height, width=width) - - # 4. set timesteps - self.scheduler.set_timesteps(num_inference_steps) - timesteps = self.scheduler.timesteps - - # 5. Prepare Image latents - image_latents = self.prepare_image_latents( - image, - batch_size, - num_images_per_prompt, - self.do_classifier_free_guidance, - generator, - ) - - height, width = image_latents.shape[-2:] - height = height * self.vae_scale_factor - width = width * self.vae_scale_factor - - # 6. Prepare latent variables - num_channels_latents = self.vae_decoder.config.latent_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - generator, - latents, - ) - - # 7. Check that shapes of latents and image match the UNet channels - num_channels_image = image_latents.shape[1] - if num_channels_latents + num_channels_image != self.unet.config.in_channels: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the config of" - " `pipeline.unet` or your `image` input." - ) - - # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 9. Denoising loop - self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # Expand the latents if we are doing classifier free guidance. - # The latents are expanded 3 times because for pix2pix the guidance\ - # is applied for both the text and the input image. - latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents - - # concat latents, image_latents in the channel dimension - scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) - - # predict the noise residual - noise_pred = self.unet( - scaled_latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) - noise_pred = ( - noise_pred_uncond - + self.guidance_scale * (noise_pred_text - noise_pred_image) - + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond) - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - image_latents = callback_outputs.pop("image_latents", image_latents) - - if not output_type == "latent": - image = self.vae_decoder(latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0] - image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype) - else: - image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionInstructPix2PixPipeline.prepare_image_latents - def prepare_image_latents( - self, image, batch_size, num_images_per_prompt, do_classifier_free_guidance, generator=None - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - image_latents = image - else: - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - image_latents = [self.vae_encoder(sample=image[i : i + 1])[0] for i in range(batch_size)] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = self.vae_encoder(sample=image)[0] - - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand image_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) - - if do_classifier_free_guidance: - uncond_image_latents = torch.zeros_like(image_latents) - image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) - - return image_latents - - # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionInstructPix2PixPipeline._encode_prompt - def _encode_prompt( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int, - do_classifier_free_guidance: bool, - negative_prompt: Optional[Union[str, List]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_ prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder(text_input_ids) - prompt_embeds = prompt_embeds[0] - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = self.text_encoder(uncond_input.input_ids) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] - prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) - - return prompt_embeds - - @property - def do_classifier_free_guidance(self): - return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0 and self.dynamic_batch_size diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py deleted file mode 100644 index bd930e994..000000000 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py +++ /dev/null @@ -1,401 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Override some diffusers API for NeuronStableDiffusionXLPipelineMixin""" - -import logging -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -from diffusers import StableDiffusionXLPipeline -from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput -from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg - -from .pipeline_utils import StableDiffusionXLPipelineMixin - - -logger = logging.getLogger(__name__) - - -class NeuronStableDiffusionXLPipelineMixin(StableDiffusionXLPipelineMixin, StableDiffusionXLPipeline): - # Adapted from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L573 - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.20.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L557 - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - num_inference_steps: int = 50, - denoising_end: Optional[float] = None, - guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.0, - original_size: Optional[Tuple[int, int]] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Optional[Tuple[int, int]] = None, - negative_original_size: Optional[Tuple[int, int]] = None, - negative_crops_coords_top_left: Tuple[int, int] = (0, 0), - negative_target_size: Optional[Tuple[int, int]] = None, - clip_skip: Optional[int] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - prompt_2 (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - denoising_end (`Optional[float]`, defaults to `None`): - When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be - completed before it is intentionally prematurely terminated. As a result, the returned sample will - still retain a substantial amount of noise as determined by the discrete timesteps selected by the - scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a - "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image - Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) - guidance_scale (`float`, defaults to 5.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. If it is different from the batch size used for the compiltaion, - it will be overriden by the static batch size of neuron (except for dynamic batching). - eta (`float`, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - output_type (`Optional[str]`, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`diffusers.pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead - of a plain tuple. - callback (`Optional[Callable]`, defaults to `None`): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, defaults to `None`): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - guidance_rescale (`float`, *optional*, defaults to 0.0): - Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of - [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - Guidance rescale factor should fix overexposure when using zero terminal SNR. - original_size (`Optional[Tuple[int, int]]`, defaults to (1024, 1024)): - If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. - `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as - explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - crops_coords_top_left (`Tuple[int]`, defaults to (0, 0)): - `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position - `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting - `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - target_size (`Tuple[int]`,defaults to (1024, 1024)): - For most cases, `target_size` should be set to the desired height and width of the generated image. If - not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in - section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - negative_original_size (`Tuple[int]`, defaults to (1024, 1024)): - To negatively condition the generation process based on a specific image resolution. Part of SDXL's - micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - negative_crops_coords_top_left (`Tuple[int]`, defaults to (0, 0)): - To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's - micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - negative_target_size (`Tuple[int]`, defaults to (1024, 1024)): - To negatively condition the generation process based on a target image resolution. It should be as same - as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - clip_skip (`Optional[int]`, defaults to `None`): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - - Examples: - - ```py - >>> from optimum.neuron import NeuronStableDiffusionXLPipeline - - >>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} - >>> input_shapes = {"batch_size": 1, "height": 1024, "width": 1024} - - >>> stable_diffusion_xl = NeuronStableDiffusionXLPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-xl-base-1.0", export=True, **compiler_args, **input_shapes) - ... ) - >>> stable_diffusion_xl.save_pretrained("sd_neuron_xl/") - - >>> prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" - >>> image = stable_diffusion_xl(prompt).images[0] - ``` - - Returns: - [`diffusers.pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: - [`diffusers.pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is a list with the generated images. - """ - # -1. Check `num_images_per_prompt` - if self.num_images_per_prompt != num_images_per_prompt and not self.dynamic_batch_size: - logger.warning( - f"Overriding `num_images_per_prompt({num_images_per_prompt})` to {self.num_images_per_prompt} used for the compilation. Please recompile the models with your " - f"custom `num_images_per_prompt` or turn on `dynamic_batch_size`, if you wish generating {num_images_per_prompt} per prompt." - ) - num_images_per_prompt = self.num_images_per_prompt - - # 0. Default height and width to unet (static shapes) - height = self.unet.config.neuron["static_height"] * self.vae_scale_factor - width = self.unet.config.neuron["static_width"] * self.vae_scale_factor - - original_size = original_size or (height, width) - target_size = target_size or (height, width) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - height, - width, - callback_steps, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - neuron_batch_size = self.unet.config.neuron["static_batch_size"] - self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = ( - guidance_scale > 1.0 - and (self.dynamic_batch_size or self.data_parallel_mode == "unet") - and self.unet.config.time_cond_proj_dim is None - ) - - # 3. Encode input prompt - if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." - ) - lora_scale = None - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=lora_scale, - clip_skip=clip_skip, - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps) - - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - generator, - latents, - ) - - # 6. Prepare extra step kwargs - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Prepare added time ids & embeddings - add_text_embeds = pooled_prompt_embeds - - add_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - dtype=prompt_embeds.dtype, - ) - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = self._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype=prompt_embeds.dtype, - ) - else: - negative_add_time_ids = add_time_ids - - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - - add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) - - # 8. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - - # 8.1 Apply denoising_end - if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: - discrete_timestep_cutoff = int( - round( - self.scheduler.config.num_train_timesteps - - (denoising_end * self.scheduler.config.num_train_timesteps) - ) - ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) - timesteps = timesteps[:num_inference_steps] - - # 9. Optionally get Guidance Scale Embedding - timestep_cond = None - if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size * num_images_per_prompt) - timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim - ).to(dtype=latents.dtype) - - self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - # [modified for neuronx] Remove not traced inputs: cross_attention_kwargs, return_dict - noise_pred = self.unet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - timestep_cond=timestep_cond, - added_cond_kwargs=added_cond_kwargs, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - if do_classifier_free_guidance and guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if not output_type == "latent": - # [Modified] Replace with pre-compiled vae decoder - image = self.vae_decoder(latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0] - else: - image = latents - return StableDiffusionXLPipelineOutput(images=image) - - # apply watermark if available - if self.watermark is not None: - image = self.watermark.apply_watermark(image) - - image = self.image_processor.postprocess(image, output_type=output_type) - - if not return_dict: - return (image,) - - return StableDiffusionXLPipelineOutput(images=image) diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py deleted file mode 100644 index 75229be15..000000000 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py +++ /dev/null @@ -1,516 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Override some diffusers API for NeuronStableDiffusionXLImg2ImgPipeline""" - -import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union - -import PIL -import torch -from diffusers import StableDiffusionXLImg2ImgPipeline -from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput -from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import rescale_noise_cfg -from diffusers.utils.torch_utils import randn_tensor - -from .pipeline_utils import StableDiffusionXLPipelineMixin - - -if TYPE_CHECKING: - from diffusers.image_processor import PipelineImageInput - - -logger = logging.getLogger(__name__) - - -class NeuronStableDiffusionXLImg2ImgPipelineMixin(StableDiffusionXLPipelineMixin, StableDiffusionXLImg2ImgPipeline): - # Adapted from https://github.com/huggingface/diffusers/blob/v0.21.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L515 - def prepare_latents( - self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - # [Modified] Replace with pre-compiled vae encoder, encode the init image into latents and scale the latents - init_latents = self.vae_encoder(sample=image)[0] - scaling_factor = self.vae_encoder.config.scaling_factor or 0.18215 - init_latents = scaling_factor * init_latents - init_latents = init_latents.to(dtype) - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, dtype=dtype) - # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.21.4/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L582 - def _get_add_time_ids( - self, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - ): - if self.config.get("requires_aesthetics_score"): - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.21.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L654 - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - image: Optional["PipelineImageInput"] = None, - strength: float = 0.3, - num_inference_steps: int = 50, - denoising_start: Optional[float] = None, - denoising_end: Optional[float] = None, - guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.0, - original_size: Tuple[int, int] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Tuple[int, int] = None, - negative_original_size: Optional[Tuple[int, int]] = None, - negative_crops_coords_top_left: Tuple[int, int] = (0, 0), - negative_target_size: Optional[Tuple[int, int]] = None, - aesthetic_score: float = 6.0, - negative_aesthetic_score: float = 2.5, - clip_skip: Optional[int] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - prompt_2 (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - image (`Optional["PipelineImageInput"]`, defaults to `None`): - The image(s) to modify with the pipeline. - strength (`float`, defaults to 0.3): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of - `denoising_start` being declared as an integer, the value of `strength` will be ignored. - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - denoising_start (`Optional[float]`, defaults to `None`): - When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be - bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and - it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, - strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline - is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image - Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). - denoising_end (`Optional[float]`, defaults to `None`): - When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be - completed before it is intentionally prematurely terminated. As a result, the returned sample will - still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be - denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the - final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline - forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image - Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). - guidance_scale (`float`, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. If it is different from the batch size used for the compiltaion, - it will be overriden by the static batch size of neuron (except for dynamic batching). - eta (`float`, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - output_type (`Optional[str]`, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`diffusers.pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a - plain tuple. - callback (`Optional[Callable]`, defaults to `None`): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_stcallback_steps (`int`, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - guidance_rescale (`float`, defaults to 0.0): - Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of - [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - Guidance rescale factor should fix overexposure when using zero terminal SNR. - original_size (`Optional[Tuple[int, int]]`, defaults to (1024, 1024)): - If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. - `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as - explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - crops_coords_top_left (`Tuple[int]`, defaults to (0, 0)): - `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position - `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting - `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - target_size (`Tuple[int]`,defaults to (1024, 1024)): - For most cases, `target_size` should be set to the desired height and width of the generated image. If - not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in - section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - negative_original_size (`Tuple[int]`, defaults to (1024, 1024)): - To negatively condition the generation process based on a specific image resolution. Part of SDXL's - micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - negative_crops_coords_top_left (`Tuple[int]`, defaults to (0, 0)): - To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's - micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - negative_target_size (`Tuple[int]`, defaults to (1024, 1024)): - To negatively condition the generation process based on a target image resolution. It should be as same - as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - aesthetic_score (`float`, defaults to 6.0): - Used to simulate an aesthetic score of the generated image by influencing the positive text condition. - Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - negative_aesthetic_score (`float`, defaults to 2.5): - Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to - simulate an aesthetic score of the generated image by influencing the negative text condition. - clip_skip (`Optional[int]`, defaults to `None`): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - - Examples: - - ```py - >>> from optimum.neuron import NeuronStableDiffusionXLImg2ImgPipeline - >>> from diffusers.utils import load_image - - >>> url = "https://huggingface.co/datasets/optimum/documentation-images/resolve/main/intel/openvino/sd_xl/castle_friedrich.png" - >>> init_image = load_image(url).convert("RGB") - - >>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} - >>> input_shapes = {"batch_size": 1, "height": 512, "width": 512} - >>> pipeline = NeuronStableDiffusionXLImg2ImgPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-xl-base-1.0", export=True, **compiler_args, **input_shapes, - ... ) - >>> pipeline.save_pretrained("sdxl_img2img/") - - >>> prompt = "a dog running, lake, moat" - >>> image = pipeline(prompt=prompt, image=init_image).images[0] - ``` - - Returns: - [`diffusers.pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: - [`diffusers.pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a - `tuple. When returning a tuple, the first element is a list with the generated images. - """ - # 0. Check batch size - if self.num_images_per_prompt != num_images_per_prompt and not self.dynamic_batch_size: - logger.warning( - f"Overriding `num_images_per_prompt({num_images_per_prompt})` to {self.num_images_per_prompt} used for the compilation. Please recompile the models with your " - f"custom `num_images_per_prompt` or turn on `dynamic_batch_size`, if you wish generating {num_images_per_prompt} image per prompt." - ) - num_images_per_prompt = self.num_images_per_prompt - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - strength, - num_inference_steps, - callback_steps, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - neuron_batch_size = self.unet.config.neuron["static_batch_size"] - self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 and ( - self.dynamic_batch_size or self.data_parallel_mode == "unet" - ) - - # 3. Encode input prompt - if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." - ) - text_encoder_lora_scale = None - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, - ) - - # 4. Preprocess image - height = self.vae_encoder.config.neuron["static_height"] - width = self.vae_encoder.config.neuron["static_width"] - image = self.image_processor.preprocess(image, height=height, width=width) - - # 5. Prepare timesteps - def denoising_value_valid(dnv): - return isinstance(denoising_end, float) and 0 < dnv < 1 - - self.scheduler.set_timesteps(num_inference_steps, device=None) - timesteps, num_inference_steps = self.get_timesteps( - num_inference_steps, strength, None, denoising_start=denoising_start if denoising_value_valid else None - ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - add_noise = True if denoising_start is None else False - - # 6. Prepare latent variables - latents = self.prepare_latents( - image, - latent_timestep, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - generator, - add_noise, - ) - - # 7. Prepare extra step kwargs. - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - height, width = latents.shape[-2:] - height = height * self.vae_scale_factor - width = width * self.vae_scale_factor - - original_size = original_size or (height, width) - target_size = target_size or (height, width) - - # 8. Prepare added time ids & embeddings - if negative_original_size is None: - negative_original_size = original_size - if negative_target_size is None: - negative_target_size = target_size - - add_text_embeds = pooled_prompt_embeds - add_time_ids, add_neg_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype=prompt_embeds.dtype, - ) - add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) - - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) - add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) - - prompt_embeds = prompt_embeds - add_text_embeds = add_text_embeds - add_time_ids = add_time_ids - - # 9. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - - # 9.1 Apply denoising_end - if ( - denoising_end is not None - and denoising_start is not None - and denoising_value_valid(denoising_end) - and denoising_value_valid(denoising_start) - and denoising_start >= denoising_end - ): - raise ValueError( - f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: " - + f" {denoising_end} when using type float." - ) - elif denoising_end is not None and denoising_value_valid(denoising_end): - discrete_timestep_cutoff = int( - round( - self.scheduler.config.num_train_timesteps - - (denoising_end * self.scheduler.config.num_train_timesteps) - ) - ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) - timesteps = timesteps[:num_inference_steps] - - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - # [modified for neuron] Remove not traced inputs: cross_attention_kwargs, return_dict - noise_pred = self.unet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - added_cond_kwargs=added_cond_kwargs, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - if do_classifier_free_guidance and guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if not output_type == "latent": - # [Modified] Replace with pre-compiled vae decoder - image = self.vae_decoder(latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0] - else: - return StableDiffusionXLPipelineOutput(images=latents) - - # apply watermark if available - if self.watermark is not None: - image = self.watermark.apply_watermark(image) - - image = self.image_processor.postprocess(image, output_type=output_type) - - if not return_dict: - return (image,) - - return StableDiffusionXLPipelineOutput(images=image) diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_inpaint.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_inpaint.py deleted file mode 100644 index 02394e922..000000000 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_inpaint.py +++ /dev/null @@ -1,679 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Override some diffusers API for NeuronStableDiffusionXLInpaintPipeline""" - - -import logging -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -from diffusers import StableDiffusionXLInpaintPipeline -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput -from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import ( - rescale_noise_cfg, - retrieve_timesteps, -) -from diffusers.utils import deprecate - -from .pipeline_utils import StableDiffusionXLPipelineMixin - - -logger = logging.getLogger(__name__) - - -class NeuronStableDiffusionXLInpaintPipelineMixin(StableDiffusionXLPipelineMixin, StableDiffusionXLInpaintPipeline): - prepare_latents = StableDiffusionXLInpaintPipeline.prepare_latents - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.21.2/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py#L629 - def _encode_vae_image( - self, image: torch.Tensor, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None - ): - image_latents = self.vae_encoder(sample=image)[0] - image_latents = self.vae_encoder.config.scaling_factor * image_latents - - return image_latents - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.21.4/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L582 - def _get_add_time_ids( - self, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - ): - if self.config.get("requires_aesthetics_score"): - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.21.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py#L871 - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - image: Optional["PipelineImageInput"] = None, - mask_image: Optional["PipelineImageInput"] = None, - masked_image_latents: Optional[torch.FloatTensor] = None, - padding_mask_crop: Optional[int] = None, - strength: float = 0.9999, - num_inference_steps: int = 50, - timesteps: Optional[List[int]] = None, - denoising_start: Optional[float] = None, - denoising_end: Optional[float] = None, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - ip_adapter_image: Optional[PipelineImageInput] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.0, - original_size: Tuple[int, int] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Tuple[int, int] = None, - negative_original_size: Optional[Tuple[int, int]] = None, - negative_crops_coords_top_left: Tuple[int, int] = (0, 0), - negative_target_size: Optional[Tuple[int, int]] = None, - aesthetic_score: float = 6.0, - negative_aesthetic_score: float = 2.5, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - **kwargs, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - prompt_2 (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - image (`Optional["PipelineImageInput"]`, defaults to `None`): - `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will - be masked out with `mask_image` and repainted according to `prompt`. - mask_image (`Optional["PipelineImageInput"]`, defaults to `None`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted - to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) - instead of 3, so the expected shape would be `(B, H, W, 1)`. - padding_mask_crop (`Optional[int]`, defaults to `None`): - The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If - `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and - contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on - the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large - and contain information inreleant for inpainging, such as background. - strength (`float`, defaults to 0.9999): - Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be - between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the - `strength`. The number of denoising steps depends on the amount of noise initially added. When - `strength` is 1, added noise will be maximum and the denoising process will run for the full number of - iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked - portion of the reference `image`. Note that in the case of `denoising_start` being declared as an - integer, the value of `strength` will be ignored. - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`Optional[List[int]]`, defaults to `None`): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - denoising_start (`Optional[float]`, defaults to `None`): - When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be - bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and - it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, - strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline - is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image - Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). - denoising_end (`Optional[float]`, defaults to `None`): - When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be - completed before it is intentionally prematurely terminated. As a result, the returned sample will - still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be - denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the - final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline - forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image - Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). - guidance_scale (`float`, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`Optional[Union[str, List[str]]]`, defaults to `None`): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - ip_adapter_image: (`Optional[PipelineImageInput]`, defaults to `None`): Optional image input to work with IP Adapters. - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. - eta (`float`, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`Optional[str]`, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - cross_attention_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - original_size (`Tuple[int]`, defaults to (1024, 1024)): - If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. - `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as - explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - crops_coords_top_left (`Tuple[int]`, defaults to (0, 0)): - `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position - `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting - `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - target_size (`Tuple[int]`, defaults to (1024, 1024)): - For most cases, `target_size` should be set to the desired height and width of the generated image. If - not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in - section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - negative_original_size (`Tuple[int]`, defaults to (1024, 1024)): - To negatively condition the generation process based on a specific image resolution. Part of SDXL's - micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - negative_crops_coords_top_left (`Tuple[int]`, defaults to (0, 0)): - To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's - micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - negative_target_size (`Tuple[int]`, defaults to (1024, 1024)): - To negatively condition the generation process based on a target image resolution. It should be as same - as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - aesthetic_score (`float`, defaults to 6.0): - Used to simulate an aesthetic score of the generated image by influencing the positive text condition. - Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - negative_aesthetic_score (`float`, defaults to 2.5): - Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to - simulate an aesthetic score of the generated image by influencing the negative text condition. - clip_skip (`Optional[int]`, defaults to `None`): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Optional[Callable[[int, int, Dict], None]]`, defaults to `None`): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List[str]`, defaults to ["latents"]): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - - Examples: - - ```py - >>> from optimum.neuron import NeuronStableDiffusionXLInpaintPipeline - >>> from diffusers.utils import load_image - - >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png" ( - >>> mask_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png" - - >>> init_image = load_image(img_url).convert("RGB") - >>> mask_image = load_image(mask_url).convert("RGB") - - >>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} - >>> input_shapes = {"batch_size": 1, "height": 1024, "width": 1024} - >>> pipeline = NeuronStableDiffusionXLInpaintPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-xl-base-1.0", export=True, **compiler_args, **input_shapes, - ... ) - >>> pipeline.save_pretrained("sdxl_inpaint/") - - >>> prompt = "A deep sea diver floating" - >>> image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, guidance_scale=12.5).images[0] - ``` - - Returns: - [`diffusers.pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: - [`diffusers.pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a - `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. - """ - callback = kwargs.pop("callback", None) - callback_steps = kwargs.pop("callback_steps", None) - - if callback is not None: - deprecate( - "callback", - "1.0.0", - "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", - ) - if callback_steps is not None: - deprecate( - "callback_steps", - "1.0.0", - "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", - ) - - # -1. Check `num_images_per_prompt` - if self.num_images_per_prompt != num_images_per_prompt and not self.dynamic_batch_size: - logger.warning( - f"Overriding `num_images_per_prompt({num_images_per_prompt})` to {self.num_images_per_prompt} used for the compilation. Please recompile the models with your " - f"custom `num_images_per_prompt` or turn on `dynamic_batch_size`, if you wish generating {num_images_per_prompt} per prompt." - ) - num_images_per_prompt = self.num_images_per_prompt - - # 0. Height and width to unet (static shapes) - height = self.unet.config.neuron["static_height"] * self.vae_scale_factor - width = self.unet.config.neuron["static_width"] * self.vae_scale_factor - - # 1. Check inputs - self.check_inputs( - prompt, - prompt_2, - image, - mask_image, - height, - width, - strength, - callback_steps, - output_type, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - callback_on_step_end_tensor_inputs, - padding_mask_crop, - ) - - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._clip_skip = clip_skip - self._cross_attention_kwargs = cross_attention_kwargs - self._denoising_end = denoising_end - self._denoising_start = denoising_start - self._interrupt = False - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - neuron_batch_size = self.unet.config.neuron["static_batch_size"] - self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 and ( - self.dynamic_batch_size or self.data_parallel_mode == "unet" - ) - - # 3. Encode input prompt - if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." - ) - text_encoder_lora_scale = None - - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - - # 4. set timesteps - def denoising_value_valid(dnv): - return isinstance(denoising_end, float) and 0 < dnv < 1 - - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, None, timesteps) - timesteps, num_inference_steps = self.get_timesteps( - num_inference_steps, strength, None, denoising_start=denoising_start if denoising_value_valid else None - ) - # check that number of inference steps is not < 1 - as this doesn't make sense - if num_inference_steps < 1: - raise ValueError( - f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" - f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." - ) - # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise - is_strength_max = strength == 1.0 - - # 5. Preprocess mask and image - if padding_mask_crop is not None: - crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) - resize_mode = "fill" - else: - crops_coords = None - resize_mode = "default" - - original_image = image - init_image = self.image_processor.preprocess( - image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode - ) - init_image = init_image.to(dtype=torch.float32) - - self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True - ) - mask = self.mask_processor.preprocess( - mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords - ) - - if masked_image_latents is not None: - masked_image = masked_image_latents - elif init_image.shape[1] == 4: - # if images are in latent space, we can't mask it - masked_image = None - else: - masked_image = init_image * (mask < 0.5) - - # 6. Prepare latent variables - num_channels_latents = self.vae_encoder.config.latent_channels - num_channels_unet = self.unet.config.in_channels - return_image_latents = num_channels_unet == 4 - - add_noise = True if denoising_start is None else False - latents_outputs = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - None, - generator, - latents, - image=init_image, - timestep=latent_timestep, - is_strength_max=is_strength_max, - add_noise=add_noise, - return_noise=True, - return_image_latents=return_image_latents, - ) - - if return_image_latents: - latents, noise, image_latents = latents_outputs - else: - latents, noise = latents_outputs - - # 7. Prepare mask latent variables - mask, masked_image_latents = self.prepare_mask_latents( - mask, - masked_image, - batch_size * num_images_per_prompt, - height, - width, - prompt_embeds.dtype, - None, - generator, - do_classifier_free_guidance, - ) - - # 8. Check that sizes of mask, masked image and latents match - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - elif num_channels_unet != 4: - raise ValueError( - f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." - ) - # 8.1 Prepare extra step kwargs. - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - height, width = latents.shape[-2:] - height = height * self.vae_scale_factor - width = width * self.vae_scale_factor - - original_size = original_size or (height, width) - target_size = target_size or (height, width) - - # 10. Prepare added time ids & embeddings - if negative_original_size is None: - negative_original_size = original_size - if negative_target_size is None: - negative_target_size = target_size - - add_text_embeds = pooled_prompt_embeds - add_time_ids, add_neg_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype=prompt_embeds.dtype, - ) - add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) - - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) - add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) - - prompt_embeds = prompt_embeds - add_text_embeds = add_text_embeds - add_time_ids = add_time_ids - - if ip_adapter_image is not None: - image_embeds = self.prepare_ip_adapter_image_embeds( - ip_adapter_image, None, batch_size * num_images_per_prompt - ) - - # 11. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - - if ( - denoising_end is not None - and denoising_start is not None - and denoising_value_valid(denoising_end) - and denoising_value_valid(denoising_start) - and denoising_start >= denoising_end - ): - raise ValueError( - f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: " - + f" {denoising_end} when using type float." - ) - elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): - discrete_timestep_cutoff = int( - round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) - ) - ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) - timesteps = timesteps[:num_inference_steps] - - # 11.1 Optionally get Guidance Scale Embedding - timestep_cond = None - if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) - timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim - ).to(dtype=latents.dtype) - - self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - if num_channels_unet == 9: - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - if ip_adapter_image is not None: - added_cond_kwargs["image_embeds"] = image_embeds - # [modified for neuron] Remove not traced inputs: cross_attention_kwargs, return_dict - noise_pred = self.unet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - added_cond_kwargs=added_cond_kwargs, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - if do_classifier_free_guidance and guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - if num_channels_unet == 4: - init_latents_proper = image_latents[:1] - if do_classifier_free_guidance: - init_mask, _ = mask.chunk(2) - else: - init_mask = mask - - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.add_noise( - init_latents_proper, noise, torch.tensor([noise_timestep]) - ) - - latents = (1 - init_mask) * init_latents_proper + init_mask * latents - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) - add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) - mask = callback_outputs.pop("mask", mask) - masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - if not output_type == "latent": - # [Modified] Replace with pre-compiled vae decoder - image = self.vae_decoder(latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0] - else: - return StableDiffusionXLPipelineOutput(images=latents) - - # apply watermark if available - if self.watermark is not None: - image = self.watermark.apply_watermark(image) - - image = self.image_processor.postprocess(image, output_type=output_type) - - if padding_mask_crop is not None: - image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] - - if not return_dict: - return (image,) - - return StableDiffusionXLPipelineOutput(images=image) diff --git a/optimum/neuron/pipelines/diffusers/pipeline_utils.py b/optimum/neuron/pipelines/diffusers/pipeline_utils.py index 6d0645b5e..b1461be59 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_utils.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_utils.py @@ -12,399 +12,62 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Optional, Union import torch -from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.utils.torch_utils import randn_tensor +from diffusers import ( + StableDiffusionXLControlNetPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, +) logger = logging.getLogger(__name__) -class DiffusionBasePipelineMixin: - def check_num_images_per_prompt(self, prompt_batch_size: int, neuron_batch_size: int, num_images_per_prompt: int): - if ( - not self.data_parallel_mode == "all" - and not self.dynamic_batch_size - and neuron_batch_size != prompt_batch_size * num_images_per_prompt - ): - raise ValueError( - f"Models in the pipeline were compiled with `batch_size` {neuron_batch_size} which does not equal the number of" - f" prompt({prompt_batch_size}) multiplied by `num_images_per_prompt`({num_images_per_prompt}). You need to enable" - " `dynamic_batch_size` or precisely configure `num_images_per_prompt` during the compilation." - ) - - # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt") - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - - -class StableDiffusionPipelineMixin(DiffusionBasePipelineMixin): - # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, dtype=dtype) - elif latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L302 - def encode_prompt( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int, - do_classifier_free_guidance: bool, - negative_prompt: Optional[Union[str, list]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, +class NeuronStableDiffusionXLPipelineMixin: + # Adapted from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L573 + def _get_add_time_ids_text_to_image( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`Union[str, List[str]]`): - prompt to be encoded - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`Optional[Union[str, list]]`, defaults to `None`): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`Optional[float]`, defaults to `None`): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale # TODO: remove as lora_scale should static. - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - # [Modified] Input and its dtype constraints - prompt_embeds = self.text_encoder(input_ids=text_input_ids) - prompt_embeds = prompt_embeds[0] - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = self.text_encoder(uncond_input.input_ids) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds, negative_prompt_embeds - - -class StableDiffusionXLPipelineMixin(DiffusionBasePipelineMixin): - # Adapted from https://github.com/huggingface/diffusers/blob/v0.20.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L502 - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, dtype=dtype) - elif latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.20.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L219 - def encode_prompt( + # Adapted from https://github.com/huggingface/diffusers/blob/v0.21.4/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L582 + def _get_add_time_ids_image_to_image( self, - prompt: str, - prompt_2: Optional[str] = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str`): - prompt to be encoded - prompt_2 (`Optional[str]`, defaults to `None`): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - num_images_per_prompt (`int`, defaults to 1): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`, defaults to `True`): - whether to use classifier free guidance or not - negative_prompt (`Optional[str]`, defaults to `None`): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`Optional[str]`, defaults to `None`): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`Optional[float]`, defaults to `None`): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`Optional[int]`, defaults to `None`): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) + if self.config.get("requires_aesthetics_score"): + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - # textual inversion: procecss multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) + return add_time_ids, add_neg_time_ids - prompt_embeds = text_encoder(input_ids=text_input_ids) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds[-1][-2] # hidden_states - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 2)] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and getattr( - self.config, "force_zeros_for_empty_prompt", False - ) - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - if negative_prompt is None: - negative_prompt = "" if isinstance(prompt, str) else [""] * batch_size - else: - negative_prompt = negative_prompt - # negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - negative_prompt_embeds = text_encoder(input_ids=uncond_input.input_ids) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds[-1][-2] # hidden_states - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - prompt_embeds = prompt_embeds - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 + def _get_add_time_ids(self, *args, **kwargs): + if self.auto_model_class in [StableDiffusionXLPipeline, StableDiffusionXLControlNetPipeline]: + return self._get_add_time_ids_text_to_image(*args, **kwargs) + elif self.auto_model_class in [StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline]: + return self._get_add_time_ids_image_to_image(*args, **kwargs) + else: + raise ValueError( + f"The pipeline type {self.auto_model_class} is not yet supported by Optimum Neuron, please open an request on: https://github.com/huggingface/optimum-neuron/issues." ) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 0c4e60209..19562b16a 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -25,6 +25,7 @@ "DIFFUSION_MODEL_TEXT_ENCODER_2_NAME", "DIFFUSION_MODEL_TEXT_ENCODER_NAME", "DIFFUSION_MODEL_UNET_NAME", + "DIFFUSION_MODEL_TRANSFORMER_NAME", "DIFFUSION_MODEL_VAE_DECODER_NAME", "DIFFUSION_MODEL_VAE_ENCODER_NAME", "DIFFUSION_MODEL_CONTROLNET_NAME", @@ -84,6 +85,7 @@ DIFFUSION_MODEL_CONTROLNET_NAME, DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, DIFFUSION_MODEL_TEXT_ENCODER_NAME, + DIFFUSION_MODEL_TRANSFORMER_NAME, DIFFUSION_MODEL_UNET_NAME, DIFFUSION_MODEL_VAE_DECODER_NAME, DIFFUSION_MODEL_VAE_ENCODER_NAME, diff --git a/optimum/neuron/utils/constant.py b/optimum/neuron/utils/constant.py index 82f8f134f..dbf600bd7 100644 --- a/optimum/neuron/utils/constant.py +++ b/optimum/neuron/utils/constant.py @@ -20,6 +20,7 @@ DIFFUSION_MODEL_TEXT_ENCODER_NAME = "text_encoder" DIFFUSION_MODEL_TEXT_ENCODER_2_NAME = "text_encoder_2" DIFFUSION_MODEL_UNET_NAME = "unet" +DIFFUSION_MODEL_TRANSFORMER_NAME = "transformer" DIFFUSION_MODEL_VAE_ENCODER_NAME = "vae_encoder" DIFFUSION_MODEL_VAE_DECODER_NAME = "vae_decoder" DIFFUSION_MODEL_CONTROLNET_NAME = "controlnet" diff --git a/setup.py b/setup.py index 45bb37672..4976ce2f1 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ "sentencepiece", "datasets", "sacremoses", - "diffusers>=0.28.0, <0.29.0", + "diffusers>=0.28.0, <=0.30.3", "safetensors", "sentence-transformers >= 2.2.0", "peft", @@ -72,7 +72,7 @@ "neuronx_distributed==0.9.0", "libneuronxla==2.0.4115.0", ], - "diffusers": ["diffusers>=0.28.0, <0.29.0", "peft"], + "diffusers": ["diffusers>=0.28.0, <=0.30.3", "peft"], "sentence-transformers": ["sentence-transformers >= 2.2.0"], } diff --git a/tests/inference/test_stable_diffusion_pipeline.py b/tests/inference/test_stable_diffusion_pipeline.py index baca33668..37f2e278d 100644 --- a/tests/inference/test_stable_diffusion_pipeline.py +++ b/tests/inference/test_stable_diffusion_pipeline.py @@ -419,3 +419,19 @@ def test_compatibility_with_compel(self, model_arch): num_inference_steps=1, ).images[0] self.assertIsInstance(image, PIL.Image.Image) + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_from_pipe(self, model_arch): + txt2img_pipeline = NeuronStableDiffusionXLPipeline.from_pretrained( + MODEL_NAMES[model_arch], + export=True, + dynamic_batch_size=False, + **self.STATIC_INPUTS_SHAPES, + **self.COMPILER_ARGS, + ) + img2img_pipeline = NeuronStableDiffusionXLImg2ImgPipeline.from_pipe(txt2img_pipeline) + url = "https://huggingface.co/datasets/optimum/documentation-images/resolve/main/intel/openvino/sd_xl/castle_friedrich.png" + init_image = download_image(url) + prompt = "a dog running, lake, moat" + image = img2img_pipeline(prompt=prompt, image=init_image).images[0] + self.assertIsInstance(image, PIL.Image.Image)