Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W&B: Restructure code to support the new dataset_check() feature #4197

Merged
merged 22 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified README.md
100755 → 100644
Empty file.
17 changes: 11 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,29 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
yaml.safe_dump(hyp, f, sort_keys=False)
with open(save_dir / 'opt.yaml', 'w') as f:
yaml.safe_dump(vars(opt), f, sort_keys=False)
data_dict = None

# Loggers
if RANK in [-1, 0]:
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER).start() # loggers dict
if loggers.wandb:
data_dict = loggers.wandb.data_dict
if resume:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp


# Config
plots = not evolve # create plots
cuda = device.type != 'cpu'
init_seeds(1 + RANK)
with torch_distributed_zero_first(RANK):
data_dict = check_dataset(data) # check
data_dict = data_dict or check_dataset(data) # check if None
train_path, val_path = data_dict['train'], data_dict['val']
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset

# Loggers
if RANK in [-1, 0]:
loggers = Loggers(save_dir, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
if loggers.wandb and resume:
weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict

# Model
pretrained = weights.endswith('.pt')
Expand Down
13 changes: 4 additions & 9 deletions utils/loggers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# YOLOv5 experiment logging utils

import torch
import warnings
from threading import Thread

import torch
from torch.utils.tensorboard import SummaryWriter

from utils.general import colorstr, emojis
Expand All @@ -23,12 +21,11 @@

class Loggers():
# YOLOv5 Loggers class
def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, data_dict=None, logger=None, include=LOGGERS):
def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, include=LOGGERS):
self.save_dir = save_dir
self.weights = weights
self.opt = opt
self.hyp = hyp
self.data_dict = data_dict
self.logger = logger # for printing results to console
self.include = include
for k in LOGGERS:
Expand All @@ -38,9 +35,7 @@ def start(self):
self.csv = True # always log to csv

# Message
try:
import wandb
except ImportError:
if not wandb:
prefix = colorstr('Weights & Biases: ')
s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)"
print(emojis(s))
Expand All @@ -57,7 +52,7 @@ def start(self):
assert 'wandb' in self.include and wandb
run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume else None
self.opt.hyp = self.hyp # add hyperparameters
self.wandb = WandbLogger(self.opt, run_id, self.data_dict)
self.wandb = WandbLogger(self.opt, run_id)
except:
self.wandb = None

Expand Down
6 changes: 2 additions & 4 deletions utils/loggers/wandb/log_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse

import yaml

from wandb_utils import WandbLogger
Expand All @@ -8,9 +7,7 @@


def create_dataset_artifact(opt):
with open(opt.data, encoding='ascii', errors='ignore') as f:
data = yaml.safe_load(f) # data dict
logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation') # TODO: return value unused
logger = WandbLogger(opt, None, job_type='Dataset Creation') # TODO: return value unused


if __name__ == '__main__':
Expand All @@ -19,6 +16,7 @@ def create_dataset_artifact(opt):
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project')
parser.add_argument('--entity', default=None, help='W&B entity')
parser.add_argument('--name', type=str, default='log dataset', help='name of W&B run')

opt = parser.parse_args()
opt.resume = False # Explicitly disallow resume check for dataset upload job
Expand Down
3 changes: 1 addition & 2 deletions utils/loggers/wandb/sweep.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys
from pathlib import Path

import wandb
from pathlib import Path

FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[2].as_posix()) # add utils/ to path
Expand Down
53 changes: 34 additions & 19 deletions utils/loggers/wandb/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import logging
import os
import sys
import yaml
from contextlib import contextmanager
from pathlib import Path

import yaml
from tqdm import tqdm

FILE = Path(__file__).absolute()
Expand Down Expand Up @@ -99,7 +98,7 @@ class WandbLogger():
https://docs.wandb.com/guides/integrations/yolov5
"""

def __init__(self, opt, run_id, data_dict, job_type='Training'):
def __init__(self, opt, run_id, job_type='Training'):
"""
- Initialize WandbLogger instance
- Upload dataset if opt.upload_dataset is True
Expand All @@ -108,7 +107,6 @@ def __init__(self, opt, run_id, data_dict, job_type='Training'):
arguments:
opt (namespace) -- Commandline arguments for this run
run_id (str) -- Run ID of W&B run to be resumed
data_dict (Dict) -- Dictionary conataining info about the dataset to be used
job_type (str) -- To set the job_type for this run

