Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[from_single_file] fix: tokenizer and config loading #5439

Closed
wants to merge 73 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
43e4e84
add: workflows.
sayakpaul Aug 29, 2023
ac29505
add unifinished implementation of _update_call()
sayakpaul Aug 29, 2023
e8e09e4
Apply suggestions from code review
sayakpaul Aug 30, 2023
a8a1378
fix
sayakpaul Oct 15, 2023
a62b77f
include todos.
sayakpaul Oct 15, 2023
d5d31e0
add: support for lora.
sayakpaul Oct 15, 2023
e3611e3
properly set lora_info
sayakpaul Oct 15, 2023
96c55d4
fix
sayakpaul Oct 15, 2023
5f19b66
resolve conflicts
sayakpaul Oct 15, 2023
29d0aa8
remove components from workflows.
sayakpaul Oct 17, 2023
ef94a00
handle torch.tensor.
sayakpaul Oct 17, 2023
d8e6f38
change method desc.
sayakpaul Oct 17, 2023
ba0b1e8
improve docstring
sayakpaul Oct 17, 2023
a6a0277
include pipeline name in the workflow
sayakpaul Oct 17, 2023
1ab81a6
update progress.
sayakpaul Oct 17, 2023
97ae043
update docstrings.
sayakpaul Oct 17, 2023
ad72597
patch call.
sayakpaul Oct 17, 2023
807c2ca
debug
sayakpaul Oct 17, 2023
930ca76
debug
sayakpaul Oct 17, 2023
0bd9773
debug
sayakpaul Oct 17, 2023
50769e0
remove torch tensor warning as it might complicate things
sayakpaul Oct 17, 2023
e710121
save_pretrained() to workflow so that it has push_to_hub
sayakpaul Oct 17, 2023
2b48d85
save_pretrained() to workflow so that it has push_to_hub
sayakpaul Oct 17, 2023
45c5656
make config_name a part of the dict.
sayakpaul Oct 17, 2023
2d1cd20
stronger check
sayakpaul Oct 17, 2023
b149800
remove unneeded comment
sayakpaul Oct 17, 2023
73dcc17
remove unneeded comment
sayakpaul Oct 17, 2023
eff03fd
override method
sayakpaul Oct 17, 2023
7b85bfe
fix: signature
sayakpaul Oct 17, 2023
fc609e3
more fix
sayakpaul Oct 17, 2023
aa7839c
pop from internal dict too.
sayakpaul Oct 17, 2023
e590b73
override pop too for feature compatibility
sayakpaul Oct 17, 2023
f6c0878
callables should not be serialized too.
sayakpaul Oct 17, 2023
3194560
seed.
sayakpaul Oct 17, 2023
c5ff8cd
debug
sayakpaul Oct 17, 2023
f08f40b
debug
sayakpaul Oct 17, 2023
b5fd337
debug
sayakpaul Oct 17, 2023
9ee8b0a
debug
sayakpaul Oct 17, 2023
9d0bcd4
debug
sayakpaul Oct 17, 2023
21d19bb
workflow_filename -> filename
sayakpaul Oct 17, 2023
91c1c1f
apply styling
sayakpaul Oct 17, 2023
800b7a0
morr
sayakpaul Oct 17, 2023
3d6637b
partial
sayakpaul Oct 17, 2023
af282b7
debug
sayakpaul Oct 17, 2023
9adaa17
debug
sayakpaul Oct 17, 2023
4731f65
debug.
sayakpaul Oct 17, 2023
a69e3d1
debug.
sayakpaul Oct 17, 2023
ed9acd6
remove print
sayakpaul Oct 17, 2023
4993c8b
Merge branch 'main' into feat/workflows
sayakpaul Oct 17, 2023
18a756f
style.
sayakpaul Oct 17, 2023
1dc9854
copying helps?
sayakpaul Oct 17, 2023
d612b54
debug
sayakpaul Oct 17, 2023
eaae2df
debugging.
sayakpaul Oct 17, 2023
55c47bc
let's see
sayakpaul Oct 18, 2023
452bf4f
hmm almost
sayakpaul Oct 18, 2023
74e766c
quality
sayakpaul Oct 18, 2023
21e5bb6
replace the __call__ attribute of the class, not the instance
sayakpaul Oct 18, 2023
231e831
feat: support passing filename
sayakpaul Oct 18, 2023
c0e1c63
Merge branch 'main' into feat/workflows
sayakpaul Oct 18, 2023
f874578
fix: lora population.
sayakpaul Oct 18, 2023
49e06fd
fix: lora population.
sayakpaul Oct 18, 2023
c1c11a6
fix: lora
sayakpaul Oct 18, 2023
ff5cd58
support basic lora only for non-peft for now.
sayakpaul Oct 18, 2023
03bfdff
Empty-Commit
sayakpaul Oct 18, 2023
ad86788
debug
sayakpaul Oct 18, 2023
c30ffd8
debug
sayakpaul Oct 18, 2023
2521dde
debug
sayakpaul Oct 18, 2023
f744da8
debug
sayakpaul Oct 18, 2023
2f53daf
debug
sayakpaul Oct 18, 2023
98f9877
debug
sayakpaul Oct 18, 2023
85ffaea
debug
sayakpaul Oct 18, 2023
e8f774c
debug
sayakpaul Oct 18, 2023
21938aa
fix: config loading
sayakpaul Oct 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
os.makedirs(save_directory, exist_ok=True)

