Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lycoris model card updates #820

Merged
merged 3 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--flow_matching_loss {diffusers,compatible,diffusion}]
[--pixart_sigma] [--sd3]
[--sd3_t5_mask_behaviour {do-nothing,mask}]
[--lora_type {Standard,lycoris}]
[--lora_type {standard,lycoris}]
[--lora_init_type {default,gaussian,loftq,olora,pissa}]
[--init_lora INIT_LORA] [--lora_rank LORA_RANK]
[--lora_alpha LORA_ALPHA] [--lora_dropout LORA_DROPOUT]
Expand Down Expand Up @@ -505,9 +505,9 @@ options:
prevents expansion of SD3 Medium's prompt length, as
it will unnecessarily attend to every token in the
prompt embed, even masked positions.
--lora_type {Standard,lycoris}
--lora_type {standard,lycoris}
When training using --model_type=lora, you may specify
a different type of LoRA to train here. Standard
a different type of LoRA to train here. standard
refers to training a vanilla LoRA via PEFT, lycoris
refers to training with KohakuBlueleaf's library of
the same name.
Expand Down
2 changes: 1 addition & 1 deletion helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def parse_args(input_args=None):
default="standard",
help=(
"When training using --model_type=lora, you may specify a different type of LoRA to train here."
" Standard refers to training a vanilla LoRA via PEFT, lycoris refers to training with KohakuBlueleaf's library of the same name."
" standard refers to training a vanilla LoRA via PEFT, lycoris refers to training with KohakuBlueleaf's library of the same name."
),
)
parser.add_argument(
Expand Down
158 changes: 138 additions & 20 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,106 @@
import os
import logging
import json
import torch
from helpers.training.state_tracker import StateTracker

logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))

licenses = {
"flux": "flux-1-dev-non-commercial-license",
"sdxl": "creativeml-openrail-m",
"legacy": "openrail++",
"pixart_sigma": "openrail++",
"kolors": "apache-2.0",
"smoldit": "apache-2.0",
"sd3": "stabilityai-ai-community",
}
allowed_licenses = [
"apache-2.0",
"mit",
"openrail",
"bigscience-openrail-m",
"creativeml-openrail-m",
"bigscience-bloom-rail-1.0",
"bigcode-openrail-m",
"afl-3.0",
"artistic-2.0",
"bsl-1.0",
"bsd",
"bsd-2-clause",
"bsd-3-clause",
"bsd-3-clause-clear",
"c-uda",
"cc",
"cc0-1.0",
"cc-by-2.0",
"cc-by-2.5",
"cc-by-3.0",
"cc-by-4.0",
"cc-by-sa-3.0",
"cc-by-sa-4.0",
"cc-by-nc-2.0",
"cc-by-nc-3.0",
"cc-by-nc-4.0",
"cc-by-nd-4.0",
"cc-by-nc-nd-3.0",
"cc-by-nc-nd-4.0",
"cc-by-nc-sa-2.0",
"cc-by-nc-sa-3.0",
"cc-by-nc-sa-4.0",
"cdla-sharing-1.0",
"cdla-permissive-1.0",
"cdla-permissive-2.0",
"wtfpl",
"ecl-2.0",
"epl-1.0",
"epl-2.0",
"etalab-2.0",
"eupl-1.1",
"agpl-3.0",
"gfdl",
"gpl",
"gpl-2.0",
"gpl-3.0",
"lgpl",
"lgpl-2.1",
"lgpl-3.0",
"isc",
"lppl-1.3c",
"ms-pl",
"apple-ascl",
"mpl-2.0",
"odc-by",
"odbl",
"openrail++",
"osl-3.0",
"postgresql",
"ofl-1.1",
"ncsa",
"unlicense",
"zlib",
"pddl",
"lgpl-lr",
"deepfloyd-if-license",
"llama2",
"llama3",
"llama3.1",
"gemma",
"unknown",
"other",
"array",
]
for _model, _license in licenses.items():
if _license not in allowed_licenses:
licenses[_model] = "other"


def _model_imports(args):
output = "import torch\n"
output += "from diffusers import DiffusionPipeline"
if "lycoris" == args.lora_type.lower() and "lora" in args.model_type:
output += "\nfrom lycoris import create_lycoris_from_weights"

return f"{output}"

Expand All @@ -18,17 +110,27 @@ def _model_load(args, repo_id: str = None):
if hf_user_name is not None:
repo_id = f"{hf_user_name}/{repo_id}" if hf_user_name else repo_id
if "lora" in args.model_type:
output = (
f"model_id = '{args.pretrained_model_name_or_path}'"
f"\nadapter_id = '{repo_id if repo_id is not None else args.output_dir}'"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id)"
f"\npipeline.load_lora_weights(adapter_id)"
)
if args.lora_type.lower() == "standard":
output = (
f"model_id = '{args.pretrained_model_name_or_path}'"
f"\nadapter_id = '{repo_id if repo_id is not None else args.output_dir}'"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id)"
f"\npipeline.load_lora_weights(adapter_id)"
)
elif args.lora_type.lower() == "lycoris":
output = (
f"model_id = '{args.pretrained_model_name_or_path}'"
f"\nadapter_id = 'pytorch_lora_weights.safetensors' # you will have to download this manually"
"\nlora_scale = 1.0"
)
else:
output = (
f"model_id = '{repo_id if repo_id else os.path.join(args.output_dir, 'pipeline')}'"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id)"
)
if args.model_type == "lora" and args.lora_type.lower() == "lycoris":
output += f"\nwrapper, _ = create_lycoris_from_weights(lora_scale, adapter_id, pipeline.transformer)"
output += "\nwrapper.merge_to()"

