Skip to content

Commit

Permalink
PEP8
Browse files Browse the repository at this point in the history
  • Loading branch information
fmassa committed Apr 29, 2019
1 parent 48d55fe commit 02c5b7f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 23 deletions.
2 changes: 1 addition & 1 deletion references/segmentation/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _has_valid_annotation(anno):

def get_coco(root, image_set, transforms):
PATHS = {
#"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
# "train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
"train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
}
Expand Down
4 changes: 1 addition & 3 deletions references/segmentation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, model, return_layers):
raise ValueError("return_layers are not present in model")

orig_return_layers = return_layers
return_layers = {k:v for k, v in return_layers.items()}
return_layers = {k: v for k, v in return_layers.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
Expand Down Expand Up @@ -76,7 +76,6 @@ def reset_classes(self, new_classes):
self.classes = new_classes



class SegmentationModel(nn.Module):
def __init__(self, backbone, head):
super(SegmentationModel, self).__init__()
Expand Down Expand Up @@ -227,4 +226,3 @@ def get_model(name, backbone, num_classes, aux=False):

model = SegmentationModel(backbone, classifiers)
return model

29 changes: 20 additions & 9 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ def print(*args, **kwargs):
__builtin__.print = print

torch_save = torch.save

def save(*args, **kwargs):
if is_master:
torch_save(*args, **kwargs)
torch.save = save


def main(args):
args.gpu = args.local_rank

Expand All @@ -116,11 +118,15 @@ def main(args):
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)

data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, collate_fn=utils.collate_fn, drop_last=True)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers,
collate_fn=utils.collate_fn, drop_last=True)

data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1,
sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1,
sampler=test_sampler, num_workers=args.workers,
collate_fn=utils.collate_fn)

model = models.get_model(args.model, args.backbone, num_classes=dataset.num_classes, aux=args.aux_loss)
model.to(device)
Expand All @@ -147,17 +153,23 @@ def main(args):
params_to_optimize,
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)

start_time = time.time()
for epoch in range(args.epochs):
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
with torch.no_grad():
confmat = evaluate(model, data_loader_test, device=device, num_classes=dataset.num_classes)
print(confmat)
torch.save({'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'args': args},
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
torch.save(
{
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'args': args
},
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
Expand Down Expand Up @@ -196,7 +208,6 @@ def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)


import os
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = num_gpus > 1
Expand Down
22 changes: 12 additions & 10 deletions references/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import os



class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
Expand Down Expand Up @@ -77,14 +76,16 @@ def reduce_from_all_processes(self):

def __str__(self):
acc_global, acc, iu = self.compute()
return ('global correct: {:.1f}\n'
'average row correct: {}\n'
'IoU: {}\n'
'mean IoU: {:.1f}').format(
acc_global.item() * 100,
['{:.1f}'.format(i) for i in (acc * 100).tolist()],
['{:.1f}'.format(i) for i in (iu * 100).tolist()],
iu.mean().item() * 100)
return (
'global correct: {:.1f}\n'
'average row correct: {}\n'
'IoU: {}\n'
'mean IoU: {:.1f}').format(
acc_global.item() * 100,
['{:.1f}'.format(i) for i in (acc * 100).tolist()],
['{:.1f}'.format(i) for i in (iu * 100).tolist()],
iu.mean().item() * 100)


class MetricLogger(object):
def __init__(self, delimiter="\t"):
Expand Down Expand Up @@ -152,7 +153,6 @@ def log_every(self, iterable, print_freq, header=None):
print('{} Total time: {}'.format(header, total_time_str))



def cat_list(images, fill_value=0):
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
batch_shape = (len(images),) + max_size
Expand All @@ -161,12 +161,14 @@ def cat_list(images, fill_value=0):
pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
return batched_imgs


def collate_fn(batch):
images, targets = list(zip(*batch))
batched_imgs = cat_list(images, fill_value=0)
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets


def mkdir(path):
try:
os.makedirs(path)
Expand Down

0 comments on commit 02c5b7f

Please sign in to comment.