# If we save using the predefined names, we can load using `from_config`
output_config_file = os.path.join(save_directory, self.config_name)
filename = kwargs.pop("filename", None)
if filename is not None:
config_name = filename
else:
config_name = self.config_name
output_config_file = os.path.join(save_directory, config_name)

self.to_json_file(output_config_file)
logger.info(f"Configuration saved in {output_config_file}")
Expand Down
17 changes: 17 additions & 0 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,8 @@ class LoraLoaderMixin:
"""
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
loras_loaded = 0
lora_info = {}
num_fused_loras = 0

def load_lora_weights(
Expand Down Expand Up @@ -1224,6 +1226,11 @@ def load_lora_weights(
adapter_name=adapter_name,
_pipeline=self,
)
if not USE_PEFT_BACKEND:
self.loras_loaded += 1
current_lora_info = {"pretrained_model_name_or_path_or_dict": pretrained_model_name_or_path_or_dict}
current_lora_info.update(dict(kwargs.items()))
self.lora_info.update({f"lora_{self.loras_loaded}": current_lora_info})

@classmethod
def lora_state_dict(
Expand Down Expand Up @@ -2256,6 +2263,15 @@ def unload_lora_weights(self):
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()

# Housekeeping.
# TODO: handle for PEFT backend because adapters can be combined, offloaded, etc.
# TODO: handle `fuse_lora()` and `unfuse_lora()` cases.
if not USE_PEFT_BACKEND:
self.loras_loaded -= 1
keys = list(self.lora_info.keys())
keys.sort()
self.lora_info.pop(keys[-1])

def fuse_lora(
self,
fuse_unet: bool = True,
Expand Down Expand Up @@ -2832,6 +2848,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
tokenizer=tokenizer,
original_config_file=original_config_file,
config_files=config_files,
local_files_only=local_files_only,
)

if torch_dtype is not None:
Expand Down
53 changes: 53 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import sys
import warnings
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -56,7 +57,9 @@
logging,
numpy_to_pil,
)
from ..utils.constants import WORKFLOW_NAME
from ..utils.torch_utils import is_compiled_module
from ..workflow_utils import _NON_CALL_ARGUMENTS


if is_transformers_available():
Expand All @@ -66,6 +69,7 @@
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME


from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin


Expand Down Expand Up @@ -1977,3 +1981,52 @@ def set_attention_slice(self, slice_size: Optional[int]):

for module in modules:
module.set_attention_slice(slice_size)

def load_workflow(self, workflow_id_or_path: Union[str, dict], filename: Optional[str] = None):
r"""Loads a workflow from the Hub or from a local path. Also patches the pipeline call arguments with values from the
workflow.

Args:
workflow_id_or_path (`str` or `dict`):
Can be either:

- A string, the workflow id (for example `sayakpaul/sdxl-workflow`) of a workflow hosted on the
Hub.
- A path to a directory (for example `./my_workflow_directory`) containing the workflow file with
[`Workflow.save_workflow`] or [`Workflow.push_to_hub`].
- A Python dictionary.

