Skip to content

Commit

Permalink
Tag where runs are being executed
Browse files Browse the repository at this point in the history
  • Loading branch information
fmassa committed Mar 24, 2019
1 parent b1bfc6e commit 5d0ecca
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 11 deletions.
23 changes: 12 additions & 11 deletions references/segmentation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ def __init__(self, block, layers, dilated=False, return_layers=None):

def _add_dilation(self):
d = (2, 2)
self.layer3[0].downsample[0].stride = (1, 1)
self.layer3[0].conv2.stride = (1, 1)
for b in self.layer3[1:]:
b.conv2.padding = d
b.conv2.dilation = d
self.layer4[0].downsample[0].stride = (1, 1)
self.layer4[0].conv2.stride = (1, 1)
self.layer4[0].conv2.dilation = (2, 2)
self.layer4[0].conv2.padding = d
self.layer4[0].conv2.dilation = d
d = (4, 4)
Expand Down Expand Up @@ -89,20 +94,16 @@ def __init__(self, in_channels, channels):
super(FCNHead, self).__init__(*layers)


class DeepLabHead(nn.Module):
class DeepLabHead(nn.Sequential):
def __init__(self, in_channels, num_classes):
super(DeepLabHead, self).__init__()
self.aspp = ASPP(in_channels, [12, 24, 36])
self.block = nn.Sequential(nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, num_classes, 1)
super(DeepLabHead, self).__init__(
ASPP(in_channels, [12, 24, 36]),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, num_classes, 1)
)

def forward(self, x):
x = self.aspp(x)
return self.block(x)


class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, atrous_rate):
Expand Down
7 changes: 7 additions & 0 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def main(args):

model = models.get_model(args.model, args.backbone, num_classes=dataset.num_classes, aux=args.aux_loss)
model.to(device)
model = torch.nn.utils.convert_sync_batchnorm(model)

model_without_ddp = model
if args.distributed:
Expand All @@ -146,6 +147,8 @@ def main(args):
with torch.no_grad():
confmat = evaluate(model, data_loader_test, device=device, num_classes=dataset.num_classes)
print(confmat)
torch.save({'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'args': args},
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
Expand Down Expand Up @@ -174,11 +177,15 @@ def main(args):
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--local_rank', default=0, type=int, help='print frequency')

args = parser.parse_args()
print(args)

if args.output_dir:
utils.mkdir(args.output_dir)


import os
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
Expand Down
11 changes: 11 additions & 0 deletions references/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import time
import torch

import errno
import os



class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
Expand Down Expand Up @@ -162,3 +166,10 @@ def collate_fn(batch):
batched_imgs = cat_list(images, fill_value=0)
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets

def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise

0 comments on commit 5d0ecca

Please sign in to comment.