Skip to content

Commit

Permalink
Merge pull request #20 from yyua8222/yy01071
Browse files Browse the repository at this point in the history
update the args variable for checkpoint path
  • Loading branch information
haoheliu authored Nov 26, 2023
2 parents 702d96a + b15b551 commit 58cd823
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion audioldm_train/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ def infer(dataset_json, configs, config_yaml_path, exp_group_name, exp_name):
required=False,
help="The filelist that contain captions (and optionally filenames)",
)
parser.add_argument(
"-reload_from_ckpt",
"--reload_from_ckpt",
type=str,
required=True,
help="the checkpoint path for the model",
)

args = parser.parse_args()

Expand All @@ -122,7 +129,7 @@ def infer(dataset_json, configs, config_yaml_path, exp_group_name, exp_name):
config_yaml_path = os.path.join(config_yaml)
config_yaml = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader)

if "reload_from_ckpt" is not None:
if args.reload_from_ckpt != None:
config_yaml["reload_from_ckpt"] = args.reload_from_ckpt

infer(dataset_json, config_yaml, config_yaml_path, exp_group_name, exp_name)

0 comments on commit 58cd823

Please sign in to comment.