Skip to content

Commit

Permalink
train1 to trainA
Browse files Browse the repository at this point in the history
  • Loading branch information
NoamRosenberg committed Jul 23, 2019
1 parent 60842e5 commit 3308efc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
13 changes: 7 additions & 6 deletions train_autodeeplab.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, args):
self.writer = self.summary.create_summary()

kwargs = {'num_workers': args.workers, 'pin_memory': True}
self.train_loader1, self.train_loader2, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

if args.use_balanced_weights:
classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
Expand All @@ -45,7 +45,7 @@ def __init__(self, args):
self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

# Define network
model = AutoDeeplab (num_classes=self.nclass, num_layers=12, criterion=self.criterion)
model = AutoDeeplab (num_classes=self.nclass, num_layers=12, criterion=self.criterion, filter_multiplier=self.args.filter_multiplier)
optimizer = torch.optim.SGD(
model.parameters(),
args.lr,
Expand All @@ -58,7 +58,7 @@ def __init__(self, args):
self.evaluator = Evaluator(self.nclass)
# Define lr scheduler
self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
args.epochs, len(self.train_loader1))
args.epochs, len(self.train_loaderA))

self.architect = Architect (self.model, args)

Expand Down Expand Up @@ -113,8 +113,8 @@ def __init__(self, args):
def training(self, epoch):
train_loss = 0.0
self.model.train()
tbar = tqdm(self.train_loader1)
num_img_tr = len(self.train_loader1)
tbar = tqdm(self.train_loaderA)
num_img_tr = len(self.train_loaderA)
for i, sample in enumerate(tbar):
image, target = sample['image'], sample['label']
if self.args.cuda:
Expand All @@ -127,7 +127,7 @@ def training(self, epoch):
self.optimizer.step()

if epoch > self.args.alpha_epoch:
search = next(iter(self.train_loader2))
search = next(iter(self.train_loaderB))
image_search, target_search = search['image'], search['label']
if self.args.cuda:
image_search, target_search = image_search.cuda (), target_search.cuda ()
Expand Down Expand Up @@ -247,6 +247,7 @@ def main():
help='number of epochs to train (default: auto)')
parser.add_argument('--start_epoch', type=int, default=0,
metavar='N', help='start epochs (default:0)')
parser.add_argument('--filter_multiplier', type=int, default=8)
parser.add_argument('--alpha_epoch', type=int, default=5,
metavar='N', help='epoch to start training alphas')
parser.add_argument('--batch-size', type=int, default=1,
Expand Down
2 changes: 1 addition & 1 deletion train_cityscapes.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
CUDA_VISIBLE_DEVICES=0,1 python train_autodeeplab.py --batch-size 2 \
--dataset cityscapes --checkname July22 --alpha_epoch 20
--dataset cityscapes --checkname July22 --alpha_epoch 20 --filter_multiplier 4

0 comments on commit 3308efc

Please sign in to comment.