Skip to content

Commit

Permalink
Merge pull request #1016 from Disty0/dev
Browse files Browse the repository at this point in the history
Fix convert_diffusers20_original_sd.py and add --variant option for loading
  • Loading branch information
kohya-ss authored Dec 24, 2023
2 parents 11ed8e2 + 7080e1a commit 9a2e385
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions tools/convert_diffusers20_original_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def convert(args):
is_load_ckpt = os.path.isfile(args.model_to_load)
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0

assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
# assert (
# is_save_ckpt or args.reference_model is not None
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
Expand All @@ -37,7 +37,7 @@ def convert(args):
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection)
else:
pipe = StableDiffusionPipeline.from_pretrained(
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
)
text_encoder = pipe.text_encoder
vae = pipe.vae
Expand All @@ -57,15 +57,15 @@ def convert(args):
if is_save_ckpt:
original_model = args.model_to_load if is_load_ckpt else None
key_count = model_util.save_stable_diffusion_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae
v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, None if args.metadata is None else eval(args.metadata), save_dtype=save_dtype, vae=vae
)
print(f"model saved. total converted state_dict keys: {key_count}")
else:
print(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}")
model_util.save_diffusers_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
)
print(f"model saved.")
print("model saved.")


def setup_parser() -> argparse.ArgumentParser:
Expand Down Expand Up @@ -99,6 +99,18 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
)
parser.add_argument(
"--metadata",
type=str,
default=None,
help='metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="variant: Diffusers variant to load. Example: fp16",
)
parser.add_argument(
"--reference_model",
type=str,
Expand Down

0 comments on commit 9a2e385

Please sign in to comment.