Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

refactor eval_imagenet.py #874

Merged
merged 1 commit into from
May 19, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 68 additions & 48 deletions examples/classification/eval_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,70 +23,90 @@
from chainercv.utils import ProgressHook


models = {
# model: (class, dataset -> pretrained_model, default batchsize,
# crop, resnet_arch)
'vgg16': (VGG16, {}, 32, 'center', None),
'resnet50': (ResNet50, {}, 32, 'center', 'fb'),
'resnet101': (ResNet101, {}, 32, 'center', 'fb'),
'resnet152': (ResNet152, {}, 32, 'center', 'fb'),
'se-resnet50': (SEResNet50, {}, 32, 'center', None),
'se-resnet101': (SEResNet101, {}, 32, 'center', None),
'se-resnet152': (SEResNet152, {}, 32, 'center', None),
'se-resnext50': (SEResNeXt50, {}, 32, 'center', None),
'se-resnext101': (SEResNeXt101, {}, 32, 'center', None),
}


def setup(dataset, model, pretrained_model, batchsize, val, crop, resnet_arch):
dataset_name = dataset
if dataset_name == 'imagenet':
dataset = DirectoryParsingLabelDataset(val)
label_names = directory_parsing_label_names(val)

def eval_(out_values, rest_values):
pred_probs, = out_values
gt_labels, = rest_values

accuracy = F.accuracy(
np.array(list(pred_probs)), np.array(list(gt_labels))).data
print()
print('Top 1 Error {}'.format(1. - accuracy))

cls, pretrained_models, default_batchsize = models[model][:3]
if pretrained_model is None:
pretrained_model = pretrained_models.get(dataset_name, dataset_name)
if crop is None:
crop = models[model][3]
kwargs = {
'n_class': len(label_names),
'pretrained_model': pretrained_model,
}
if model in ['resnet50', 'resnet101', 'resnet152']:
if resnet_arch is None:
resnet_arch = models[model][4]
kwargs.update({'arch': resnet_arch})
extractor = cls(**kwargs)
model = FeaturePredictor(
extractor, crop_size=224, scale_size=256, crop=crop)

if batchsize is None:
batchsize = default_batchsize

return dataset, eval_, model, batchsize


def main():
parser = argparse.ArgumentParser(
description='Learning convnet from ILSVRC2012 dataset')
description='Evaluating convnet from ILSVRC2012 dataset')
parser.add_argument('val', help='Path to root of the validation dataset')
parser.add_argument(
'--model', choices=(
'vgg16',
'resnet50', 'resnet101', 'resnet152',
'se-resnet50', 'se-resnet101', 'se-resnet152',
'se-resnext50', 'se-resnext101'))
parser.add_argument('--pretrained-model', default='imagenet')
parser.add_argument('--model', choices=sorted(models.keys()))
parser.add_argument('--pretrained-model')
parser.add_argument('--dataset', choices=('imagenet'))
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--batchsize', type=int, default=32)
parser.add_argument('--crop', choices=('center', '10'), default='center')
parser.add_argument('--resnet-arch', default='fb')
parser.add_argument('--batchsize', type=int)
parser.add_argument('--crop', choices=('center', '10'))
parser.add_argument('--resnet-arch')
args = parser.parse_args()

dataset = DirectoryParsingLabelDataset(args.val)
label_names = directory_parsing_label_names(args.val)
n_class = len(label_names)
iterator = iterators.MultiprocessIterator(
dataset, args.batchsize, repeat=False, shuffle=False,
n_processes=6, shared_mem=300000000)

if args.model == 'vgg16':
extractor = VGG16(n_class, args.pretrained_model)
elif args.model == 'resnet50':
extractor = ResNet50(
n_class, args.pretrained_model, arch=args.resnet_arch)
elif args.model == 'resnet101':
extractor = ResNet101(
n_class, args.pretrained_model, arch=args.resnet_arch)
elif args.model == 'resnet152':
extractor = ResNet152(
n_class, args.pretrained_model, arch=args.resnet_arch)
elif args.model == 'se-resnet50':
extractor = SEResNet50(n_class, args.pretrained_model)
elif args.model == 'se-resnet101':
extractor = SEResNet101(n_class, args.pretrained_model)
elif args.model == 'se-resnet152':
extractor = SEResNet152(n_class, args.pretrained_model)
elif args.model == 'se-resnext50':
extractor = SEResNeXt50(n_class, args.pretrained_model)
elif args.model == 'se-resnext101':
extractor = SEResNeXt101(n_class, args.pretrained_model)
model = FeaturePredictor(
extractor, crop_size=224, scale_size=256, crop=args.crop)
dataset, eval_, model, batchsize = setup(
args.dataset, args.model, args.pretrained_model, args.batchsize,
args.val, args.crop, args.resnet_arch)

if args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use()
model.to_gpu()

iterator = iterators.MultiprocessIterator(
dataset, batchsize, repeat=False, shuffle=False,
n_processes=6, shared_mem=300000000)

print('Model has been prepared. Evaluation starts.')
in_values, out_values, rest_values = apply_to_iterator(
model.predict, iterator, hook=ProgressHook(len(dataset)))
del in_values

pred_probs, = out_values
gt_labels, = rest_values

accuracy = F.accuracy(
np.array(list(pred_probs)), np.array(list(gt_labels))).data
print()
print('Top 1 Error {}'.format(1. - accuracy))
eval_(out_values, rest_values)


if __name__ == '__main__':
Expand Down