Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
add PSPResNet50 and input-size option
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyu2172 committed Dec 19, 2018
1 parent 5f1b506 commit 0052dde
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
16 changes: 13 additions & 3 deletions examples/semantic_segmentation/eval_semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@

from chainercv.evaluations import eval_semantic_segmentation
from chainercv.experimental.links import PSPNetResNet101
from chainercv.experimental.links import PSPNetResNet50
from chainercv.links import SegNetBasic
from chainercv.utils import apply_to_iterator
from chainercv.utils import ProgressHook


def get_dataset_and_model(dataset_name, model_name, pretrained_model):
def get_dataset_and_model(dataset_name, model_name, pretrained_model,
input_size):
if dataset_name == 'cityscapes':
dataset = CityscapesSemanticSegmentationDataset(
split='val', label_resolution='fine')
Expand All @@ -39,7 +41,13 @@ def get_dataset_and_model(dataset_name, model_name, pretrained_model):
model = PSPNetResNet101(
n_class=n_class,
pretrained_model=pretrained_model,
input_size=(713, 713)
input_size=input_size
)
elif model_name == 'pspnet_resnet50':
model = PSPNetResNet50(
n_class=n_class,
pretrained_model=pretrained_model,
input_size=input_size
)
elif model_name == 'segnet':
model = SegNetBasic(
Expand All @@ -56,10 +64,12 @@ def main():
'pspnet_resnet101', 'segnet'))
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--pretrained-model')
parser.add_argument(
'--input-size', default=None)
args = parser.parse_args()

dataset, label_names, model = get_dataset_and_model(
args.dataset, args.model, args.pretrained_model)
args.dataset, args.model, args.pretrained_model, args.input_size)

if args.gpu >= 0:
chainer.cuda.get_device_from_id(args.gpu).use()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ def main():
'--model', choices=(
'pspnet_resnet101', 'segnet'))
parser.add_argument('--pretrained-model')
parser.add_argument(
'--input-size', default=None)
args = parser.parse_args()

comm = chainermn.create_communicator()
device = comm.intra_rank

dataset, label_names, model = get_dataset_and_model(
args.dataset, args.model, args.pretrained_model)
args.dataset, args.model, args.pretrained_model, args.input_size)
assert (len(dataset) % comm.size == 0, \
"The size of the dataset should be a multiple "\
"of the number of GPUs")
Expand Down

0 comments on commit 0052dde

Please sign in to comment.