From 24edbb0f097f4ba38e1470836d6acb7f34a03330 Mon Sep 17 00:00:00 2001 From: lkrisztian Date: Thu, 20 Jul 2023 18:53:45 +0200 Subject: [PATCH 1/2] fix/adapt argument name for pretrained model --- src/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/train.py b/src/train.py index 7c7eb0bc..d2bf876a 100644 --- a/src/train.py +++ b/src/train.py @@ -166,7 +166,7 @@ def train(model, train_generator, val_generator, id2code, batch_size, '--model_fn', type=str, help='Output model filename') parser.add_argument( - '--weights_path', type=str, default=None, + '--model_path', type=str, default=None, help='ONLY FOR OPERATION == FINE-TUNE: Input weights path') parser.add_argument( '--visualization_path', type=str, default='/tmp', @@ -270,7 +270,7 @@ def train(model, train_generator, val_generator, id2code, batch_size, '0 and smaller or equal than 1') main(args.operation, args.data_dir, args.output_dir, - args.model, args.model_fn, args.weights_path, args.visualization_path, + args.model, args.model_fn, args.model_path, args.visualization_path, args.nr_epochs, args.initial_epoch, args.batch_size, args.loss_function, args.seed, args.patience, (args.tensor_height, args.tensor_width), args.monitored_value, From ce75ffce94100ff2edca63763175ae732979598c Mon Sep 17 00:00:00 2001 From: lkrisztian Date: Thu, 27 Jul 2023 06:38:15 +0200 Subject: [PATCH 2/2] keep weighs_path instead of model_path parameter --- src/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/train.py b/src/train.py index d2bf876a..155f30d6 100644 --- a/src/train.py +++ b/src/train.py @@ -166,7 +166,7 @@ def train(model, train_generator, val_generator, id2code, batch_size, '--model_fn', type=str, help='Output model filename') parser.add_argument( - '--model_path', type=str, default=None, + '--weights_path', type=str, default=None, help='ONLY FOR OPERATION == FINE-TUNE: Input weights path') parser.add_argument( '--visualization_path', type=str, default='/tmp', @@ -251,9 +251,9 @@ def train(model, train_generator, val_generator, id2code, batch_size, args = parser.parse_args() # check required arguments by individual operations - if args.operation == 'fine-tune' and args.model_path is None: + if args.operation == 'fine-tune' and args.weights_path is None: raise parser.error( - 'Argument model_path required for operation == fine-tune') + 'Argument weights_path required for operation == fine-tune') if args.operation == 'train' and args.initial_epoch != 0: raise parser.error( 'Argument initial_epoch must be 0 for operation == train') @@ -270,7 +270,7 @@ def train(model, train_generator, val_generator, id2code, batch_size, '0 and smaller or equal than 1') main(args.operation, args.data_dir, args.output_dir, - args.model, args.model_fn, args.model_path, args.visualization_path, + args.model, args.model_fn, args.weights_path, args.visualization_path, args.nr_epochs, args.initial_epoch, args.batch_size, args.loss_function, args.seed, args.patience, (args.tensor_height, args.tensor_width), args.monitored_value,