return output

Expand Down Expand Up @@ -77,7 +179,6 @@ def code_example(args, repo_id: str = None):

prompt = "{args.validation_prompt if args.validation_prompt else 'An astronaut is riding a horse through the jungles of Thailand.'}"
{_negative_prompt(args)}

pipeline.to({_torch_device()})
image = pipeline(
prompt=prompt,{_negative_prompt(args, in_call=True) if not args.flux else ''}
Expand All @@ -92,15 +193,32 @@ def code_example(args, repo_id: str = None):
return code_example


def model_type(args):
if "lora" in args.model_type:
if "standard" == args.lora_type.lower():
return "standard PEFT LoRA"
if "lycoris" == args.lora_type.lower():
return "LyCORIS adapter"
else:
return "full rank finetune"


def lora_info(args):
"""Return a string with the LORA information."""
if "lora" not in args.model_type:
return ""
return f"""- LoRA Rank: {args.lora_rank}
if args.lora_type.lower() == "standard":
return f"""- LoRA Rank: {args.lora_rank}
- LoRA Alpha: {args.lora_alpha}
- LoRA Dropout: {args.lora_dropout}
- LoRA initialisation style: {args.lora_init_type}
"""
"""
if args.lora_type.lower() == "lycoris":
lycoris_config_file = args.lycoris_config
# read the json file
with open(lycoris_config_file, "r") as file:
lycoris_config = json.load(file)
return f"""- LyCORIS Config:\n```json\n{json.dumps(lycoris_config, indent=4)}\n```"""


def model_card_note(args):
Expand Down Expand Up @@ -169,30 +287,30 @@ def save_model_card(
sub_idx += 1

shortname_idx += 1
args = StateTracker.get_args()
yaml_content = f"""---
license: creativeml-openrail-m
license: {licenses[StateTracker.get_model_type()]}
base_model: "{base_model}"
tags:
- {'stable-diffusion' if 'deepfloyd' not in StateTracker.get_args().model_type else 'deepfloyd-if'}
- {'stable-diffusion-diffusers' if 'deepfloyd' not in StateTracker.get_args().model_type else 'deepfloyd-if-diffusers'}
- {StateTracker.get_model_type()}
- {f'{StateTracker.get_model_type()}-diffusers' if 'deepfloyd' not in args.model_type else 'deepfloyd-if-diffusers'}
- text-to-image
- diffusers
- simpletuner
- {StateTracker.get_args().model_type}
{' - template:sd-lora' if 'lora' in StateTracker.get_args().model_type else ''}
- {args.model_type}
{' - template:sd-lora' if 'lora' in args.model_type else ''}
inference: true
{widget_str}
---

"""
model_card_content = f"""# {repo_id}

This is a {'LoRA' if 'lora' in StateTracker.get_args().model_type else 'full rank finetune'} derived from [{base_model}](https://huggingface.co/{base_model}).

{'This is a **diffusion** model trained using DDPM objective instead of Flow matching. **Be sure to set the appropriate scheduler configuration.**' if StateTracker.get_args().sd3 and StateTracker.get_args().flow_matching_loss == "diffusion" else ''}
This is a {model_type(args)} derived from [{base_model}](https://huggingface.co/{base_model}).

{'The main validation prompt used during training was:' if prompt else 'Validation used ground-truth images as an input for partial denoising (img2img).' if StateTracker.get_args().validation_using_datasets else 'No validation prompt was used during training.'}
{model_card_note(StateTracker.get_args())}
{'This is a **diffusion** model trained using DDPM objective instead of Flow matching. **Be sure to set the appropriate scheduler configuration.**' if args.sd3 and args.flow_matching_loss == "diffusion" else ''}
{'The main validation prompt used during training was:' if prompt else 'Validation used ground-truth images as an input for partial denoising (img2img).' if args.validation_using_datasets else 'No validation prompt was used during training.'}
{model_card_note(args)}
{'```' if prompt else ''}
{prompt}
{'```' if prompt else ''}
Expand Down Expand Up @@ -227,7 +345,7 @@ def save_model_card(
- Prediction type: {'flow-matching' if (StateTracker.get_args().sd3 or StateTracker.get_args().flux) else StateTracker.get_args().prediction_type}
- Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr}
- Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''}
- Precision: {'Pure BF16' if StateTracker.get_args().adam_bfloat16 else StateTracker.get_args().mixed_precision}
- Precision: {'Pure BF16' if (StateTracker.get_args().adam_bfloat16 or torch.backends.mps.is_available()) else StateTracker.get_args().mixed_precision}
- Quantised: {f'Yes: {StateTracker.get_args().base_model_precision}' if StateTracker.get_args().base_model_precision != "no_change" else 'No'}
- Xformers: {'Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else 'Not used'}
{lora_info(args=StateTracker.get_args())}
Expand Down
5 changes: 5 additions & 0 deletions helpers/training/save_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ def _save_lycoris(self, models, weights, output_dir):
{"lycoris_config": json.dumps(lycoris_config)}, # metadata
)

# copy the config into the repo
shutil.copy2(
self.args.lycoris_config, os.path.join(output_dir, "lycoris_config.json")
)

logger.info("LyCORIS weights have been saved to disk")

def _save_full_model(self, models, weights, output_dir):
Expand Down
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2273,6 +2273,10 @@ def main():
].dtype,
{"lycoris_config": json.dumps(lycoris_config)}, # metadata
)
shutil.copy2(
args.lycoris_config,
os.path.join(args.output_dir, "lycoris_config.json"),
)

elif args.use_ema:
if unet is not None:
Expand Down
Loading