Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add train.run() method #3700

Merged
merged 3 commits into from
Jun 20, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 45 additions & 36 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt,
device,
):
save_dir, epochs, batch_size, weights, single_cls = \
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, notest, nosave, workers, = \
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
opt.resume, opt.notest, opt.nosave, opt.workers

# Directories
save_dir = Path(save_dir)
Expand All @@ -70,49 +71,49 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
yaml.safe_dump(vars(opt), f, sort_keys=False)

# Configure
plots = not opt.evolve # create plots
plots = not evolve # create plots
cuda = device.type != 'cpu'
init_seeds(2 + RANK)
with open(opt.data) as f:
with open(data) as f:
data_dict = yaml.safe_load(f) # data dict

# Loggers
loggers = {'wandb': None, 'tb': None} # loggers dict
if RANK in [-1, 0]:
# TensorBoard
if not opt.evolve:
if not evolve:
prefix = colorstr('tensorboard: ')
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
loggers['tb'] = SummaryWriter(opt.save_dir)
loggers['tb'] = SummaryWriter(str(save_dir))

# W&B
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb
data_dict = wandb_logger.data_dict
if wandb_logger.wandb:
if loggers['wandb']:
data_dict = wandb_logger.data_dict
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming

nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data) # check
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset

# Model
pretrained = weights.endswith('.pt')
if pretrained:
with torch_distributed_zero_first(RANK):
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
state_dict = ckpt['model'].float().state_dict() # to FP32
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
model.load_state_dict(state_dict, strict=False) # load
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
else:
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
with torch_distributed_zero_first(RANK):
check_dataset(data_dict) # check
train_path = data_dict['train']
Expand Down Expand Up @@ -182,7 +183,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# Epochs
start_epoch = ckpt['epoch'] + 1
if opt.resume:
if resume:
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
if epochs < start_epoch:
logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
Expand Down Expand Up @@ -210,20 +211,20 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
workers=opt.workers,
workers=workers,
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1)

# Process 0
if RANK in [-1, 0]:
testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
workers=opt.workers,
hyp=hyp, cache=opt.cache_images and not notest, rect=True, rank=-1,
workers=workers,
pad=0.5, prefix=colorstr('val: '))[0]

if not opt.resume:
if not resume:
labels = np.concatenate(dataset.labels, 0)
c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
Expand Down Expand Up @@ -356,8 +357,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
elif plots and ni == 10 and wandb_logger.wandb:
wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
elif plots and ni == 10 and loggers['wandb']:
wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]})

# end batch ------------------------------------------------------------------------------------------------
Expand All @@ -371,7 +372,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# mAP
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
if not notest or final_epoch: # Calculate mAP
wandb_logger.current_epoch = epoch + 1
results, maps, _ = test.test(data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
Expand All @@ -398,7 +399,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
if loggers['tb']:
loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
if wandb_logger.wandb:
if loggers['wandb']:
wandb_logger.log({tag: x}) # W&B

# Update best mAP
Expand All @@ -408,21 +409,21 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
wandb_logger.end_epoch(best_result=best_fitness == fi)

# Save model
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
if (not nosave) or (final_epoch and not evolve): # if save
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': results_file.read_text(),
'model': deepcopy(de_parallel(model)).half(),
'ema': deepcopy(ema.ema).half(),
'updates': ema.updates,
'optimizer': optimizer.state_dict(),
'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None}

# Save last, best and delete
torch.save(ckpt, last)
if best_fitness == fi:
torch.save(ckpt, best)
if wandb_logger.wandb:
if loggers['wandb']:
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
del ckpt
Expand All @@ -433,15 +434,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots:
plot_results(save_dir=save_dir) # save as results.png
if wandb_logger.wandb:
if loggers['wandb']:
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})

if not opt.evolve:
if not evolve:
if is_coco: # COCO dataset
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
results, _, _ = test.test(opt.data,
results, _, _ = test.test(data,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz_test,
conf_thres=0.001,
Expand All @@ -457,17 +458,17 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
if wandb_logger.wandb: # Log the stripped model
wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
if loggers['wandb']: # Log the stripped model
loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
wandb_logger.finish_run()

torch.cuda.empty_cache()
return results


def parse_opt():
def parse_opt(known=False):
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
Expand Down Expand Up @@ -503,7 +504,7 @@ def parse_opt():
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
opt = parser.parse_args()
opt = parser.parse_known_args()[0] if known else parser.parse_args()
return opt


Expand Down Expand Up @@ -633,6 +634,14 @@ def main(opt):
f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')


def run(**kwargs):
# Usage: import train; train.run(imgsz=320, weights='yolov5m.pt')
opt = parse_opt(True)
for k, v in kwargs.items():
setattr(opt, k, v)
main(opt)


if __name__ == "__main__":
opt = parse_opt()
main(opt)