Skip to content

Commit

Permalink
Add run method
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Jun 20, 2021
1 parent b943a4b commit 2b23ea0
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
return results


def parse_opt():
def parse_opt(known=False):
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
Expand Down Expand Up @@ -504,7 +504,7 @@ def parse_opt():
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
opt = parser.parse_args()
opt = parser.parse_known_args()[0] if known else parser.parse_args()
return opt


Expand Down Expand Up @@ -634,6 +634,14 @@ def main(opt):
f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')


def run(**kwargs):
# Usage: import train; train.run(imgsz=320, weights='yolov5m.pt')
opt = parse_opt(True)
for k, v in kwargs.items():
setattr(opt, k, v)
main(opt)


if __name__ == "__main__":
opt = parse_opt()
main(opt)

0 comments on commit 2b23ea0

Please sign in to comment.