Skip to content

Commit

Permalink
rework compute_clip_image_embedding overloads + improve docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Sep 27, 2024
1 parent 9e2a499 commit 07fc619
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions src/refiners/foundationals/latent_diffusion/image_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import torch
from jaxtyping import Float
Expand Down Expand Up @@ -454,37 +454,42 @@ def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
"""
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})

@overload
def compute_clip_image_embedding(self, image_prompt: Tensor, weights: list[float] | None = None) -> Tensor: ...

@overload
def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor: ...

@overload
def compute_clip_image_embedding(
self, image_prompt: list[Image.Image], weights: list[float] | None = None
) -> Tensor: ...

def compute_clip_image_embedding(
self,
image_prompt: Tensor | Image.Image | list[Image.Image],
image_prompt: Image.Image | list[Image.Image] | Tensor,
weights: list[float] | None = None,
concat_batches: bool = True,
) -> Tensor:
"""Compute the CLIP image embedding.
"""Compute CLIP image embeddings from the provided image prompts.
Args:
image_prompt: The image prompt to use.
weights: The scale to use for the image prompt.
concat_batches: Whether to concatenate the batches.
image_prompt: A single image or a list of images to compute embeddings for.
This can be a PIL Image, a list of PIL Images, or a Tensor.
weights: An optional list of scaling factors for the conditional embeddings.
If provided, it must have the same length as the number of images in `image_prompt`.
Each weight scales the corresponding image's conditional embedding, allowing you to
adjust the influence of each image. Defaults to uniform weights of 1.0.
concat_batches: Determines how embeddings are concatenated when multiple images are provided:
- If `True`, embeddings from multiple images are concatenated along the feature
dimension to form a longer sequence of image tokens. This is useful when you want to
treat multiple images as a single combined input.
- If `False`, embeddings are kept separate along the batch dimension, treating each image
independently.
Returns:
The CLIP image embedding.
A Tensor containing the CLIP image embeddings.
The structure of the returned Tensor depends on the `concat_batches` parameter:
- If `concat_batches` is `True` and multiple images are provided, the embeddings are
concatenated along the feature dimension.
- If `concat_batches` is `False` or a single image is provided, the embeddings are returned
as a batch, with one embedding per image.
"""
if isinstance(image_prompt, Image.Image):
image_prompt = self.preprocess_image(image_prompt)
elif isinstance(image_prompt, list):
assert all(isinstance(image, Image.Image) for image in image_prompt)
assert all(
isinstance(image, Image.Image) for image in image_prompt
), "All elements of `image_prompt` must be of PIL Images."
image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])

negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
Expand Down

0 comments on commit 07fc619

Please sign in to comment.