"""
Expand All @@ -119,10 +117,11 @@ def __init__(self, opt, run_id, data_dict, job_type='Training'):
self.train_artifact_path, self.val_artifact_path = None, None
self.result_artifact = None
self.val_table, self.result_table = None, None
self.data_dict = data_dict
self.bbox_media_panel_images = []
self.val_table_path_map = None
self.max_imgs_to_log = 16
self.wandb_artifact_data_dict = None
self.data_dict = None
# It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
if isinstance(opt.resume, str): # checks resume from artifact
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
Expand All @@ -148,11 +147,23 @@ def __init__(self, opt, run_id, data_dict, job_type='Training'):
if self.wandb_run:
if self.job_type == 'Training':
if not opt.resume:
wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
# Info useful for resuming from artifacts
self.wandb_run.config.update({'opt': vars(opt), 'data_dict': wandb_data_dict},
allow_val_change=True)
self.data_dict = self.setup_training(opt, data_dict)
if opt.upload_dataset:
self.wandb_artifact_data_dict = self.check_and_upload_dataset(opt)

elif opt.data.endswith('_wandb.yaml'): # When dataset is W&B artifact
with open(opt.data, encoding='ascii', errors='ignore') as f:
data_dict = yaml.safe_load(f)
self.data_dict = data_dict
else: # Local .yaml dataset file or .zip file
self.data_dict = check_dataset(opt.data)

self.setup_training(opt)
# write data_dict to config. useful for resuming from artifacts
if not self.wandb_artifact_data_dict:
self.wandb_artifact_data_dict = self.data_dict
self.wandb_run.config.update({'data_dict': self.wandb_artifact_data_dict},
allow_val_change=True)

if self.job_type == 'Dataset Creation':
self.data_dict = self.check_and_upload_dataset(opt)

Expand All @@ -167,15 +178,15 @@ def check_and_upload_dataset(self, opt):
Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links.
"""
assert wandb, 'Install wandb to upload dataset'
config_path = self.log_dataset_artifact(check_file(opt.data),
config_path = self.log_dataset_artifact(opt.data,
opt.single_cls,
'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
print("Created dataset config file ", config_path)
with open(config_path, encoding='ascii', errors='ignore') as f:
wandb_data_dict = yaml.safe_load(f)
return wandb_data_dict

def setup_training(self, opt, data_dict):
def setup_training(self, opt):
"""
Setup the necessary processes for training YOLO models:
- Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX
Expand All @@ -184,10 +195,7 @@ def setup_training(self, opt, data_dict):

arguments:
opt (namespace) -- commandline arguments for this run
data_dict (Dict) -- Dataset dictionary for this run

returns:
data_dict (Dict) -- contains the updated info about the dataset to be used for training
"""
self.log_dict, self.current_epoch = {}, 0
self.bbox_interval = opt.bbox_interval
Expand All @@ -198,8 +206,10 @@ def setup_training(self, opt, data_dict):
config = self.wandb_run.config
opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \
config.opt['hyp']
config.hyp
data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
else:
data_dict = self.data_dict
if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download
self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
opt.artifact_alias)
Expand All @@ -221,7 +231,10 @@ def setup_training(self, opt, data_dict):
self.map_val_table_path()
if opt.bbox_interval == -1:
self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
return data_dict
train_from_artifact = self.train_artifact_path is not None and self.val_artifact_path is not None
# Update the the data_dict to point to local artifacts dir
if train_from_artifact:
self.data_dict = data_dict

def download_dataset_artifact(self, path, alias):
"""
Expand Down Expand Up @@ -299,7 +312,8 @@ def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=
returns:
the new .yaml file with artifact links. it can be used to start training directly from artifacts
"""
data = check_dataset(data_file) # parse and check
self.data_dict = check_dataset(data_file) # parse and check
data = dict(self.data_dict)
nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
names = {k: v for k, v in enumerate(names)} # to index dictionary
self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
Expand All @@ -310,7 +324,8 @@ def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=
data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
if data.get('val'):
data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
path = Path(data_file).stem
path = (path if overwrite_config else path + '_wandb') + '.yaml' # updated data.yaml path
data.pop('download', None)
data.pop('path', None)
with open(path, 'w') as f:
Expand Down