Skip to content

Commit

Permalink
Eliminate total_batch_size variable (#3697)
Browse files Browse the repository at this point in the history
* Eliminate `total_batch_size` variable

* cleanup

* Update train.py
  • Loading branch information
glenn-jocher authored Jun 19, 2021
1 parent fad27c0 commit b3e2f4e
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt,
device,
):
save_dir, epochs, batch_size, total_batch_size, weights, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.single_cls
save_dir, epochs, batch_size, weights, single_cls = \
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls

# Directories
save_dir = Path(save_dir)
wdir = save_dir / 'weights'
wdir.mkdir(parents=True, exist_ok=True) # make dir
last = wdir / 'last.pt'
Expand Down Expand Up @@ -127,8 +128,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# Optimizer
nbs = 64 # nominal batch size
accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")

pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
Expand Down Expand Up @@ -205,7 +206,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
logger.info('Using SyncBatchNorm()')

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
workers=opt.workers,
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
Expand All @@ -215,7 +216,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# Process 0
if RANK in [-1, 0]:
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
workers=opt.workers,
pad=0.5, prefix=colorstr('val: '))[0]
Expand Down Expand Up @@ -302,7 +303,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if ni <= nw:
xi = [0, nw] # x interp
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
for j, x in enumerate(optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
Expand Down Expand Up @@ -371,7 +372,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if not opt.notest or final_epoch: # Calculate mAP
wandb_logger.current_epoch = epoch + 1
results, maps, _ = test.test(data_dict,
batch_size=batch_size * 2,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz_test,
model=ema.ema,
single_cls=single_cls,
Expand Down Expand Up @@ -439,7 +440,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if is_coco: # COCO dataset
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
results, _, _ = test.test(opt.data,
batch_size=batch_size * 2,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz_test,
conf_thres=0.001,
iou_thres=0.7,
Expand Down Expand Up @@ -518,7 +519,7 @@ def main(opt):
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
opt.cfg, opt.weights, opt.resume, opt.batch_size = '', ckpt, True, opt.total_batch_size # reinstate
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
logger.info('Resuming training from %s' % ckpt)
else:
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
Expand All @@ -529,17 +530,15 @@ def main(opt):
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve))

# DDP mode
opt.total_batch_size = opt.batch_size
device = select_device(opt.device, batch_size=opt.batch_size)
if LOCAL_RANK != -1:
from datetime import timedelta
assert torch.cuda.device_count() > LOCAL_RANK, 'too few GPUS for DDP command'
assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
torch.cuda.set_device(LOCAL_RANK)
device = torch.device('cuda', LOCAL_RANK)
dist.init_process_group(backend="gloo", timeout=timedelta(seconds=60))
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
opt.batch_size = opt.total_batch_size // WORLD_SIZE

# Train
if not opt.evolve:
Expand Down

0 comments on commit b3e2f4e

Please sign in to comment.