From eb9f00460c39dad6aaf44e8b6c181d97824acf9c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 13:31:42 +0000 Subject: [PATCH 01/20] first draft --- src/compel/compel.py | 58 +++-- src/compel/embeddings_provider.py | 127 ++++++++++- test/@ | 363 ++++++++++++++++++++++++++++++ test/prompting_test_utils.py | 12 +- test/test_compel.py | 15 ++ 5 files changed, 547 insertions(+), 28 deletions(-) create mode 100644 test/@ diff --git a/src/compel/compel.py b/src/compel/compel.py index 53b7841..5bf573b 100644 --- a/src/compel/compel.py +++ b/src/compel/compel.py @@ -6,7 +6,7 @@ from . import cross_attention_control from .conditioning_scheduler import ConditioningScheduler, StaticConditioningScheduler -from .embeddings_provider import EmbeddingsProvider, BaseTextualInversionManager +from .embeddings_provider import EmbeddingsProvider, BaseTextualInversionManager, EmbeddingsProviderMulti from .prompt_parser import Blend, FlattenedPrompt, PromptParser, CrossAttentionControlSubstitute __all__ = ["Compel"] @@ -18,15 +18,34 @@ class ExtraConditioningInfo: class Compel: def __init__(self, - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], textual_inversion_manager: Optional[BaseTextualInversionManager] = None, - dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32): - self.conditioning_provider = EmbeddingsProvider(tokenizer=tokenizer, - text_encoder=text_encoder, - textual_inversion_manager=textual_inversion_manager, - dtype_for_device_getter=dtype_for_device_getter - ) + dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32, + hidden_states_type: Union[str, List[str]] = "final", + return_pooled: Union[str, List[bool]] = False, + ): + + if isinstance(tokenizer, (tuple, list)) and isinstance(text_encoder, (tuple, list)): + self.conditioning_provider = EmbeddingsProviderMulti(tokenizers=tokenizer, + text_encoders=text_encoder, + textual_inversion_manager=textual_inversion_manager, + dtype_for_device_getter=dtype_for_device_getter, + hidden_states_types=hidden_states_type, + return_pooled=return_pooled, + ) + elif isinstance(tokenizer, (tuple, list)) and not isinstance(text_encoder, (tuple, list)): + raise ValueError("Cannot provide list of tokenizers, but not of text encoders.") + elif not isinstance(tokenizer, (tuple, list)) and isinstance(text_encoder, (tuple, list)): + raise ValueError("Cannot provide list of text encoders, but not of tokenizers.") + else: + self.conditioning_provider = EmbeddingsProvider(tokenizer=tokenizer, + text_encoder=text_encoder, + textual_inversion_manager=textual_inversion_manager, + dtype_for_device_getter=dtype_for_device_getter, + hidden_states_type=hidden_states_type, + return_pooled=return_pooled, + ) @property def device(self): @@ -91,18 +110,27 @@ def build_conditioning_tensor_for_prompt_object(self, prompt: Union[Blend, Flatt raise ValueError(f"unsupported prompt type: {type(prompt).__name__}") - def _get_conditioning_for_flattened_prompt(self, prompt: FlattenedPrompt, should_return_tokens: bool=False + def _get_conditioning_for_flattened_prompt(self, prompt: FlattenedPrompt, should_return_tokens: bool=False, should_return_pooled: bool=False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if type(prompt) is not FlattenedPrompt: raise ValueError(f"embeddings can only be made from FlattenedPrompts, got {type(prompt).__name__} instead") fragments = [x.text for x in prompt.children] weights = [x.weight for x in prompt.children] - conditioning, tokens = self.conditioning_provider.get_embeddings_for_weighted_prompt_fragments( - text_batch=[fragments], fragment_weights_batch=[weights], should_return_tokens=True, device=self.device) + conditioning, tokens, pooled = self.conditioning_provider.get_embeddings_for_weighted_prompt_fragments( + text_batch=[fragments], fragment_weights_batch=[weights], should_return_tokens=True, should_return_pooled=True, device=self.device) + + outputs = (conditioning,) + + if should_return_pooled: + outputs += (pooled,) + if should_return_tokens: - return conditioning, tokens - else: - return conditioning + outputs += (tokens,) + + if len(outputs) == 1: + return outputs[0] + + return outputs def _get_conditioning_for_blend(self, blend: Blend): conditionings_to_blend = None diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index 3839217..9abdb57 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -3,7 +3,7 @@ from typing import Callable, Union import torch -from transformers import CLIPTokenizer, CLIPTextModel +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection from typing import List, Tuple __all__ = ["EmbeddingsProvider"] @@ -18,14 +18,19 @@ class EmbeddingsProvider: def __init__(self, tokenizer: CLIPTokenizer, # converts strings to lists of int token ids - text_encoder: CLIPTextModel, # convert a list of int token ids to a tensor of embeddings + text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], # convert a list of int token ids to a tensor of embeddings textual_inversion_manager: BaseTextualInversionManager = None, dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32, + hidden_states_type: str = "final", # final, penultimate, + requires_pooled: bool = False, ): self.tokenizer = tokenizer self.text_encoder = text_encoder self.textual_inversion_manager = textual_inversion_manager + self.hidden_states_type = hidden_states_type + self.requires_pooled = requires_pooled + # by default always use float32 self.get_dtype_for_device = dtype_for_device_getter @@ -51,6 +56,7 @@ def get_embeddings_for_weighted_prompt_fragments(self, text_batch: List[List[str]], fragment_weights_batch: List[List[float]], should_return_tokens: bool = False, + should_return_pooled: bool = False, device='cpu', ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ @@ -81,7 +87,10 @@ def get_embeddings_for_weighted_prompt_fragments(self, # handle weights >=1 tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments, weights, device=device) - base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights) + base_embedding_outputs = self.build_weighted_embedding_tensor(tokens, per_token_weights) + + base_embedding = base_embedding_outputs[0] if self.requires_pooled else base_embedding_outputs + pooled = base_embedding_outputs[1] if self.requires_pooled else None # this is our starting point embeddings = base_embedding.unsqueeze(0) @@ -131,11 +140,18 @@ def get_embeddings_for_weighted_prompt_fragments(self, # should have shape (B, 77, 768) #print(f"assembled all tokens into tensor of shape {batch_z.shape}") + outputs = (batch_z,) if should_return_tokens: - return batch_z, batch_tokens - else: - return batch_z + outputs += (batch_tokens,) + + if should_return_pooled: + outputs += (pooled,) + + if len(outputs) == 1: + return outputs[0] + + return outputs def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True) -> List[List[int]]: """ @@ -233,7 +249,7 @@ def get_token_ids_and_expand_weights(self, fragments: List[str], weights: List[f return all_token_ids_tensor, per_token_weights_tensor - def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor: + def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor): """ Build a tensor that embeds the passed-in token IDs and applies the given per_token weights :param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints) @@ -245,14 +261,107 @@ def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_wei if token_ids.shape != torch.Size([self.max_token_count]): raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_token_count}]") - z = self.text_encoder(token_ids.unsqueeze(0), return_dict=False)[0] + text_encoder_output = self.text_encoder(token_ids.unsqueeze(0), return_dict=False, output_hidden_states=self.hidden_states_type == "penultimate") + + pooled = None + if self.hidden_states_type == "final" and not self.requires_pooled: + z = text_encoder_output[0] if isinstance(self.text_encoder, CLIPTextModel) else text_encoder_output[1] + elif self.hidden_states_type == "final" and self.requires_pooled: + z = text_encoder_output[0] if isinstance(self.text_encoder, CLIPTextModel) else text_encoder_output[1] + pooled = text_encoder_output[1] if isinstance(self.text_encoder, CLIPTextModel) else text_encoder_output[0] + elif self.hidden_states_type == "penultimate" and not self.requires_pooled: + z = text_encoder_output[2][-2] + elif self.hidden_states_type == "penultimate" and self.requires_pooled: + z = text_encoder_output[2][-2] + pooled = text_encoder_output[1] if isinstance(self.text_encoder, CLIPTextModel) else text_encoder_output[0] + empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] + [self.tokenizer.eos_token_id] + [self.tokenizer.pad_token_id] * (self.max_token_count - 2), dtype=torch.int, device=z.device).unsqueeze(0) - empty_z = self.text_encoder(empty_token_ids, return_dict=False)[0] + + empty_text_encoder_output = self.text_encoder(empty_token_ids, return_dict=False, output_hidden_states=self.hidden_states_type == "penultimate") + if self.hidden_states_type == "final": + empty_z = empty_text_encoder_output[0] if isinstance(self.text_encoder, CLIPTextModel) else empty_text_encoder_output[1] + elif self.hidden_states_type == "penultimate": + empty_z = empty_text_encoder_output[2][-2] + batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape).to(z) z_delta_from_empty = z - empty_z weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) + if self.requires_pooled: + return weighted_z, pooled + return weighted_z + + +class EmbeddingsProviderMulti: + + def __init__(self, + tokenizers: List[CLIPTokenizer], # converts strings to lists of int token ids + text_encoders: List[Union[CLIPTextModel, CLIPTextModelWithProjection]], # convert a list of int token ids to a tensor of embeddings + textual_inversion_manager: BaseTextualInversionManager = None, + dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32, + hidden_states_types: Union[str, List[str]] = "final", + return_pooled: Union[str, List[bool]] = False, + ): + + hidden_states_types = len(text_encoders) * [hidden_states_types] if not isinstance(hidden_states_types, (list, tuple)) else hidden_states_types + return_pooled = len(text_encoders) * [return_pooled] if not isinstance(return_pooled, (list, tuple)) else return_pooled + + self.embedding_providers = [ + EmbeddingsProvider(tokenizer, text_encoder, textual_inversion_manager, dtype_for_device_getter, hidden_states_type, pooled) + for tokenizer, text_encoder, hidden_states_type, pooled in zip(tokenizers, text_encoders, hidden_states_types, return_pooled) + ] + + @property + def text_encoder(self): + return self.embedding_providers[0].text_encoder + + @property + def tokenizer(self): + return self.embedding_providers[0].tokenizer + + def get_token_ids(self, *args, **kwargs): + # get token ids does not use padding. The padding ID is the only ID that can differ between tokenizers + # so for simplicity, we just return `get_token_ids` of the first tokenizer + return self.embedding_providers[0].get_token_ids(self, *args, **kwargs) + + def get_embeddings_for_weighted_prompt_fragments(self, + text_batch: List[List[str]], + fragment_weights_batch: List[List[float]], + should_return_tokens: bool = False, + should_return_pooled: bool = False, + device='cpu', + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + outputs = [provider.get_embeddings_for_weighted_prompt_fragments(text_batch, fragment_weights_batch, should_return_tokens=should_return_tokens, should_return_pooled=should_return_pooled, device=device) for provider in self.embedding_providers] + + pooled_list = [] + text_embeddings_list = [] + tokens = [] + + for i, output in enumerate(outputs): + text_embeddings_list.append(output[0]) + + if should_return_tokens: + tokens.append(output[1]) + + if should_return_pooled: + pooled_list.append(output[-1]) + + text_embeddings = torch.cat(text_embeddings_list, dim=-1) + + pooled_list = [p for p in pooled_list if p is not None] + pooled = torch.cat(pooled_list, dim=-1) if len(pooled_list) > 0 else None + + outputs = (text_embeddings,) + + if should_return_tokens: + outputs += (tokens,) + + if should_return_pooled: + outputs += (pooled,) + + return outputs diff --git a/test/@ b/test/@ new file mode 100644 index 0000000..17f0bf7 --- /dev/null +++ b/test/@ @@ -0,0 +1,363 @@ +import math +from abc import ABC +from typing import Callable, Union + +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from typing import List, Tuple + +__all__ = ["EmbeddingsProvider"] + + +class BaseTextualInversionManager(ABC): + def expand_textual_inversion_token_ids_if_necessary(self, token_ids: List[int]) -> List[int]: + raise NotImplementedError() + + +class EmbeddingsProvider: + + def __init__(self, + tokenizer: CLIPTokenizer, # converts strings to lists of int token ids + text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], # convert a list of int token ids to a tensor of embeddings + textual_inversion_manager: BaseTextualInversionManager = None, + dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32, + hidden_states_type: str = "final", # final, penultimate, + return_pooled: bool = False, + ): + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.textual_inversion_manager = textual_inversion_manager + + self.hidden_states_type = hidden_states_type + self.return_pooled = return_pooled + + # by default always use float32 + self.get_dtype_for_device = dtype_for_device_getter + + + @property + def max_token_count(self) -> int: + return self.tokenizer.model_max_length + + + @classmethod + def apply_embedding_weights(cls, embeddings: torch.Tensor, per_embedding_weights: List[float], + normalize: bool) -> torch.Tensor: + per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) + if normalize: + per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights) + + reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,)) + blended_embeddings = torch.sum(embeddings * reshaped_weights, dim=1) + # blended_embeddings now has shape (77, 768) + return blended_embeddings + + def get_embeddings_for_weighted_prompt_fragments(self, + text_batch: List[List[str]], + fragment_weights_batch: List[List[float]], + should_return_tokens: bool = False, + device='cpu', + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + + :param text_batch: A list of fragments of text to which different weights are to be applied. + :param fragment_weights_batch: A list of weights, one for each entry in `fragments`. + :return: A tensor of shape `[1, 77, token_dim]` containing weighted embeddings where token_dim is 768 for SD1 + and 1280 for SD2 + """ + if len(text_batch) != len(fragment_weights_batch): + raise ValueError( + f"lengths of text and fragment_weights lists are not the same "+ + f"({len(text_batch)} != {len(fragment_weights_batch)})") + + batch_z = None + batch_tokens = None + for fragments, weights in zip(text_batch, fragment_weights_batch): + + # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively + # applying a multiplier to the CFG scale on a per-token basis). + # For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept + # captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active + # interest, however small, in redness; what the user probably intends when they attach the number 0.01 to + # "red" is to tell SD that it should almost completely *ignore* redness). + # To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt + # string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the + # closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment. + + # handle weights >=1 + tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments, weights, device=device) + base_embedding_outputs = self.build_weighted_embedding_tensor(tokens, per_token_weights) + + base_embedding = base_embedding_outputs[0] if self.return_pooled else base_embedding_outputs[0] + pooled = base_embedding_outputs[1] if self.return_pooled else None + + # this is our starting point + embeddings = base_embedding.unsqueeze(0) + per_embedding_weights = [1.0] + + # now handle weights <1 + # Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped + # with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting + # embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words + # removed. + # eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding + # for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it + # such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain". + for index, fragment_weight in enumerate(weights): + if fragment_weight < 1: + fragments_without_this = fragments[:index] + fragments[index+1:] + weights_without_this = weights[:index] + weights[index+1:] + tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments_without_this, weights_without_this, device=device) + embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights) + + embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1) + # weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0 + # if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding + # therefore: + # fragment_weight = 1: we are at base_z => lerp weight 0 + # fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1 + # fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf + # so let's use tan(), because: + # tan is 0.0 at 0, + # 1.0 at PI/4, and + # inf at PI/2 + # -> tan((1-weight)*PI/2) should give us ideal lerp weights + epsilon = 1e-9 + fragment_weight = max(epsilon, fragment_weight) # inf is bad + embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2) + # todo handle negative weight? + + per_embedding_weights.append(embedding_lerp_weight) + + lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0) + + #print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") + + # append to batch + batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1) + batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1) + + # should have shape (B, 77, 768) + #print(f"assembled all tokens into tensor of shape {batch_z.shape}") + outputs = (batch_z,) + + if should_return_tokens: + outputs += (batch_tokens,) + + if self.return_pooled: + outputs += (pooled,) + + if len(outputs) == 1: + return outputs[0] + + return outputs + + def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True) -> List[List[int]]: + """ + Convert a list of strings like `["a cat", "a dog", "monkey riding a bicycle"]` into a list of lists of token + ids like `[[bos, 0, 1, eos], [bos, 0, 2, eos], [bos, 3, 4, 0, 5, eos]]`. bos/eos markers are skipped if + `include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length + (typically 75 tokens + eos/bos markers). + + :param texts: The strings to convert. + :param include_start_and_end_markers: If True (default), returned token id lists will start with the beginning + of sequence marker and end with the end-of-sequence marker (`eos`). + :return: A list of lists of token ids corresponding to the input strings. + """ + # for args documentation of self.tokenizer() see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py + # (part of `transformers` lib) + token_ids_list = self.tokenizer( + texts, + truncation=True, + max_length=self.max_token_count, + return_overflowing_tokens=False, + padding='do_not_pad', + return_tensors=None, # just give me lists of ints + )['input_ids'] + + result = [] + for token_ids in token_ids_list: + # trim eos/bos + token_ids = token_ids[1:-1] + # pad for textual inversions with vector length >1 + if self.textual_inversion_manager is not None: + token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids) + # truncate if necessary to max_length-2 (leaving room for bos/eos) + token_ids = token_ids[0:self.max_token_count - 2] + # add back eos/bos if requested + if include_start_and_end_markers: + token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id] + + result.append(token_ids) + + return result + + def get_token_ids_and_expand_weights(self, fragments: List[str], weights: List[float], device: str) -> (torch.Tensor, torch.Tensor): + ''' + Given a list of text fragments and corresponding weights: tokenize each fragment, append the token sequences + together and return a padded token sequence starting with the bos marker, ending with the eos marker, and padded + or truncated as appropriate to `self.max_length`. Also return a list of weights expanded from the passed-in + weights to match each token. + + :param fragments: Text fragments to tokenize and concatenate. May be empty. + :param weights: Per-fragment weights (i.e. quasi-CFG scaling). Values from 0 to inf are permitted. In practise with SD1.5 + values >1.6 tend to produce garbage output. Must have same length as `fragment`. + :return: A tuple of tensors `(token_ids, weights)`. `token_ids` is ints, `weights` is floats, both have shape `[self.max_length]`. + ''' + if len(fragments) != len(weights): + raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(fragments)} != {len(weights)})") + + # empty is meaningful + if len(fragments) == 0: + fragments = [''] + weights = [1.0] + per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False) + all_token_ids = [] + per_token_weights = [] + #print("all fragments:", fragments, weights) + for this_fragment_token_ids, weight in zip(per_fragment_token_ids, weights): + # append + all_token_ids += this_fragment_token_ids + # fill out weights tensor with one float per token + per_token_weights += [float(weight)] * len(this_fragment_token_ids) + + # leave room for bos/eos + max_token_count_without_bos_eos_markers = self.max_token_count - 2 + if len(all_token_ids) > max_token_count_without_bos_eos_markers: + excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers + # TODO build nice description string of how the truncation was applied + # this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to + # self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit. + print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated") + all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers] + per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers] + + # pad out to a self.max_length-entry array: [bos_token, , eos_token, pad_token...] + # (typically self.max_length == 77) + all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id] + per_token_weights = [1.0] + per_token_weights + [1.0] + pad_length = self.max_token_count - len(all_token_ids) + all_token_ids += [self.tokenizer.pad_token_id] * pad_length + per_token_weights += [1.0] * pad_length + + all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device) + per_token_weights_tensor = torch.tensor(per_token_weights, + dtype=self.get_dtype_for_device(self.text_encoder.device), + device=device) + #print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}") + return all_token_ids_tensor, per_token_weights_tensor + + + def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor): + """ + Build a tensor that embeds the passed-in token IDs and applies the given per_token weights + :param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints) + :param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats) + :return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings + where `token_dim` is 768 for SD1 and 1280 for SD2. + """ + # print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") + if token_ids.shape != torch.Size([self.max_token_count]): + raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_token_count}]") + + text_encoder_output = self.text_encoder(token_ids.unsqueeze(0), return_dict=False, output_hidden_states=self.hidden_states_type == "penultimate") + + pooled = None + if self.hidden_states_type == "final" and not self.return_pooled: + z = text_encoder_output[0] if isinstance(self.text_encoder, CLIPTextModel) else text_encoder_output[1] + elif self.hidden_states_type == "final" and self.return_pooled: + z = text_encoder_output[0] if isinstance(self.text_encoder, CLIPTextModel) else text_encoder_output[1] + pooled = text_encoder_output[1] if isinstance(self.text_encoder, CLIPTextModel) else text_encoder_output[0] + elif self.hidden_states_type == "penultimate" and not self.return_pooled: + z = text_encoder_output[2][-2] + elif self.hidden_states_type == "penultimate" and self.return_pooled: + z = text_encoder_output[2][-2] + pooled = text_encoder_output[1] if isinstance(self.text_encoder, CLIPTextModel) else text_encoder_output[0] + + empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] + + [self.tokenizer.eos_token_id] + + [self.tokenizer.pad_token_id] * (self.max_token_count - 2), + dtype=torch.int, device=z.device).unsqueeze(0) + + empty_text_encoder_output = self.text_encoder(empty_token_ids, return_dict=False, output_hidden_states=self.hidden_states_type == "penultimate") + if self.hidden_states_type == "final": + empty_z = empty_text_encoder_output[0] if isinstance(self.text_encoder, CLIPTextModel) else empty_text_encoder_output[1] + elif self.hidden_states_type == "penultimate": + empty_z = empty_text_encoder_output[2][-2] + + batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape).to(z) + z_delta_from_empty = z - empty_z + weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) + + if self.return_pooled: + return weighted_z, pooled + + return weighted_z + + +class EmbeddingsProviderMulti: + + def __init__(self, + tokenizers: List[CLIPTokenizer], # converts strings to lists of int token ids + text_encoders: List[Union[CLIPTextModel, CLIPTextModelWithProjection]], # convert a list of int token ids to a tensor of embeddings + textual_inversion_manager: BaseTextualInversionManager = None, + dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32, + hidden_states_types: Union[str, List[str]] = "final", + return_pooled: Union[str, List[bool]] = False, + ): + + hidden_states_types = len(text_encoders) * [hidden_states_types] if not isinstance(hidden_states_types, (list, tuple)) else hidden_states_types + return_pooled = len(text_encoders) * [return_pooled] if not isinstance(return_pooled, (list, tuple)) else return_pooled + + self.embedding_providers = [ + EmbeddingsProvider(tokenizer, text_encoder, textual_inversion_manager, dtype_for_device_getter, hidden_states_type, pooled) + for tokenizer, text_encoder, hidden_states_type, pooled in zip(tokenizers, text_encoders, hidden_states_types, return_pooled) + ] + + @property + def text_encoder(self): + return self.embedding_providers[0].text_encoder + + @property + def tokenizer(self): + return self.embedding_providers[0].tokenizer + + def get_token_ids(self, *args, **kwargs): + # get token ids does not use padding. The padding ID is the only ID that can differ between tokenizers + # so for simplicity, we just return `get_token_ids` of the first tokenizer + return self.embedding_providers[0].get_token_ids(self, *args, **kwargs) + + def get_embeddings_for_weighted_prompt_fragments(self, + text_batch: List[List[str]], + fragment_weights_batch: List[List[float]], + should_return_tokens: bool = False, + device='cpu', + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + outputs = [provider.get_embeddings_for_weighted_prompt_fragments(text_batch, fragment_weights_batch, should_return_tokens, device=device) for provider in self.embedding_providers] + + pooled_list = [] + text_embeddings_list = [] + tokens = [] + + for i, output in enumerate(outputs): + text_embeddings_list.append(output[0]) + + if should_return_tokens: + tokens.append(output[1]) + + if self.embedding_providers[i].return_pooled: + pooled_list.append(output[-1]) + + text_embeddings = torch.cat(text_embeddings_list, dim=-1) + pooled = torch.cat(pooled_list, dim=-1) if len(pooled_list) > 0 else None + + outputs = (text_embeddings,) + + if pooled is not None: + outputs += (pooled) + + if should_return_tokens: + outputs += (tokens) + + return outputs diff --git a/test/prompting_test_utils.py b/test/prompting_test_utils.py index 3cfa41b..419838d 100644 --- a/test/prompting_test_utils.py +++ b/test/prompting_test_utils.py @@ -41,7 +41,7 @@ def resize_token_embeddings(self, new_size=None): def get_input_embeddings(self): return self.embeddings - def forward(self, input_ids: torch.Tensor, return_dict: bool=True): + def forward(self, input_ids: torch.Tensor, return_dict: bool=True, output_hidden_states: bool=False): if input_ids.shape[0] > 1: raise AssertionError("for unit testing, only batch size =1 is supported") all_embeddings = torch.cat([e.unsqueeze(0) for e in self.embeddings]).to(self.device) @@ -55,14 +55,18 @@ def __init__(self, last_hidden_state): self.last_hidden_state = last_hidden_state def __getitem__(self, item): - assert item == 0 - return self.last_hidden_state + if item == 0: + return self.last_hidden_state[:, -1, :] + if item == 1: + return self.last_hidden_state + if item == 2: + return 2 * [self.last_hidden_state] o = EmbeddingsObject(embeddings) return o def __call__(self, input_ids, **kwargs): - return self.forward(input_ids=input_ids, return_dict=True) + return self.forward(input_ids=input_ids, return_dict=True, output_hidden_states=kwargs.pop("output_hidden_states", False)) class DummyTokenizer(): def __init__(self, model_max_length=77): diff --git a/test/test_compel.py b/test/test_compel.py index 9824c3a..4fdf51d 100644 --- a/test/test_compel.py +++ b/test/test_compel.py @@ -45,6 +45,21 @@ def test_basic_prompt(self): expected_positive_conditioning, expected_negative_conditioning) + def test_basic_prompt_multi_text_encoder(self): + tokenizer_1 = DummyTokenizer() + text_encoder_1 = DummyTransformer() + + tokenizer_2 = DummyTokenizer() + text_encoder_2 = DummyTransformer() + + compel = Compel(tokenizer=[tokenizer_1, tokenizer_2], text_encoder=[text_encoder_1, text_encoder_2], hidden_states_type="penultimate", return_pooled=[False, True]) + + # test "a b c" makes it to the Conditioning intact for t=0, t=0.5, t=1 + prompt = " ".join(KNOWN_WORDS[:3]) + output = compel(prompt) + + assert output.shape == (1, 77, 2 * 768) + def test_basic_negative_prompt(self): tokenizer = DummyTokenizer() From 1e5ce7b2d52d7cff8a3fc3b31a88388d2288fad5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 13:31:53 +0000 Subject: [PATCH 02/20] remove bogus --- test/@ | 363 --------------------------------------------------------- 1 file changed, 363 deletions(-) delete mode 100644 test/@ diff --git a/test/@ b/test/@ deleted file mode 100644 index 17f0bf7..0000000 --- a/test/@ +++ /dev/null @@ -1,363 +0,0 @@ -import math -from abc import ABC -from typing import Callable, Union - -import torch -from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection -from typing import List, Tuple - -__all__ = ["EmbeddingsProvider"] - - -class BaseTextualInversionManager(ABC): - def expand_textual_inversion_token_ids_if_necessary(self, token_ids: List[int]) -> List[int]: - raise NotImplementedError() - - -class EmbeddingsProvider: - - def __init__(self, - tokenizer: CLIPTokenizer, # converts strings to lists of int token ids - text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], # convert a list of int token ids to a tensor of embeddings - textual_inversion_manager: BaseTextualInversionManager = None, - dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32, - hidden_states_type: str = "final", # final, penultimate, - return_pooled: bool = False, - ): - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.textual_inversion_manager = textual_inversion_manager - - self.hidden_states_type = hidden_states_type - self.return_pooled = return_pooled - - # by default always use float32 - self.get_dtype_for_device = dtype_for_device_getter - - - @property - def max_token_count(self) -> int: - return self.tokenizer.model_max_length - - - @classmethod - def apply_embedding_weights(cls, embeddings: torch.Tensor, per_embedding_weights: List[float], - normalize: bool) -> torch.Tensor: - per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) - if normalize: - per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights) - - reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,)) - blended_embeddings = torch.sum(embeddings * reshaped_weights, dim=1) - # blended_embeddings now has shape (77, 768) - return blended_embeddings - - def get_embeddings_for_weighted_prompt_fragments(self, - text_batch: List[List[str]], - fragment_weights_batch: List[List[float]], - should_return_tokens: bool = False, - device='cpu', - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - - :param text_batch: A list of fragments of text to which different weights are to be applied. - :param fragment_weights_batch: A list of weights, one for each entry in `fragments`. - :return: A tensor of shape `[1, 77, token_dim]` containing weighted embeddings where token_dim is 768 for SD1 - and 1280 for SD2 - """ - if len(text_batch) != len(fragment_weights_batch): - raise ValueError( - f"lengths of text and fragment_weights lists are not the same "+ - f"({len(text_batch)} != {len(fragment_weights_batch)})") - - batch_z = None - batch_tokens = None - for fragments, weights in zip(text_batch, fragment_weights_batch): - - # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively - # applying a multiplier to the CFG scale on a per-token basis). - # For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept - # captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active - # interest, however small, in redness; what the user probably intends when they attach the number 0.01 to - # "red" is to tell SD that it should almost completely *ignore* redness). - # To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt - # string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the - # closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment. - - # handle weights >=1 - tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments, weights, device=device) - base_embedding_outputs = self.build_weighted_embedding_tensor(tokens, per_token_weights) - - base_embedding = base_embedding_outputs[0] if self.return_pooled else base_embedding_outputs[0] - pooled = base_embedding_outputs[1] if self.return_pooled else None - - # this is our starting point - embeddings = base_embedding.unsqueeze(0) - per_embedding_weights = [1.0] - - # now handle weights <1 - # Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped - # with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting - # embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words - # removed. - # eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding - # for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it - # such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain". - for index, fragment_weight in enumerate(weights): - if fragment_weight < 1: - fragments_without_this = fragments[:index] + fragments[index+1:] - weights_without_this = weights[:index] + weights[index+1:] - tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments_without_this, weights_without_this, device=device) - embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights) - - embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1) - # weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0 - # if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding - # therefore: - # fragment_weight = 1: we are at base_z => lerp weight 0 - # fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1 - # fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf - # so let's use tan(), because: - # tan is 0.0 at 0, - # 1.0 at PI/4, and - # inf at PI/2 - # -> tan((1-weight)*PI/2) should give us ideal lerp weights - epsilon = 1e-9 - fragment_weight = max(epsilon, fragment_weight) # inf is bad - embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2) - # todo handle negative weight? - - per_embedding_weights.append(embedding_lerp_weight) - - lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0) - - #print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") - - # append to batch - batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1) - batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1) - - # should have shape (B, 77, 768) - #print(f"assembled all tokens into tensor of shape {batch_z.shape}") - outputs = (batch_z,) - - if should_return_tokens: - outputs += (batch_tokens,) - - if self.return_pooled: - outputs += (pooled,) - - if len(outputs) == 1: - return outputs[0] - - return outputs - - def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True) -> List[List[int]]: - """ - Convert a list of strings like `["a cat", "a dog", "monkey riding a bicycle"]` into a list of lists of token - ids like `[[bos, 0, 1, eos], [bos, 0, 2, eos], [bos, 3, 4, 0, 5, eos]]`. bos/eos markers are skipped if - `include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length - (typically 75 tokens + eos/bos markers). - - :param texts: The strings to convert. - :param include_start_and_end_markers: If True (default), returned token id lists will start with the beginning - of sequence marker and end with the end-of-sequence marker (`eos`). - :return: A list of lists of token ids corresponding to the input strings. - """ - # for args documentation of self.tokenizer() see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py - # (part of `transformers` lib) - token_ids_list = self.tokenizer( - texts, - truncation=True, - max_length=self.max_token_count, - return_overflowing_tokens=False, - padding='do_not_pad', - return_tensors=None, # just give me lists of ints - )['input_ids'] - - result = [] - for token_ids in token_ids_list: - # trim eos/bos - token_ids = token_ids[1:-1] - # pad for textual inversions with vector length >1 - if self.textual_inversion_manager is not None: - token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids) - # truncate if necessary to max_length-2 (leaving room for bos/eos) - token_ids = token_ids[0:self.max_token_count - 2] - # add back eos/bos if requested - if include_start_and_end_markers: - token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id] - - result.append(token_ids) - - return result - - def get_token_ids_and_expand_weights(self, fragments: List[str], weights: List[float], device: str) -> (torch.Tensor, torch.Tensor): - ''' - Given a list of text fragments and corresponding weights: tokenize each fragment, append the token sequences - together and return a padded token sequence starting with the bos marker, ending with the eos marker, and padded - or truncated as appropriate to `self.max_length`. Also return a list of weights expanded from the passed-in - weights to match each token. - - :param fragments: Text fragments to tokenize and concatenate. May be empty. - :param weights: Per-fragment weights (i.e. quasi-CFG scaling). Values from 0 to inf are permitted. In practise with SD1.5 - values >1.6 tend to produce garbage output. Must have same length as `fragment`. - :return: A tuple of tensors `(token_ids, weights)`. `token_ids` is ints, `weights` is floats, both have shape `[self.max_length]`. - ''' - if len(fragments) != len(weights): - raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(fragments)} != {len(weights)})") - - # empty is meaningful - if len(fragments) == 0: - fragments = [''] - weights = [1.0] - per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False) - all_token_ids = [] - per_token_weights = [] - #print("all fragments:", fragments, weights) - for this_fragment_token_ids, weight in zip(per_fragment_token_ids, weights): - # append - all_token_ids += this_fragment_token_ids - # fill out weights tensor with one float per token - per_token_weights += [float(weight)] * len(this_fragment_token_ids) - - # leave room for bos/eos - max_token_count_without_bos_eos_markers = self.max_token_count - 2 - if len(all_token_ids) > max_token_count_without_bos_eos_markers: - excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers - # TODO build nice description string of how the truncation was applied - # this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to - # self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit. - print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated") - all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers] - per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers] - - # pad out to a self.max_length-entry array: [bos_token, , eos_token, pad_token...] - # (typically self.max_length == 77) - all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id] - per_token_weights = [1.0] + per_token_weights + [1.0] - pad_length = self.max_token_count - len(all_token_ids) - all_token_ids += [self.tokenizer.pad_token_id] * pad_length - per_token_weights += [1.0] * pad_length - - all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device) - per_token_weights_tensor = torch.tensor(per_token_weights, - dtype=self.get_dtype_for_device(self.text_encoder.device), - device=device) - #print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}") - return all_token_ids_tensor, per_token_weights_tensor - - - def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor): - """ - Build a tensor that embeds the passed-in token IDs and applies the given per_token weights - :param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints) - :param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats) - :return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings - where `token_dim` is 768 for SD1 and 1280 for SD2. - """ - # print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") - if token_ids.shape != torch.Size([self.max_token_count]): - raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_token_count}]") - - text_encoder_output = self.text_encoder(token_ids.unsqueeze(0), return_dict=False, output_hidden_states=self.hidden_states_type == "penultimate") - - pooled = None - if self.hidden_states_type == "final" and not self.return_pooled: - z = text_encoder_output[0] if isinstance(self.text_encoder, CLIPTextModel) else text_encoder_output[1] - elif self.hidden_states_type == "final" and self.return_pooled: - z = text_encoder_output[0] if isinstance(self.text_encoder, CLIPTextModel) else text_encoder_output[1] - pooled = text_encoder_output[1] if isinstance(self.text_encoder, CLIPTextModel) else text_encoder_output[0] - elif self.hidden_states_type == "penultimate" and not self.return_pooled: - z = text_encoder_output[2][-2] - elif self.hidden_states_type == "penultimate" and self.return_pooled: - z = text_encoder_output[2][-2] - pooled = text_encoder_output[1] if isinstance(self.text_encoder, CLIPTextModel) else text_encoder_output[0] - - empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] + - [self.tokenizer.eos_token_id] + - [self.tokenizer.pad_token_id] * (self.max_token_count - 2), - dtype=torch.int, device=z.device).unsqueeze(0) - - empty_text_encoder_output = self.text_encoder(empty_token_ids, return_dict=False, output_hidden_states=self.hidden_states_type == "penultimate") - if self.hidden_states_type == "final": - empty_z = empty_text_encoder_output[0] if isinstance(self.text_encoder, CLIPTextModel) else empty_text_encoder_output[1] - elif self.hidden_states_type == "penultimate": - empty_z = empty_text_encoder_output[2][-2] - - batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape).to(z) - z_delta_from_empty = z - empty_z - weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) - - if self.return_pooled: - return weighted_z, pooled - - return weighted_z - - -class EmbeddingsProviderMulti: - - def __init__(self, - tokenizers: List[CLIPTokenizer], # converts strings to lists of int token ids - text_encoders: List[Union[CLIPTextModel, CLIPTextModelWithProjection]], # convert a list of int token ids to a tensor of embeddings - textual_inversion_manager: BaseTextualInversionManager = None, - dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32, - hidden_states_types: Union[str, List[str]] = "final", - return_pooled: Union[str, List[bool]] = False, - ): - - hidden_states_types = len(text_encoders) * [hidden_states_types] if not isinstance(hidden_states_types, (list, tuple)) else hidden_states_types - return_pooled = len(text_encoders) * [return_pooled] if not isinstance(return_pooled, (list, tuple)) else return_pooled - - self.embedding_providers = [ - EmbeddingsProvider(tokenizer, text_encoder, textual_inversion_manager, dtype_for_device_getter, hidden_states_type, pooled) - for tokenizer, text_encoder, hidden_states_type, pooled in zip(tokenizers, text_encoders, hidden_states_types, return_pooled) - ] - - @property - def text_encoder(self): - return self.embedding_providers[0].text_encoder - - @property - def tokenizer(self): - return self.embedding_providers[0].tokenizer - - def get_token_ids(self, *args, **kwargs): - # get token ids does not use padding. The padding ID is the only ID that can differ between tokenizers - # so for simplicity, we just return `get_token_ids` of the first tokenizer - return self.embedding_providers[0].get_token_ids(self, *args, **kwargs) - - def get_embeddings_for_weighted_prompt_fragments(self, - text_batch: List[List[str]], - fragment_weights_batch: List[List[float]], - should_return_tokens: bool = False, - device='cpu', - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - - outputs = [provider.get_embeddings_for_weighted_prompt_fragments(text_batch, fragment_weights_batch, should_return_tokens, device=device) for provider in self.embedding_providers] - - pooled_list = [] - text_embeddings_list = [] - tokens = [] - - for i, output in enumerate(outputs): - text_embeddings_list.append(output[0]) - - if should_return_tokens: - tokens.append(output[1]) - - if self.embedding_providers[i].return_pooled: - pooled_list.append(output[-1]) - - text_embeddings = torch.cat(text_embeddings_list, dim=-1) - pooled = torch.cat(pooled_list, dim=-1) if len(pooled_list) > 0 else None - - outputs = (text_embeddings,) - - if pooled is not None: - outputs += (pooled) - - if should_return_tokens: - outputs += (tokens) - - return outputs From 6e549d01d76689c96e8868dc14fe20d32f5629fd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 16:05:42 +0200 Subject: [PATCH 03/20] Apply suggestions from code review --- src/compel/embeddings_provider.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index 889b606..5e9bfec 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -183,15 +183,12 @@ def get_embeddings_for_weighted_prompt_fragments(self, # should have shape (B, 77, 768) #print(f"assembled all tokens into tensor of shape {batch_z.shape}") - outputs = (batch_z,) if should_return_tokens: - outputs += (batch_tokens,) - - if len(outputs) == 1: - return outputs[0] - - return outputs + if should_return_tokens: + return batch_z, batch_tokens + else: + return batch_z def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True) -> List[List[int]]: """ From 604eaf89dc7493b7775ddcfd2609240e572d8d4b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 16:06:21 +0200 Subject: [PATCH 04/20] Apply suggestions from code review --- src/compel/embeddings_provider.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index 5e9bfec..b715a79 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -114,7 +114,7 @@ def get_embeddings_for_weighted_prompt_fragments(self, # handle weights >=1 tokens, per_token_weights, mask = self.get_token_ids_and_expand_weights(fragments, weights, device=device) - base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, mask)[0] + base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, mask, device=device) # this is our starting point embeddings = base_embedding.unsqueeze(0) @@ -184,7 +184,6 @@ def get_embeddings_for_weighted_prompt_fragments(self, # should have shape (B, 77, 768) #print(f"assembled all tokens into tensor of shape {batch_z.shape}") - if should_return_tokens: if should_return_tokens: return batch_z, batch_tokens else: @@ -312,6 +311,8 @@ def build_weighted_embedding_tensor(self, return_pooled: bool = False, device: Optional[str] = None) -> torch.Tensor: """ + Build a tensor that embeds the passed-in token IDs and applies the given per_token weights + :param token_ids: A tensor of shape `n*[self.max_length]` containing token IDs (ints) where n is some arbitrary integer (i.e. n==1 for shorter prompts, or it may be >1 if there are more than max_length tokens in the original prompt) From ac1b78c4293a2bc0b7cf209b774de55cd471a50b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 14:08:28 +0000 Subject: [PATCH 05/20] merge conflict --- src/compel/embeddings_provider.py | 46 +++++++++++++++---------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index 889b606..ab0cf05 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -378,6 +378,29 @@ def build_weighted_embedding_tensor(self, return weighted_z + def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor]=None) -> torch.Tensor: + text_encoder_output = self.text_encoder(token_ids, + attention_mask, + output_hidden_states=self.use_penultimate_clip_layer, + return_dict=True) + + if self.requires_pooled: + pooled = text_encoder_output.pooler_output if "pooler_output" in text_encoder_output else text_encoder_output.text_embeds + else: + pooled = None + + if self.use_penultimate_clip_layer: + # needs normalizing + penultimate_hidden_state = text_encoder_output.hidden_states[-2] + + if self.use_penultimate_layer_norm: + penultimate_hidden_state = self.text_encoder.text_model.final_layer_norm(penultimate_hidden_state) + return (penultimate_hidden_state, pooled) + else: + # already normalized + return (text_encoder_output.last_hidden_state, pooled) + def _get_token_ranges_for_fragments(self, chunked_and_padded_token_ids: List[int], fragments: List[str]) -> List[Tuple[int, int]]: """ Match token id sequences for the strings in `fragments` with token id sequences in `chunked_and_padded_token_ids`, @@ -511,26 +534,3 @@ def get_embeddings_for_weighted_prompt_fragments(self, outputs += (pooled,) return outputs - - def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor]=None) -> torch.Tensor: - text_encoder_output = self.text_encoder(token_ids, - attention_mask, - output_hidden_states=self.use_penultimate_clip_layer, - return_dict=True) - - if self.requires_pooled: - pooled = text_encoder_output.pooler_output if "pooler_output" in text_encoder_output else text_encoder_output.text_embeds - else: - pooled = None - - if self.use_penultimate_clip_layer: - # needs normalizing - penultimate_hidden_state = text_encoder_output.hidden_states[-2] - - if self.use_penultimate_layer_norm: - penultimate_hidden_state = self.text_encoder.text_model.final_layer_norm(penultimate_hidden_state) - return (penultimate_hidden_state, pooled) - else: - # already normalized - return (text_encoder_output.last_hidden_state, pooled) From b1d54ce50d803cfeb41a454876d10afda0282634 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 14:09:15 +0000 Subject: [PATCH 06/20] merge conflict --- src/compel/embeddings_provider.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index 012bfce..ab0cf05 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -114,7 +114,7 @@ def get_embeddings_for_weighted_prompt_fragments(self, # handle weights >=1 tokens, per_token_weights, mask = self.get_token_ids_and_expand_weights(fragments, weights, device=device) - base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, mask, device=device) + base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, mask)[0] # this is our starting point embeddings = base_embedding.unsqueeze(0) @@ -183,11 +183,15 @@ def get_embeddings_for_weighted_prompt_fragments(self, # should have shape (B, 77, 768) #print(f"assembled all tokens into tensor of shape {batch_z.shape}") + outputs = (batch_z,) if should_return_tokens: - return batch_z, batch_tokens - else: - return batch_z + outputs += (batch_tokens,) + + if len(outputs) == 1: + return outputs[0] + + return outputs def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True) -> List[List[int]]: """ @@ -311,8 +315,6 @@ def build_weighted_embedding_tensor(self, return_pooled: bool = False, device: Optional[str] = None) -> torch.Tensor: """ - Build a tensor that embeds the passed-in token IDs and applies the given per_token weights - :param token_ids: A tensor of shape `n*[self.max_length]` containing token IDs (ints) where n is some arbitrary integer (i.e. n==1 for shorter prompts, or it may be >1 if there are more than max_length tokens in the original prompt) From ab15819491ed78ed7a49ebbe12a14df0e23c0e7f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 14:15:34 +0000 Subject: [PATCH 07/20] Fix naming --- src/compel/compel.py | 6 +++--- src/compel/embeddings_provider.py | 10 +++++----- test/test_compel.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/compel/compel.py b/src/compel/compel.py index a7c2b68..c8e86c2 100644 --- a/src/compel/compel.py +++ b/src/compel/compel.py @@ -30,7 +30,7 @@ def __init__(self, downweight_mode: DownweightMode = DownweightMode.MASK, use_penultimate_clip_layer: bool=False, device: Optional[str] = None): - return_pooled: Union[str, List[bool]] = False, + requires_pooled: Union[str, List[bool]] = False, """ Initialize Compel. The tokenizer and text_encoder can be lifted directly from any DiffusionPipeline. @@ -64,7 +64,7 @@ def __init__(self, padding_attention_mask_value = padding_attention_mask_value, downweight_mode=downweight_mode, use_penultimate_clip_layer=use_penultimate_clip_layer, - return_pooled=return_pooled, + requires_pooled=requires_pooled, ) else: self.conditioning_provider = EmbeddingsProvider(tokenizer=tokenizer, @@ -75,7 +75,7 @@ def __init__(self, padding_attention_mask_value = padding_attention_mask_value, downweight_mode=downweight_mode, use_penultimate_clip_layer=use_penultimate_clip_layer, - return_pooled=return_pooled, + requires_pooled=requires_pooled, ) self._device = device diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index ab0cf05..87c1772 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -312,7 +312,7 @@ def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - return_pooled: bool = False, + should_return_pooled: bool = False, device: Optional[str] = None) -> torch.Tensor: """ :param token_ids: A tensor of shape `n*[self.max_length]` containing token IDs (ints) where n is some arbitrary @@ -373,7 +373,7 @@ def build_weighted_embedding_tensor(self, chunk_start_index += chunk_size - if self.requires_pooled: + if should_return_pooled: return weighted_z, pooled return weighted_z @@ -473,15 +473,15 @@ def __init__(self, textual_inversion_manager: BaseTextualInversionManager = None, dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32, hidden_states_types: Union[str, List[str]] = "final", - return_pooled: Union[str, List[bool]] = False, + requires_pooled: Union[str, List[bool]] = False, ): hidden_states_types = len(text_encoders) * [hidden_states_types] if not isinstance(hidden_states_types, (list, tuple)) else hidden_states_types - return_pooled = len(text_encoders) * [return_pooled] if not isinstance(return_pooled, (list, tuple)) else return_pooled + requires_pooled = len(text_encoders) * [requires_pooled] if not isinstance(requires_pooled, (list, tuple)) else requires_pooled self.embedding_providers = [ EmbeddingsProvider(tokenizer, text_encoder, textual_inversion_manager, dtype_for_device_getter, hidden_states_type, pooled) - for tokenizer, text_encoder, hidden_states_type, pooled in zip(tokenizers, text_encoders, hidden_states_types, return_pooled) + for tokenizer, text_encoder, hidden_states_type, pooled in zip(tokenizers, text_encoders, hidden_states_types, requires_pooled) ] @property diff --git a/test/test_compel.py b/test/test_compel.py index 0924656..67744de 100644 --- a/test/test_compel.py +++ b/test/test_compel.py @@ -80,7 +80,7 @@ def test_basic_prompt_multi_text_encoder(self): tokenizer_2 = DummyTokenizer() text_encoder_2 = DummyTransformer() - compel = Compel(tokenizer=[tokenizer_1, tokenizer_2], text_encoder=[text_encoder_1, text_encoder_2], hidden_states_type="penultimate", return_pooled=[False, True]) + compel = Compel(tokenizer=[tokenizer_1, tokenizer_2], text_encoder=[text_encoder_1, text_encoder_2], hidden_states_type="penultimate", requires_pooled=[False, True]) # test "a b c" makes it to the Conditioning intact for t=0, t=0.5, t=1 prompt = " ".join(KNOWN_WORDS[:3]) From d58f429cd4aab1c88c60a13bd5fa9ddb12d5d90a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 14:44:25 +0000 Subject: [PATCH 08/20] Fix more --- src/compel/compel.py | 13 +++++-- src/compel/embeddings_provider.py | 64 ++++++++++++++----------------- 2 files changed, 37 insertions(+), 40 deletions(-) diff --git a/src/compel/compel.py b/src/compel/compel.py index c8e86c2..5d1b895 100644 --- a/src/compel/compel.py +++ b/src/compel/compel.py @@ -28,9 +28,11 @@ def __init__(self, truncate_long_prompts: bool = True, padding_attention_mask_value: int = 1, downweight_mode: DownweightMode = DownweightMode.MASK, - use_penultimate_clip_layer: bool=False, - device: Optional[str] = None): + use_penultimate_clip_layer: Union[bool, List[bool]]=False, + use_penultimate_layer_norm: Union[bool, List[bool]]=False, requires_pooled: Union[str, List[bool]] = False, + device: Optional[str] = None + ): """ Initialize Compel. The tokenizer and text_encoder can be lifted directly from any DiffusionPipeline. @@ -64,6 +66,7 @@ def __init__(self, padding_attention_mask_value = padding_attention_mask_value, downweight_mode=downweight_mode, use_penultimate_clip_layer=use_penultimate_clip_layer, + use_penultimate_layer_norm=use_penultimate_layer_norm, requires_pooled=requires_pooled, ) else: @@ -75,6 +78,7 @@ def __init__(self, padding_attention_mask_value = padding_attention_mask_value, downweight_mode=downweight_mode, use_penultimate_clip_layer=use_penultimate_clip_layer, + use_penultimate_layer_norm=use_penultimate_layer_norm, requires_pooled=requires_pooled, ) self._device = device @@ -102,6 +106,7 @@ def build_conditioning_tensor(self, text: str) -> torch.Tensor: building a conditioning tensor from that Conjunction. """ conjunction = self.parse_prompt_string(text) + import ipdb; ipdb.set_trace() conditioning, _ = self.build_conditioning_tensor_for_conjunction(conjunction) return conditioning @@ -236,9 +241,9 @@ def _get_conditioning_for_flattened_prompt(self, raise ValueError(f"embeddings can only be made from FlattenedPrompts, got {type(prompt).__name__} instead") fragments = [x.text for x in prompt.children] weights = [x.weight for x in prompt.children] - conditioning, tokens = self.conditioning_provider.get_embeddings_for_weighted_prompt_fragments( + conditioning, tokens, pooled = self.conditioning_provider.get_embeddings_for_weighted_prompt_fragments( text_batch=[fragments], fragment_weights_batch=[weights], - should_return_tokens=True, device=self.device) + should_return_tokens=True, should_return_pooled=True, device=self.device) outputs = (conditioning,) diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index 87c1772..6958f94 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -53,6 +53,7 @@ def __init__(self, self.padding_attention_mask_value = padding_attention_mask_value self.downweight_mode = downweight_mode self.use_penultimate_clip_layer = use_penultimate_clip_layer + self.use_penultimate_layer_norm = use_penultimate_layer_norm self.requires_pooled = requires_pooled @@ -193,7 +194,7 @@ def get_embeddings_for_weighted_prompt_fragments(self, return outputs - def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True) -> List[List[int]]: + def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True, padding: str = 'do_not_pad') -> List[List[int]]: """ Convert a list of strings like `["a cat", "a dog", "monkey riding a bicycle"]` into a list of lists of token ids like `[[bos, 0, 1, eos], [bos, 0, 2, eos], [bos, 3, 4, 0, 5, eos]]`. bos/eos markers are skipped if @@ -210,7 +211,7 @@ def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = token_ids_list = self.tokenizer( texts, truncation=self.truncate_to_model_max_length, - padding='do_not_pad', + padding=padding, return_tensors=None, # just give me lists of ints )['input_ids'] @@ -230,6 +231,16 @@ def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = return result + def maybe_get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]: + if not self.requires_pooled: + return None + + token_ids = self.get_token_ids(texts, padding="max_length") + text_encoder_output = self.text_encoder(token_ids, attention_mask, return_dict=True) + + pooled = text_encoder_output.pooler_output if "pooler_output" in text_encoder_output else text_encoder_output.text_embeds + return pooled + def get_token_ids_and_expand_weights(self, fragments: List[str], weights: List[float], device: str ) -> (torch.Tensor, torch.Tensor, torch.Tensor): ''' @@ -312,7 +323,6 @@ def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - should_return_pooled: bool = False, device: Optional[str] = None) -> torch.Tensor: """ :param token_ids: A tensor of shape `n*[self.max_length]` containing token IDs (ints) where n is some arbitrary @@ -337,9 +347,8 @@ def build_weighted_embedding_tensor(self, [self.tokenizer.eos_token_id] + [self.tokenizer.pad_token_id] * (self.max_token_count - 2), dtype=torch.int, device=device).unsqueeze(0) - empty_z, _ = self._encode_token_ids_to_embeddings(empty_token_ids) + empty_z = self._encode_token_ids_to_embeddings(empty_token_ids) weighted_z = None - pooled = None chunk_size = self.max_token_count while chunk_start_index < token_ids.shape[0]: @@ -352,7 +361,7 @@ def build_weighted_embedding_tensor(self, else None ) - z, this_pooled = self._encode_token_ids_to_embeddings(chunk_token_ids, chunk_attention_mask) + z = self._encode_token_ids_to_embeddings(chunk_token_ids, chunk_attention_mask) batch_weights_expanded = chunk_per_token_weights.reshape( chunk_per_token_weights.shape + (1,)).expand(z.shape).to(z) @@ -364,18 +373,8 @@ def build_weighted_embedding_tensor(self, else torch.cat([weighted_z, this_weighted_z], dim=1) ) - if pooled is not None: - pooled = ( - this_pooled - if pooled is None - else torch.mean([pooled, this_pooled], dim=1) - ) - chunk_start_index += chunk_size - if should_return_pooled: - return weighted_z, pooled - return weighted_z def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor, @@ -385,21 +384,16 @@ def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor, output_hidden_states=self.use_penultimate_clip_layer, return_dict=True) - if self.requires_pooled: - pooled = text_encoder_output.pooler_output if "pooler_output" in text_encoder_output else text_encoder_output.text_embeds - else: - pooled = None - if self.use_penultimate_clip_layer: # needs normalizing penultimate_hidden_state = text_encoder_output.hidden_states[-2] if self.use_penultimate_layer_norm: penultimate_hidden_state = self.text_encoder.text_model.final_layer_norm(penultimate_hidden_state) - return (penultimate_hidden_state, pooled) + return penultimate_hidden_state else: # already normalized - return (text_encoder_output.last_hidden_state, pooled) + return text_encoder_output.last_hidden_state def _get_token_ranges_for_fragments(self, chunked_and_padded_token_ids: List[int], fragments: List[str]) -> List[Tuple[int, int]]: """ @@ -497,40 +491,38 @@ def get_token_ids(self, *args, **kwargs): # so for simplicity, we just return `get_token_ids` of the first tokenizer return self.embedding_providers[0].get_token_ids(self, *args, **kwargs) + def maybe_get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]: + pooled = [provider.maybe_get_pooled(texts, attention_mask) for provider in self.embedding_providers] + pooled = [p for p in pooled if p is not None] + + if len(pooled) == 0: + return None + + return torch.cat(pooled, dim=-1) + def get_embeddings_for_weighted_prompt_fragments(self, text_batch: List[List[str]], fragment_weights_batch: List[List[float]], should_return_tokens: bool = False, - should_return_pooled: bool = False, device='cpu', ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - outputs = [provider.get_embeddings_for_weighted_prompt_fragments(text_batch, fragment_weights_batch, should_return_tokens=should_return_tokens, should_return_pooled=should_return_pooled, device=device) for provider in self.embedding_providers] + outputs = [provider.get_embeddings_for_weighted_prompt_fragments(text_batch, fragment_weights_batch, should_return_tokens=should_return_tokens, device=device) for provider in self.embedding_providers] - pooled_list = [] text_embeddings_list = [] tokens = [] - for i, output in enumerate(outputs): + for output in outputs: text_embeddings_list.append(output[0]) if should_return_tokens: tokens.append(output[1]) - if should_return_pooled: - pooled_list.append(output[-1]) - text_embeddings = torch.cat(text_embeddings_list, dim=-1) - pooled_list = [p for p in pooled_list if p is not None] - pooled = torch.cat(pooled_list, dim=-1) if len(pooled_list) > 0 else None - outputs = (text_embeddings,) if should_return_tokens: outputs += (tokens,) - if should_return_pooled: - outputs += (pooled,) - return outputs From cf22828e4603e7a6ebe40705840ac46e7a6751c4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 14:49:34 +0000 Subject: [PATCH 09/20] Correct more --- src/compel/compel.py | 12 ++---------- src/compel/embeddings_provider.py | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/src/compel/compel.py b/src/compel/compel.py index 5d1b895..da569e2 100644 --- a/src/compel/compel.py +++ b/src/compel/compel.py @@ -30,7 +30,6 @@ def __init__(self, downweight_mode: DownweightMode = DownweightMode.MASK, use_penultimate_clip_layer: Union[bool, List[bool]]=False, use_penultimate_layer_norm: Union[bool, List[bool]]=False, - requires_pooled: Union[str, List[bool]] = False, device: Optional[str] = None ): """ @@ -67,7 +66,6 @@ def __init__(self, downweight_mode=downweight_mode, use_penultimate_clip_layer=use_penultimate_clip_layer, use_penultimate_layer_norm=use_penultimate_layer_norm, - requires_pooled=requires_pooled, ) else: self.conditioning_provider = EmbeddingsProvider(tokenizer=tokenizer, @@ -79,7 +77,6 @@ def __init__(self, downweight_mode=downweight_mode, use_penultimate_clip_layer=use_penultimate_clip_layer, use_penultimate_layer_norm=use_penultimate_layer_norm, - requires_pooled=requires_pooled, ) self._device = device @@ -106,7 +103,6 @@ def build_conditioning_tensor(self, text: str) -> torch.Tensor: building a conditioning tensor from that Conjunction. """ conjunction = self.parse_prompt_string(text) - import ipdb; ipdb.set_trace() conditioning, _ = self.build_conditioning_tensor_for_conjunction(conjunction) return conditioning @@ -235,24 +231,20 @@ def pad_conditioning_tensors_to_same_length(self, conditionings: List[torch.Tens def _get_conditioning_for_flattened_prompt(self, prompt: FlattenedPrompt, should_return_tokens: bool=False, - should_return_pooled: bool=False ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if type(prompt) is not FlattenedPrompt: raise ValueError(f"embeddings can only be made from FlattenedPrompts, got {type(prompt).__name__} instead") fragments = [x.text for x in prompt.children] weights = [x.weight for x in prompt.children] - conditioning, tokens, pooled = self.conditioning_provider.get_embeddings_for_weighted_prompt_fragments( + conditioning, tokens = self.conditioning_provider.get_embeddings_for_weighted_prompt_fragments( text_batch=[fragments], fragment_weights_batch=[weights], - should_return_tokens=True, should_return_pooled=True, device=self.device) + should_return_tokens=True, device=self.device) outputs = (conditioning,) if should_return_tokens: outputs += (tokens,) - if should_return_pooled: - outputs += (pooled,) - if len(outputs) == 1: return outputs[0] diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index 6958f94..7c9ec40 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -31,7 +31,6 @@ def __init__(self, downweight_mode: DownweightMode = DownweightMode.MASK, use_penultimate_clip_layer: bool=False, use_penultimate_layer_norm: bool=True, - requires_pooled: bool = False, ): """ `tokenizer`: converts strings to lists of int token ids @@ -55,8 +54,6 @@ def __init__(self, self.use_penultimate_clip_layer = use_penultimate_clip_layer self.use_penultimate_layer_norm = use_penultimate_layer_norm - self.requires_pooled = requires_pooled - # by default always use float32 self.get_dtype_for_device = dtype_for_device_getter @@ -462,20 +459,23 @@ def _get_token_ranges_for_fragments(self, chunked_and_padded_token_ids: List[int class EmbeddingsProviderMulti: def __init__(self, - tokenizers: List[CLIPTokenizer], # converts strings to lists of int token ids - text_encoders: List[Union[CLIPTextModel, CLIPTextModelWithProjection]], # convert a list of int token ids to a tensor of embeddings + tokenizers: CLIPTokenizer, + text_encoders: Union[CLIPTextModel, CLIPTextModelWithProjection], # convert a list of int token ids to a tensor of embeddings textual_inversion_manager: BaseTextualInversionManager = None, dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32, - hidden_states_types: Union[str, List[str]] = "final", - requires_pooled: Union[str, List[bool]] = False, + truncate: bool = True, + padding_attention_mask_value: int = 1, + downweight_mode: DownweightMode = DownweightMode.MASK, + use_penultimate_clip_layer: List[bool]=False, + use_penultimate_layer_norm: List[bool]=True, ): - hidden_states_types = len(text_encoders) * [hidden_states_types] if not isinstance(hidden_states_types, (list, tuple)) else hidden_states_types - requires_pooled = len(text_encoders) * [requires_pooled] if not isinstance(requires_pooled, (list, tuple)) else requires_pooled + use_penultimate_clip_layer = len(text_encoders) * [use_penultimate_clip_layer] if not isinstance(use_penultimate_clip_layer, (list, tuple)) else use_penultimate_clip_layer + use_penultimate_layer_norm = len(text_encoders) * [use_penultimate_layer_norm] if not isinstance(use_penultimate_layer_norm, (list, tuple)) else use_penultimate_layer_norm self.embedding_providers = [ - EmbeddingsProvider(tokenizer, text_encoder, textual_inversion_manager, dtype_for_device_getter, hidden_states_type, pooled) - for tokenizer, text_encoder, hidden_states_type, pooled in zip(tokenizers, text_encoders, hidden_states_types, requires_pooled) + EmbeddingsProvider(tokenizer, text_encoder, textual_inversion_manager, dtype_for_device_getter, truncate, padding_attention_mask_value, downweight_mode, clip_layer, clip_norm) + for tokenizer, text_encoder, clip_layer, clip_norm in zip(tokenizers, text_encoders, use_penultimate_clip_layer, use_penultimate_layer_norm) ] @property From 974c69033014a2f8c1a3e41cf31a467b14cc22fd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 16:50:45 +0200 Subject: [PATCH 10/20] Apply suggestions from code review --- src/compel/compel.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/compel/compel.py b/src/compel/compel.py index da569e2..f16a770 100644 --- a/src/compel/compel.py +++ b/src/compel/compel.py @@ -189,6 +189,8 @@ def build_conditioning_tensor_for_prompt_object(self, prompt: Union[Blend, Flatt raise ValueError(f"unsupported prompt type: {type(prompt).__name__}") + + def pad_conditioning_tensors_to_same_length(self, conditionings: List[torch.Tensor], ) -> List[torch.Tensor]: """ @@ -230,7 +232,7 @@ def pad_conditioning_tensors_to_same_length(self, conditionings: List[torch.Tens def _get_conditioning_for_flattened_prompt(self, prompt: FlattenedPrompt, - should_return_tokens: bool=False, + should_return_tokens: bool=False ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if type(prompt) is not FlattenedPrompt: raise ValueError(f"embeddings can only be made from FlattenedPrompts, got {type(prompt).__name__} instead") @@ -239,16 +241,10 @@ def _get_conditioning_for_flattened_prompt(self, conditioning, tokens = self.conditioning_provider.get_embeddings_for_weighted_prompt_fragments( text_batch=[fragments], fragment_weights_batch=[weights], should_return_tokens=True, device=self.device) - - outputs = (conditioning,) - if should_return_tokens: - outputs += (tokens,) - - if len(outputs) == 1: - return outputs[0] - - return outputs + return conditioning, tokens + else: + return conditioning def _get_conditioning_for_blend(self, blend: Blend): conditionings_to_blend = [] From 1c731ae01bba1696747a7788d55416d566b56b44 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 16:51:22 +0200 Subject: [PATCH 11/20] Apply suggestions from code review --- src/compel/embeddings_provider.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index 7c9ec40..f78feb3 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -79,7 +79,7 @@ def get_embeddings_for_weighted_prompt_fragments(self, text_batch: List[List[str]], fragment_weights_batch: List[List[float]], should_return_tokens: bool = False, - device='cpu', + device='cpu' ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ @@ -97,7 +97,6 @@ def get_embeddings_for_weighted_prompt_fragments(self, batch_z = None batch_tokens = None - for fragments, weights in zip(text_batch, fragment_weights_batch): # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively @@ -112,7 +111,7 @@ def get_embeddings_for_weighted_prompt_fragments(self, # handle weights >=1 tokens, per_token_weights, mask = self.get_token_ids_and_expand_weights(fragments, weights, device=device) - base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, mask)[0] + base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, mask, device=device) # this is our starting point embeddings = base_embedding.unsqueeze(0) From 7c8ed72358e0a69cdf1d5973848df83baa5f171e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 16:52:12 +0200 Subject: [PATCH 12/20] Apply suggestions from code review --- src/compel/embeddings_provider.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index f78feb3..2c0060b 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -321,6 +321,8 @@ def build_weighted_embedding_tensor(self, attention_mask: Optional[torch.Tensor] = None, device: Optional[str] = None) -> torch.Tensor: """ + Build a tensor that embeds the passed-in token IDs and applies the given per_token weights + :param token_ids: A tensor of shape `n*[self.max_length]` containing token IDs (ints) where n is some arbitrary integer (i.e. n==1 for shorter prompts, or it may be >1 if there are more than max_length tokens in the original prompt) From 1386d380e1fe56cf1e0f8f100f19c9dd5fd49574 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 16:52:28 +0200 Subject: [PATCH 13/20] Update src/compel/embeddings_provider.py --- src/compel/embeddings_provider.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index 2c0060b..5a23727 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -180,7 +180,6 @@ def get_embeddings_for_weighted_prompt_fragments(self, # should have shape (B, 77, 768) #print(f"assembled all tokens into tensor of shape {batch_z.shape}") - outputs = (batch_z,) if should_return_tokens: outputs += (batch_tokens,) From e01cb5514d34be1199cfac7e143fa0f0d1cf95a9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 16:52:50 +0200 Subject: [PATCH 14/20] Apply suggestions from code review --- src/compel/embeddings_provider.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index 5a23727..68d73af 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -182,12 +182,9 @@ def get_embeddings_for_weighted_prompt_fragments(self, #print(f"assembled all tokens into tensor of shape {batch_z.shape}") if should_return_tokens: - outputs += (batch_tokens,) - - if len(outputs) == 1: - return outputs[0] - - return outputs + return batch_z, batch_tokens + else: + return batch_z def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True, padding: str = 'do_not_pad') -> List[List[int]]: """ From a4395a9074f5e9e6d5fa453aa9c334942a3df717 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 16:53:17 +0200 Subject: [PATCH 15/20] Apply suggestions from code review --- src/compel/embeddings_provider.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index 68d73af..d49a502 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -366,7 +366,6 @@ def build_weighted_embedding_tensor(self, if weighted_z is None else torch.cat([weighted_z, this_weighted_z], dim=1) ) - chunk_start_index += chunk_size return weighted_z @@ -377,7 +376,6 @@ def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor, attention_mask, output_hidden_states=self.use_penultimate_clip_layer, return_dict=True) - if self.use_penultimate_clip_layer: # needs normalizing penultimate_hidden_state = text_encoder_output.hidden_states[-2] From 94806364e68ebc8b88d609a302292b5780f9f35d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 16:53:42 +0200 Subject: [PATCH 16/20] Apply suggestions from code review --- test/prompting_test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/prompting_test_utils.py b/test/prompting_test_utils.py index 409c3b1..31ea3d2 100644 --- a/test/prompting_test_utils.py +++ b/test/prompting_test_utils.py @@ -87,6 +87,7 @@ def text_model(self): tm.final_layer_norm = nn.LayerNorm(normalized_shape=[self.text_model_max_length, self.embedding_length]) return tm + class DummyTokenizer(): def __init__(self, model_max_length=77): self.tokens = KNOWN_WORDS.copy() + ["<|bos|>", "<|pad|>", "<|eos|>"] From a43db80a2d41f63a0980b33f3abf4b08bd0202ab Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 15:03:41 +0000 Subject: [PATCH 17/20] Add to compel --- src/compel/compel.py | 23 ++++++++++++++++++----- src/compel/embeddings_provider.py | 9 +++------ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/compel/compel.py b/src/compel/compel.py index f16a770..7cae5c9 100644 --- a/src/compel/compel.py +++ b/src/compel/compel.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from transformers import CLIPTokenizer, CLIPTextModel +from transformers import CLIPTokenizer, CLIPTextModel, ConditionalDetrImageProcessor from . import cross_attention_control from .conditioning_scheduler import ConditioningScheduler, StaticConditioningScheduler @@ -84,7 +84,7 @@ def __init__(self, def device(self): return self._device if self._device else self.conditioning_provider.text_encoder.device - def make_conditioning_scheduler(self, positive_prompt: str, negative_prompt: str='') -> ConditioningScheduler: + def make_conditioning_scheduler(self, positive_prompt: str, negative_prompt: str=''): """ Return a ConditioningScheduler object that provides conditioning tensors for different diffusion steps (currently not fully implemented). @@ -97,17 +97,22 @@ def make_conditioning_scheduler(self, positive_prompt: str, negative_prompt: str return StaticConditioningScheduler(positive_conditioning=positive_conditioning, negative_conditioning=negative_conditioning) - def build_conditioning_tensor(self, text: str) -> torch.Tensor: + def build_conditioning_tensor(self, text: str, return_pooled: bool = False) -> torch.Tensor: """ Build a conditioning tensor by parsing the text for Compel syntax, constructing a Conjunction, and then building a conditioning tensor from that Conjunction. """ conjunction = self.parse_prompt_string(text) conditioning, _ = self.build_conditioning_tensor_for_conjunction(conjunction) + + if return_pooled: + pooled = self.conditioning_provider.get_pooled(text) + return conditioning, pooled + return conditioning @torch.no_grad() - def __call__(self, text: Union[str, List[str]]) -> torch.FloatTensor: + def __call__(self, text: Union[str, List[str]], return_pooled: False) -> torch.FloatTensor: """ Take a string or a list of strings and build conditioning tensors to match. @@ -119,12 +124,20 @@ def __call__(self, text: Union[str, List[str]]) -> torch.FloatTensor: text = [text] cond_tensor = [] + pooled = [] for text_input in text: - cond_tensor.append(self.build_conditioning_tensor(text_input)) + output = self.build_conditioning_tensor(text_input, return_pooled=return_pooled) + cond_tensor.append(output[0] if return_pooled else output) + + if return_pooled: + pooled.append(output[1]) cond_tensor = self.pad_conditioning_tensors_to_same_length(conditionings=cond_tensor) cond_tensor = torch.cat(cond_tensor) + if return_pooled: + cond_tensor, torch.cat(pooled) + return cond_tensor @classmethod diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index d49a502..107152d 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -223,10 +223,7 @@ def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = return result - def maybe_get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]: - if not self.requires_pooled: - return None - + def get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]: token_ids = self.get_token_ids(texts, padding="max_length") text_encoder_output = self.text_encoder(token_ids, attention_mask, return_dict=True) @@ -486,8 +483,8 @@ def get_token_ids(self, *args, **kwargs): # so for simplicity, we just return `get_token_ids` of the first tokenizer return self.embedding_providers[0].get_token_ids(self, *args, **kwargs) - def maybe_get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]: - pooled = [provider.maybe_get_pooled(texts, attention_mask) for provider in self.embedding_providers] + def get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]: + pooled = [provider.get_pooled(texts, attention_mask) if isinstance(provider, CLIPTextModelWithProjection) else None for provider in self.embedding_providers] pooled = [p for p in pooled if p is not None] if len(pooled) == 0: From 64e434f8973ec4962e14f1ad0b16f3df0b2193b7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 15:36:39 +0000 Subject: [PATCH 18/20] Get test working --- src/compel/compel.py | 25 ++++++++++++++++--------- src/compel/embeddings_provider.py | 25 +++++++++++++++++-------- test/prompting_test_utils.py | 12 ++++++++++-- test/test_compel.py | 5 +++-- 4 files changed, 46 insertions(+), 21 deletions(-) diff --git a/src/compel/compel.py b/src/compel/compel.py index 7cae5c9..a1a99cb 100644 --- a/src/compel/compel.py +++ b/src/compel/compel.py @@ -30,6 +30,7 @@ def __init__(self, downweight_mode: DownweightMode = DownweightMode.MASK, use_penultimate_clip_layer: Union[bool, List[bool]]=False, use_penultimate_layer_norm: Union[bool, List[bool]]=False, + requires_pooled: Union[bool, List[bool]]=False, device: Optional[str] = None ): """ @@ -66,6 +67,7 @@ def __init__(self, downweight_mode=downweight_mode, use_penultimate_clip_layer=use_penultimate_clip_layer, use_penultimate_layer_norm=use_penultimate_layer_norm, + requires_pooled=requires_pooled, ) else: self.conditioning_provider = EmbeddingsProvider(tokenizer=tokenizer, @@ -77,8 +79,10 @@ def __init__(self, downweight_mode=downweight_mode, use_penultimate_clip_layer=use_penultimate_clip_layer, use_penultimate_layer_norm=use_penultimate_layer_norm, + requires_pooled=requires_pooled, ) - self._device = device + + self._device = device @property def device(self): @@ -105,14 +109,15 @@ def build_conditioning_tensor(self, text: str, return_pooled: bool = False) -> t conjunction = self.parse_prompt_string(text) conditioning, _ = self.build_conditioning_tensor_for_conjunction(conjunction) - if return_pooled: - pooled = self.conditioning_provider.get_pooled(text) + pooled = self.conditioning_provider.maybe_get_pooled([text]) + + if return_pooled and pooled is not None: return conditioning, pooled return conditioning @torch.no_grad() - def __call__(self, text: Union[str, List[str]], return_pooled: False) -> torch.FloatTensor: + def __call__(self, text: Union[str, List[str]]) -> torch.FloatTensor: """ Take a string or a list of strings and build conditioning tensors to match. @@ -126,17 +131,19 @@ def __call__(self, text: Union[str, List[str]], return_pooled: False) -> torch.F cond_tensor = [] pooled = [] for text_input in text: - output = self.build_conditioning_tensor(text_input, return_pooled=return_pooled) - cond_tensor.append(output[0] if return_pooled else output) + output = self.build_conditioning_tensor(text_input, return_pooled=True) + + requires_pooled = len(output) > 1 + cond_tensor.append(output[0] if requires_pooled else output) - if return_pooled: + if requires_pooled: pooled.append(output[1]) cond_tensor = self.pad_conditioning_tensors_to_same_length(conditionings=cond_tensor) cond_tensor = torch.cat(cond_tensor) - if return_pooled: - cond_tensor, torch.cat(pooled) + if len(pooled) > 0: + return cond_tensor, torch.cat(pooled) return cond_tensor diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index 107152d..b0b3e83 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -31,6 +31,7 @@ def __init__(self, downweight_mode: DownweightMode = DownweightMode.MASK, use_penultimate_clip_layer: bool=False, use_penultimate_layer_norm: bool=True, + requires_pooled: bool=False, ): """ `tokenizer`: converts strings to lists of int token ids @@ -53,6 +54,7 @@ def __init__(self, self.downweight_mode = downweight_mode self.use_penultimate_clip_layer = use_penultimate_clip_layer self.use_penultimate_layer_norm = use_penultimate_layer_norm + self.requires_pooled = requires_pooled # by default always use float32 self.get_dtype_for_device = dtype_for_device_getter @@ -223,11 +225,16 @@ def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = return result - def get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]: + def maybe_get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]: + if not self.requires_pooled: + return None + token_ids = self.get_token_ids(texts, padding="max_length") + token_ids = torch.tensor(token_ids, dtype=torch.long).to(self.text_encoder.device) + text_encoder_output = self.text_encoder(token_ids, attention_mask, return_dict=True) + pooled = text_encoder_output.text_embeds - pooled = text_encoder_output.pooler_output if "pooler_output" in text_encoder_output else text_encoder_output.text_embeds return pooled def get_token_ids_and_expand_weights(self, fragments: List[str], weights: List[float], device: str @@ -458,16 +465,18 @@ def __init__(self, truncate: bool = True, padding_attention_mask_value: int = 1, downweight_mode: DownweightMode = DownweightMode.MASK, - use_penultimate_clip_layer: List[bool]=False, - use_penultimate_layer_norm: List[bool]=True, + use_penultimate_clip_layer: Union[List[bool], bool]=False, + use_penultimate_layer_norm: Union[List[bool], bool]=True, + requires_pooled: Union[List[bool], bool]=False, ): use_penultimate_clip_layer = len(text_encoders) * [use_penultimate_clip_layer] if not isinstance(use_penultimate_clip_layer, (list, tuple)) else use_penultimate_clip_layer use_penultimate_layer_norm = len(text_encoders) * [use_penultimate_layer_norm] if not isinstance(use_penultimate_layer_norm, (list, tuple)) else use_penultimate_layer_norm + requires_pooled = len(text_encoders) * [requires_pooled] if not isinstance(requires_pooled, (list, tuple)) else requires_pooled self.embedding_providers = [ - EmbeddingsProvider(tokenizer, text_encoder, textual_inversion_manager, dtype_for_device_getter, truncate, padding_attention_mask_value, downweight_mode, clip_layer, clip_norm) - for tokenizer, text_encoder, clip_layer, clip_norm in zip(tokenizers, text_encoders, use_penultimate_clip_layer, use_penultimate_layer_norm) + EmbeddingsProvider(tokenizer, text_encoder, textual_inversion_manager, dtype_for_device_getter, truncate, padding_attention_mask_value, downweight_mode, clip_layer, clip_norm, pooled) + for tokenizer, text_encoder, clip_layer, clip_norm, pooled in zip(tokenizers, text_encoders, use_penultimate_clip_layer, use_penultimate_layer_norm, requires_pooled) ] @property @@ -483,8 +492,8 @@ def get_token_ids(self, *args, **kwargs): # so for simplicity, we just return `get_token_ids` of the first tokenizer return self.embedding_providers[0].get_token_ids(self, *args, **kwargs) - def get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]: - pooled = [provider.get_pooled(texts, attention_mask) if isinstance(provider, CLIPTextModelWithProjection) else None for provider in self.embedding_providers] + def maybe_get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]: + pooled = [provider.maybe_get_pooled(texts, attention_mask) for provider in self.embedding_providers] pooled = [p for p in pooled if p is not None] if len(pooled) == 0: diff --git a/test/prompting_test_utils.py b/test/prompting_test_utils.py index 31ea3d2..99b525d 100644 --- a/test/prompting_test_utils.py +++ b/test/prompting_test_utils.py @@ -75,6 +75,10 @@ def __getitem__(self, item): def hidden_states(self): return [-self.last_hidden_state, self.last_hidden_state] + @property + def text_embeds(self): + return self.last_hidden_state[:, -1, :] + o = EmbeddingsObject(embeddings) return o @@ -108,8 +112,12 @@ def __call__(self, fragments, **kwargs): else x for x in tokenized] padding_strategy = kwargs.get('padding', 'do_not_pad') - if padding_strategy != 'do_not_pad': - raise Exception(f"for unit tests only 'do_not_pad' is supported as a padding strategy (got '{padding_strategy}')") + if padding_strategy not in ['do_not_pad', 'max_length']: + raise Exception(f"for unit tests only 'do_not_pad' and 'max_length' is supported as a padding strategy (got '{padding_strategy}')") + + if padding_strategy == "max_length": + tokenized = [(tokens[:-1] + (self.model_max_length - len(tokens)) * [self.pad_token_id] + tokens[1:]) for tokens in tokenized] + return {'input_ids': tokenized} def convert_tokens_to_ids(self, token_str): diff --git a/test/test_compel.py b/test/test_compel.py index 67744de..dac837f 100644 --- a/test/test_compel.py +++ b/test/test_compel.py @@ -80,13 +80,14 @@ def test_basic_prompt_multi_text_encoder(self): tokenizer_2 = DummyTokenizer() text_encoder_2 = DummyTransformer() - compel = Compel(tokenizer=[tokenizer_1, tokenizer_2], text_encoder=[text_encoder_1, text_encoder_2], hidden_states_type="penultimate", requires_pooled=[False, True]) + compel = Compel(tokenizer=[tokenizer_1, tokenizer_2], text_encoder=[text_encoder_1, text_encoder_2], use_penultimate_clip_layer=True, use_penultimate_layer_norm=False, requires_pooled=[False, True]) # test "a b c" makes it to the Conditioning intact for t=0, t=0.5, t=1 prompt = " ".join(KNOWN_WORDS[:3]) - output = compel(prompt) + output, pooled = compel(prompt) assert output.shape == (1, 77, 2 * 768) + assert pooled.shape == (1, 768) def test_basic_negative_prompt(self): From 8758cc7667129a1cd874c70dfe1b6ad2ec358c73 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 17:37:41 +0200 Subject: [PATCH 19/20] Apply suggestions from code review --- src/compel/compel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compel/compel.py b/src/compel/compel.py index a1a99cb..0351f3c 100644 --- a/src/compel/compel.py +++ b/src/compel/compel.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from transformers import CLIPTokenizer, CLIPTextModel, ConditionalDetrImageProcessor +from transformers import CLIPTokenizer, CLIPTextModel from . import cross_attention_control from .conditioning_scheduler import ConditioningScheduler, StaticConditioningScheduler @@ -88,7 +88,7 @@ def __init__(self, def device(self): return self._device if self._device else self.conditioning_provider.text_encoder.device - def make_conditioning_scheduler(self, positive_prompt: str, negative_prompt: str=''): + def make_conditioning_scheduler(self, positive_prompt: str, negative_prompt: str='') -> ConditioningScheduler: """ Return a ConditioningScheduler object that provides conditioning tensors for different diffusion steps (currently not fully implemented). From 20dcf45b5dc845b596202c94c3816325b6a525a7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 10 Jul 2023 17:38:26 +0200 Subject: [PATCH 20/20] Apply suggestions from code review --- test/prompting_test_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/prompting_test_utils.py b/test/prompting_test_utils.py index 99b525d..361ab0a 100644 --- a/test/prompting_test_utils.py +++ b/test/prompting_test_utils.py @@ -64,12 +64,8 @@ def __init__(self, last_hidden_state): self.last_hidden_state = last_hidden_state def __getitem__(self, item): - if item == 0: - return self.last_hidden_state[:, -1, :] - if item == 1: - return self.last_hidden_state - if item == 2: - return 2 * [self.last_hidden_state] + assert item == 0 + return self.last_hidden_state @property def hidden_states(self):