Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make compel work with SD-XL #41

Merged
merged 22 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions src/compel/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import cross_attention_control
from .conditioning_scheduler import ConditioningScheduler, StaticConditioningScheduler
from .embeddings_provider import EmbeddingsProvider, BaseTextualInversionManager, DownweightMode
from .embeddings_provider import EmbeddingsProvider, BaseTextualInversionManager, DownweightMode, EmbeddingsProviderMulti
from .prompt_parser import Blend, FlattenedPrompt, PromptParser, CrossAttentionControlSubstitute, Conjunction

__all__ = ["Compel", "DownweightMode"]
Expand All @@ -21,15 +21,16 @@ class Compel:


def __init__(self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
textual_inversion_manager: Optional[BaseTextualInversionManager] = None,
dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32,
truncate_long_prompts: bool = True,
padding_attention_mask_value: int = 1,
downweight_mode: DownweightMode = DownweightMode.MASK,
use_penultimate_clip_layer: bool=False,
device: Optional[str] = None):
return_pooled: Union[str, List[bool]] = False,
"""
Initialize Compel. The tokenizer and text_encoder can be lifted directly from any DiffusionPipeline.

Expand All @@ -50,16 +51,33 @@ def __init__(self,
`device`: The torch device on which the tensors should be created. If a device is not specified, the device will
be the same as that of the `text_encoder` at the moment when `build_conditioning_tensor()` is called.
"""
self.conditioning_provider = EmbeddingsProvider(tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=textual_inversion_manager,
dtype_for_device_getter=dtype_for_device_getter,
truncate=truncate_long_prompts,
padding_attention_mask_value = padding_attention_mask_value,
downweight_mode=downweight_mode,
use_penultimate_clip_layer=use_penultimate_clip_layer
)
self._device = device
if isinstance(tokenizer, (tuple, list)) and not isinstance(text_encoder, (tuple, list)):
raise ValueError("Cannot provide list of tokenizers, but not of text encoders.")
elif not isinstance(tokenizer, (tuple, list)) and isinstance(text_encoder, (tuple, list)):
raise ValueError("Cannot provide list of text encoders, but not of tokenizers.")
elif isinstance(tokenizer, (tuple, list)) and isinstance(text_encoder, (tuple, list)):
self.conditioning_provider = EmbeddingsProviderMulti(tokenizers=tokenizer,
text_encoders=text_encoder,
textual_inversion_manager=textual_inversion_manager,
dtype_for_device_getter=dtype_for_device_getter,
truncate=truncate_long_prompts,
padding_attention_mask_value = padding_attention_mask_value,
downweight_mode=downweight_mode,
use_penultimate_clip_layer=use_penultimate_clip_layer,
return_pooled=return_pooled,
)
else:
self.conditioning_provider = EmbeddingsProvider(tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=textual_inversion_manager,
dtype_for_device_getter=dtype_for_device_getter,
truncate=truncate_long_prompts,
padding_attention_mask_value = padding_attention_mask_value,
downweight_mode=downweight_mode,
use_penultimate_clip_layer=use_penultimate_clip_layer,
return_pooled=return_pooled,
)
self._device = device

@property
def device(self):
Expand Down Expand Up @@ -170,8 +188,6 @@ def build_conditioning_tensor_for_prompt_object(self, prompt: Union[Blend, Flatt

raise ValueError(f"unsupported prompt type: {type(prompt).__name__}")



def pad_conditioning_tensors_to_same_length(self, conditionings: List[torch.Tensor],
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
) -> List[torch.Tensor]:
"""
Expand Down Expand Up @@ -213,7 +229,8 @@ def pad_conditioning_tensors_to_same_length(self, conditionings: List[torch.Tens

def _get_conditioning_for_flattened_prompt(self,
prompt: FlattenedPrompt,
should_return_tokens: bool=False
should_return_tokens: bool=False,
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
should_return_pooled: bool=False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if type(prompt) is not FlattenedPrompt:
raise ValueError(f"embeddings can only be made from FlattenedPrompts, got {type(prompt).__name__} instead")
Expand All @@ -222,10 +239,19 @@ def _get_conditioning_for_flattened_prompt(self,
conditioning, tokens = self.conditioning_provider.get_embeddings_for_weighted_prompt_fragments(
text_batch=[fragments], fragment_weights_batch=[weights],
should_return_tokens=True, device=self.device)

outputs = (conditioning,)

patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
if should_return_tokens:
return conditioning, tokens
else:
return conditioning
outputs += (tokens,)

if should_return_pooled:
outputs += (pooled,)

if len(outputs) == 1:
return outputs[0]

return outputs

def _get_conditioning_for_blend(self, blend: Blend):
conditionings_to_blend = []
Expand Down
153 changes: 127 additions & 26 deletions src/compel/embeddings_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Callable, Union, Tuple, List, Optional

import torch
from transformers import CLIPTokenizer, CLIPTextModel
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from typing import List, Tuple

__all__ = ["EmbeddingsProvider", "DownweightMode"]

Expand All @@ -22,13 +23,15 @@ class EmbeddingsProvider:

def __init__(self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], # convert a list of int token ids to a tensor of embeddings
textual_inversion_manager: BaseTextualInversionManager = None,
dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32,
truncate: bool = True,
padding_attention_mask_value: int = 1,
downweight_mode: DownweightMode = DownweightMode.MASK,
use_penultimate_clip_layer: bool=False
use_penultimate_clip_layer: bool=False,
use_penultimate_layer_norm: bool=True,
requires_pooled: bool = False,
):
"""
`tokenizer`: converts strings to lists of int token ids
Expand All @@ -51,6 +54,8 @@ def __init__(self,
self.downweight_mode = downweight_mode
self.use_penultimate_clip_layer = use_penultimate_clip_layer

self.requires_pooled = requires_pooled

# by default always use float32
self.get_dtype_for_device = dtype_for_device_getter

Expand All @@ -76,7 +81,7 @@ def get_embeddings_for_weighted_prompt_fragments(self,
text_batch: List[List[str]],
fragment_weights_batch: List[List[float]],
should_return_tokens: bool = False,
device='cpu'
device='cpu',
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""

Expand All @@ -94,6 +99,7 @@ def get_embeddings_for_weighted_prompt_fragments(self,

batch_z = None
batch_tokens = None

patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
for fragments, weights in zip(text_batch, fragment_weights_batch):

# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
Expand All @@ -108,7 +114,7 @@ def get_embeddings_for_weighted_prompt_fragments(self,

# handle weights >=1
tokens, per_token_weights, mask = self.get_token_ids_and_expand_weights(fragments, weights, device=device)
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, mask, device=device)
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, mask)[0]
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

# this is our starting point
embeddings = base_embedding.unsqueeze(0)
Expand Down Expand Up @@ -177,11 +183,15 @@ def get_embeddings_for_weighted_prompt_fragments(self,

# should have shape (B, 77, 768)
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
outputs = (batch_z,)
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

if should_return_tokens:
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
return batch_z, batch_tokens
else:
return batch_z
outputs += (batch_tokens,)

if len(outputs) == 1:
return outputs[0]

return outputs
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True) -> List[List[int]]:
"""
Expand Down Expand Up @@ -302,10 +312,9 @@ def build_weighted_embedding_tensor(self,
token_ids: torch.Tensor,
per_token_weights: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_pooled: bool = False,
device: Optional[str] = None) -> torch.Tensor:
"""
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
Build a tensor that embeds the passed-in token IDs and applies the given per_token weights

:param token_ids: A tensor of shape `n*[self.max_length]` containing token IDs (ints) where n is some arbitrary
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
integer (i.e. n==1 for shorter prompts, or it may be >1 if there are more than max_length tokens in the
original prompt)
Expand All @@ -328,8 +337,9 @@ def build_weighted_embedding_tensor(self,
[self.tokenizer.eos_token_id] +
[self.tokenizer.pad_token_id] * (self.max_token_count - 2),
dtype=torch.int, device=device).unsqueeze(0)
empty_z = self._encode_token_ids_to_embeddings(empty_token_ids)
empty_z, _ = self._encode_token_ids_to_embeddings(empty_token_ids)
weighted_z = None
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
pooled = None
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

chunk_size = self.max_token_count
while chunk_start_index < token_ids.shape[0]:
Expand All @@ -342,7 +352,7 @@ def build_weighted_embedding_tensor(self,
else None
)

z = self._encode_token_ids_to_embeddings(chunk_token_ids, chunk_attention_mask)
z, this_pooled = self._encode_token_ids_to_embeddings(chunk_token_ids, chunk_attention_mask)
batch_weights_expanded = chunk_per_token_weights.reshape(
chunk_per_token_weights.shape + (1,)).expand(z.shape).to(z)

Expand All @@ -353,23 +363,20 @@ def build_weighted_embedding_tensor(self,
if weighted_z is None
else torch.cat([weighted_z, this_weighted_z], dim=1)
)

patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
if pooled is not None:
pooled = (
this_pooled
if pooled is None
else torch.mean([pooled, this_pooled], dim=1)
)

chunk_start_index += chunk_size

patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
return weighted_z
if self.requires_pooled:
return weighted_z, pooled

def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
text_encoder_output = self.text_encoder(token_ids,
attention_mask,
output_hidden_states=self.use_penultimate_clip_layer,
return_dict=True)
if self.use_penultimate_clip_layer:
# needs normalizing
penultimate_hidden_state = text_encoder_output.hidden_states[-2]
return self.text_encoder.text_model.final_layer_norm(penultimate_hidden_state)
else:
# already normalized
return text_encoder_output.last_hidden_state
return weighted_z

def _get_token_ranges_for_fragments(self, chunked_and_padded_token_ids: List[int], fragments: List[str]) -> List[Tuple[int, int]]:
"""
Expand Down Expand Up @@ -433,3 +440,97 @@ def _get_token_ranges_for_fragments(self, chunked_and_padded_token_ids: List[int
fragment_start = fragment_end + 1

return corresponding_indices


class EmbeddingsProviderMulti:

def __init__(self,
tokenizers: List[CLIPTokenizer], # converts strings to lists of int token ids
text_encoders: List[Union[CLIPTextModel, CLIPTextModelWithProjection]], # convert a list of int token ids to a tensor of embeddings
textual_inversion_manager: BaseTextualInversionManager = None,
dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32,
hidden_states_types: Union[str, List[str]] = "final",
return_pooled: Union[str, List[bool]] = False,
):

hidden_states_types = len(text_encoders) * [hidden_states_types] if not isinstance(hidden_states_types, (list, tuple)) else hidden_states_types
return_pooled = len(text_encoders) * [return_pooled] if not isinstance(return_pooled, (list, tuple)) else return_pooled

self.embedding_providers = [
EmbeddingsProvider(tokenizer, text_encoder, textual_inversion_manager, dtype_for_device_getter, hidden_states_type, pooled)
for tokenizer, text_encoder, hidden_states_type, pooled in zip(tokenizers, text_encoders, hidden_states_types, return_pooled)
]

@property
def text_encoder(self):
return self.embedding_providers[0].text_encoder

@property
def tokenizer(self):
return self.embedding_providers[0].tokenizer

def get_token_ids(self, *args, **kwargs):
# get token ids does not use padding. The padding ID is the only ID that can differ between tokenizers
# so for simplicity, we just return `get_token_ids` of the first tokenizer
return self.embedding_providers[0].get_token_ids(self, *args, **kwargs)

def get_embeddings_for_weighted_prompt_fragments(self,
text_batch: List[List[str]],
fragment_weights_batch: List[List[float]],
should_return_tokens: bool = False,
should_return_pooled: bool = False,
device='cpu',
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:

outputs = [provider.get_embeddings_for_weighted_prompt_fragments(text_batch, fragment_weights_batch, should_return_tokens=should_return_tokens, should_return_pooled=should_return_pooled, device=device) for provider in self.embedding_providers]

pooled_list = []
text_embeddings_list = []
tokens = []

for i, output in enumerate(outputs):
text_embeddings_list.append(output[0])

if should_return_tokens:
tokens.append(output[1])

if should_return_pooled:
pooled_list.append(output[-1])

text_embeddings = torch.cat(text_embeddings_list, dim=-1)

pooled_list = [p for p in pooled_list if p is not None]
pooled = torch.cat(pooled_list, dim=-1) if len(pooled_list) > 0 else None

outputs = (text_embeddings,)

if should_return_tokens:
outputs += (tokens,)

if should_return_pooled:
outputs += (pooled,)

return outputs

def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
text_encoder_output = self.text_encoder(token_ids,
attention_mask,
output_hidden_states=self.use_penultimate_clip_layer,
return_dict=True)

if self.requires_pooled:
pooled = text_encoder_output.pooler_output if "pooler_output" in text_encoder_output else text_encoder_output.text_embeds
else:
pooled = None

if self.use_penultimate_clip_layer:
# needs normalizing
penultimate_hidden_state = text_encoder_output.hidden_states[-2]

if self.use_penultimate_layer_norm:
penultimate_hidden_state = self.text_encoder.text_model.final_layer_norm(penultimate_hidden_state)
return (penultimate_hidden_state, pooled)
else:
# already normalized
return (text_encoder_output.last_hidden_state, pooled)
Loading