From 5ab2e2669076884fd44255e6fbbff1d8dafff899 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 23 Mar 2021 21:24:34 +0530 Subject: [PATCH] W&B DDP fix (#2574) --- train.py | 8 +++++--- utils/wandb_logging/wandb_utils.py | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 62a72375c7a3..fd2d6745ab46 100644 --- a/train.py +++ b/train.py @@ -66,14 +66,16 @@ def train(hyp, opt, device, tb_writer=None): is_coco = opt.data.endswith('coco.yaml') # Logging- Doing this before checking the dataset. Might update data_dict + loggers = {'wandb': None} # loggers dict if rank in [-1, 0]: opt.hyp = hyp # add hyperparameters run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict) + loggers['wandb'] = wandb_logger.wandb data_dict = wandb_logger.data_dict if wandb_logger.wandb: weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming - loggers = {'wandb': wandb_logger.wandb} # loggers dict + nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check @@ -381,6 +383,7 @@ def train(hyp, opt, device, tb_writer=None): fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] if fi > best_fitness: best_fitness = fi + wandb_logger.end_epoch(best_result=best_fitness == fi) # Save model if (not opt.nosave) or (final_epoch and not opt.evolve): # if save @@ -402,7 +405,6 @@ def train(hyp, opt, device, tb_writer=None): wandb_logger.log_model( last.parent, opt, epoch, fi, best_model=best_fitness == fi) del ckpt - wandb_logger.end_epoch(best_result=best_fitness == fi) # end epoch ---------------------------------------------------------------------------------------------------- # end training @@ -442,10 +444,10 @@ def train(hyp, opt, device, tb_writer=None): wandb_logger.wandb.log_artifact(str(final), type='model', name='run_' + wandb_logger.wandb_run.id + '_model', aliases=['last', 'best', 'stripped']) + wandb_logger.finish_run() else: dist.destroy_process_group() torch.cuda.empty_cache() - wandb_logger.finish_run() return results diff --git a/utils/wandb_logging/wandb_utils.py b/utils/wandb_logging/wandb_utils.py index c9a32f5b6026..d6dd256366e0 100644 --- a/utils/wandb_logging/wandb_utils.py +++ b/utils/wandb_logging/wandb_utils.py @@ -16,9 +16,9 @@ try: import wandb + from wandb import init, finish except ImportError: wandb = None - print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)") WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' @@ -71,6 +71,9 @@ def __init__(self, opt, name, run_id, data_dict, job_type='Training'): self.data_dict = self.setup_training(opt, data_dict) if self.job_type == 'Dataset Creation': self.data_dict = self.check_and_upload_dataset(opt) + else: + print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)") + def check_and_upload_dataset(self, opt): assert wandb, 'Install wandb to upload dataset'