From 24b8b5cf5e44b585e8808c34616cfa20e8e39866 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 11 Oct 2022 20:30:09 +0200 Subject: [PATCH] `mps`: Alternative implementation for `repeat_interleave` (#766) * mps: alt. implementation for repeat_interleave * style * Bump mps version of PyTorch in the documentation. * Apply suggestions from code review Co-authored-by: Suraj Patil * Simplify: do not check for device. * style * Fix repeat dimensions: - The unconditional embeddings are always created from a single prompt. - I was shadowing the batch_size var. * Split long lines as suggested by Suraj. Co-authored-by: Patrick von Platen Co-authored-by: Suraj Patil --- docs/source/optimization/mps.mdx | 2 +- .../stable_diffusion/pipeline_stable_diffusion.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/source/optimization/mps.mdx b/docs/source/optimization/mps.mdx index ad347b84e8..ff9d614c87 100644 --- a/docs/source/optimization/mps.mdx +++ b/docs/source/optimization/mps.mdx @@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License. - Mac computer with Apple silicon (M1/M2) hardware. - macOS 12.3 or later. - arm64 version of Python. -- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.13.0.dev20220830` or later. +- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later. ## Inference Pipeline diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index bd1aade0bc..ca6c580ffc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -218,8 +218,10 @@ def __call__( text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] - # duplicate text embeddings for each generation per prompt - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -256,8 +258,10 @@ def __call__( ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] - # duplicate unconditional embeddings for each generation per prompt - uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch