Skip to content

Commit

Permalink
Let tokenizers return weights to be stored in the saved checkpoint.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jul 25, 2024
1 parent 10c919f commit f87810c
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 1 deletion.
6 changes: 5 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ def load_sd(self, sd, full_model=False):
return self.cond_stage_model.load_sd(sd)

def get_sd(self):
return self.cond_stage_model.state_dict()
sd_clip = self.cond_stage_model.state_dict()
sd_tokenizer = self.tokenizer.state_dict()
for k in sd_tokenizer:
sd_clip[k] = sd_tokenizer[k]
return sd_clip

def load_model(self):
model_management.load_model_gpu(self.patcher)
Expand Down
4 changes: 4 additions & 0 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,8 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))

def state_dict(self):
return {}

class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
Expand All @@ -534,6 +536,8 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
def untokenize(self, token_weight_pair):
return getattr(self, self.clip).untokenize(token_weight_pair)

def state_dict(self):
return {}

class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions comfy/sdxl_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)

def state_dict(self):
return {}

class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
Expand Down
3 changes: 3 additions & 0 deletions comfy/text_encoders/sd3_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)

def state_dict(self):
return {}

class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
super().__init__()
Expand Down

0 comments on commit f87810c

Please sign in to comment.