diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index b9365b519..a6774328e 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -23,7 +23,7 @@ def convert(args): is_load_ckpt = os.path.isfile(args.model_to_load) is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 - assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" + assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" # assert ( # is_save_ckpt or args.reference_model is not None # ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" @@ -37,7 +37,7 @@ def convert(args): text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection) else: pipe = StableDiffusionPipeline.from_pretrained( - args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None + args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant ) text_encoder = pipe.text_encoder vae = pipe.vae @@ -57,7 +57,7 @@ def convert(args): if is_save_ckpt: original_model = args.model_to_load if is_load_ckpt else None key_count = model_util.save_stable_diffusion_checkpoint( - v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae + v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, None if args.metadata is None else eval(args.metadata), save_dtype=save_dtype, vae=vae ) print(f"model saved. total converted state_dict keys: {key_count}") else: @@ -65,7 +65,7 @@ def convert(args): model_util.save_diffusers_checkpoint( v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors ) - print(f"model saved.") + print("model saved.") def setup_parser() -> argparse.ArgumentParser: @@ -99,6 +99,18 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値" ) + parser.add_argument( + "--metadata", + type=str, + default=None, + help='metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'', + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="variant: Diffusers variant to load. Example: fp16", + ) parser.add_argument( "--reference_model", type=str,