Skip to content

Commit

Permalink
Fix potential issue with non clip text embeddings.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jul 30, 2024
1 parent 25853d0 commit 82cae45
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 9 deletions.
2 changes: 1 addition & 1 deletion comfy/clip_config_bigg.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 2,
"eos_token_id": 49407,
"hidden_act": "gelu",
"hidden_size": 1280,
"initializer_factor": 1.0,
Expand Down
3 changes: 2 additions & 1 deletion comfy/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self, config_dict, dtype, device, operations):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
self.eos_token_id = config_dict["eos_token_id"]

super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
Expand All @@ -111,7 +112,7 @@ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, f
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)

pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
return x, i, pooled_output

class CLIPTextModel(torch.nn.Module):
Expand Down
7 changes: 2 additions & 5 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,13 @@ def reset_clip_options(self):

def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
next_new_token = token_dict_size = current_embeds.weight.shape[0]
embedding_weights = []

for x in tokens:
tokens_temp = []
for y in x:
if isinstance(y, numbers.Integral):
if y == token_dict_size: #EOS token
y = -1
tokens_temp += [int(y)]
else:
if y.shape[0] == current_embeds.weight.shape[1]:
Expand All @@ -164,11 +162,10 @@ def set_up_textual_embeddings(self, tokens, current_embeds):
n = token_dict_size
if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
new_embedding.weight[:token_dict_size] = current_embeds.weight
for x in embedding_weights:
new_embedding.weight[n] = x
n += 1
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
self.transformer.set_input_embeddings(new_embedding)

processed_tokens = []
Expand Down
2 changes: 1 addition & 1 deletion comfy/sd1_clip_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 2,
"eos_token_id": 49407,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
Expand Down
2 changes: 1 addition & 1 deletion comfy/text_encoders/sd2_clip_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 2,
"eos_token_id": 49407,
"hidden_act": "gelu",
"hidden_size": 1024,
"initializer_factor": 1.0,
Expand Down

0 comments on commit 82cae45

Please sign in to comment.