filename (`str`, *optional*):
Optional name of the workflow file to load. Especially useful when working with multiple workflow
files.
"""
filename = filename or WORKFLOW_NAME

# Load workflow.
if not isinstance(workflow_id_or_path, dict):
if os.path.isdir(workflow_id_or_path):
workflow_filepath = os.path.join(workflow_id_or_path, filename)
elif os.path.isfile(workflow_id_or_path):
workflow_filepath = workflow_id_or_path
else:
workflow_filepath = hf_hub_download(repo_id=workflow_id_or_path, filename=filename)
workflow = self._dict_from_json_file(workflow_filepath)
else:
workflow = workflow_id_or_path

# Handle generator.
seed = workflow.pop("seed", None)
if seed is not None:
generator = torch.manual_seed(seed)
else:
generator = None
workflow.update({"generator": generator})

# Handle non-call arguments.
# Note: Instead of popping the non-call arguments off, it's better to keep them in
# the workflow object should it be reused.
final_call_args = {k: v for k, v in workflow.items() if k not in _NON_CALL_ARGUMENTS}

# Handle the call here.
partial_call = partial(self.__call__, **final_call_args)
setattr(self.__class__, "__call__", partial_call)
112 changes: 85 additions & 27 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,12 @@ def _copy_layers(hf_layers, pt_layers):
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
if text_encoder is None:
config_name = "openai/clip-vit-large-patch14"
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
try:
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
)

ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
Expand Down Expand Up @@ -922,7 +927,12 @@ def convert_open_clip_checkpoint(
# text_model = CLIPTextModelWithProjection.from_pretrained(
# "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
# )
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
try:
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'."
)

ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
Expand Down Expand Up @@ -1211,7 +1221,6 @@ def download_from_original_stable_diffusion_ckpt(
- `xl_refiner`: Config file for Stable Diffusion XL Refiner
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""

# import pipelines here to avoid circular import error when using from_single_file method
from diffusers import (
LDMTextToImagePipeline,
Expand Down Expand Up @@ -1464,11 +1473,19 @@ def download_from_original_stable_diffusion_ckpt(
config_name = "stabilityai/stable-diffusion-2"
config_kwargs = {"subfolder": "text_encoder"}

text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only
text_model = convert_open_clip_checkpoint(
checkpoint, config_name, **config_kwargs, local_files_only=local_files_only
)

try:
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'."
)

if stable_unclip is None:
if controlnet:
pipe = pipeline_class(
Expand Down Expand Up @@ -1545,10 +1562,14 @@ def download_from_original_stable_diffusion_ckpt(
prior = PriorTransformer.from_pretrained(
karlo_model, subfolder="prior", local_files_only=local_files_only
)

prior_tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
)
try:
prior_tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
)
prior_text_model = CLIPTextModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
)
Expand Down Expand Up @@ -1581,7 +1602,14 @@ def download_from_original_stable_diffusion_ckpt(
raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}")
elif model_type == "PaintByExample":
vision_model = convert_paint_by_example_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
try:
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
Expand All @@ -1597,11 +1625,16 @@ def download_from_original_stable_diffusion_ckpt(
text_model = convert_ldm_clip_checkpoint(
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
)
tokenizer = (
CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
if tokenizer is None
else tokenizer
)
try:
tokenizer = (
CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
if tokenizer is None
else tokenizer
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
)

if load_safety_checker:
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
Expand Down Expand Up @@ -1637,18 +1670,33 @@ def download_from_original_stable_diffusion_ckpt(
)
elif model_type in ["SDXL", "SDXL-Refiner"]:
if model_type == "SDXL":
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
)
try:
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
)
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
tokenizer_2 = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
)
try:
tokenizer_2 = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
)

config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
checkpoint,
config_name,
prefix="conditioner.embedders.1.model.",
has_projection=True,
local_files_only=local_files_only,
**config_kwargs,
)

if is_accelerate_available(): # SBM Now move model to cpu.
Expand Down Expand Up @@ -1682,14 +1730,24 @@ def download_from_original_stable_diffusion_ckpt(
else:
tokenizer = None
text_encoder = None
tokenizer_2 = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
)
try:
tokenizer_2 = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
)

config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs
checkpoint,
config_name,
prefix="conditioner.embedders.0.model.",
has_projection=True,
local_files_only=local_files_only,
**config_kwargs,
)

if is_accelerate_available(): # SBM Now move model to cpu.
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/pipelines/stable_diffusion/pipeline_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ class StableDiffusionPipelineOutput(BaseOutput):
nsfw_content_detected (`List[bool]`)
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
`None` if safety checking could not be performed.
workflow (`dict`):
Dictionary containing pipeline component configurations and call arguments
"""

images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
workflow: dict


if is_flax_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...workflow_utils import populate_workflow_from_pipeline
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -592,6 +593,7 @@ def __call__(
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
return_workflow: bool = False,
clip_skip: Optional[int] = None,
):
r"""
Expand Down Expand Up @@ -649,6 +651,8 @@ def __call__(
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
return_workflow(`bool`, *optional*, defaults to `False`):
Whether to return pipeline component configurations and call arguments.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
Expand Down Expand Up @@ -779,7 +783,21 @@ def __call__(
# Offload all models
self.maybe_free_model_hooks()

workflow = None
if return_workflow:
signature = inspect.signature(self.__call__)
argument_names = [param.name for param in signature.parameters.values()]
call_arg_values = inspect.getargvalues(inspect.currentframe()).locals
workflow = populate_workflow_from_pipeline(
argument_names, call_arg_values, self.lora_info, self.config._name_or_path
)

if not return_dict:
return (image, has_nsfw_concept)
outputs = (image, has_nsfw_concept)

if workflow is not None:
outputs += (workflow,)

return outputs

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, workflow=workflow)
1 change: 1 addition & 0 deletions src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
WORKFLOW_NAME = "diffusion_workflow.json"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
Expand Down
Loading
Loading