This repository has been archived by the owner on Jul 2, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 304
Add VGG16 #265
Merged
Merged
Add VGG16 #265
Changes from all commits
Commits
Show all changes
156 commits
Select commit
Hold shift + click to select a range
46a07fb
add wip vgg
yuyu2172 4a32cce
Merge remote-tracking branch 'yuyu2172/image-folder-dataset' into HEAD
yuyu2172 8b35001
update on classification models
yuyu2172 1cd87b6
fix the issue with copying vision chain
yuyu2172 af88d89
add tests and pass tests for feature option
yuyu2172 2bbece2
add predict function
yuyu2172 17f6333
update test
yuyu2172 2d4c45e
add caffe trained pretrained weight loader
yuyu2172 55ecaca
improve eval_imagenet
yuyu2172 f570a92
add pretrained_model ooption to eval_imagenet
yuyu2172 f479f21
small fixes
yuyu2172 e23e5f5
pass test_predict
yuyu2172 b45fe22
more informative evaluation
yuyu2172 1f3a999
use scale instead of resize
yuyu2172 1f0353a
fix a bug in scale
yuyu2172 d184963
move crop option to __init__
yuyu2172 c6d62de
fix eval_imagenet
yuyu2172 d1a4e46
add docs
yuyu2172 df66a0a
support automatic download
yuyu2172 84a905d
Work without setting n_class or pretrained_model
yuyu2172 a83d548
remove VGG16FeatureExtractor
yuyu2172 b696389
improve print output
yuyu2172 47326fb
flake8 of eval_imagenet
yuyu2172 30035bb
improve doc
yuyu2172 2ffae2b
flake8 of tests
yuyu2172 8b3337c
fix tests
yuyu2172 fab7a7c
Simplify __init__
yuyu2172 1ac56a1
fix tests
yuyu2172 86f351d
change convert_from_caffe to convert_vgg
yuyu2172 fdf763e
update README
yuyu2172 d37d08d
accept multiple types of features
yuyu2172 7f04616
accept multiple features as input
yuyu2172 ce77737
cosmetic
yuyu2172 2206fc8
change name of a variable
yuyu2172 64efbc0
fix a bug in conditional
yuyu2172 ab2cfee
consistency in variable names
yuyu2172 3caec05
change api to return tuple instead of dict
yuyu2172 7b2d214
fix doc
yuyu2172 96e285c
fix faster_rcnn_vgg
yuyu2172 b0b070c
fix doc
yuyu2172 a1d6927
fix init
yuyu2172 338819a
simplify functions
yuyu2172 700c5f4
fix links.rst
yuyu2172 b1983cd
Merge branch 'image-folder-dataset' into classification
yuyu2172 5b08bd7
fix eval_imagenet
yuyu2172 6a2b43c
fix vgg for python3
yuyu2172 724dcb5
remove predict and use SequentialFeatureExtractionChain
yuyu2172 49bee22
exploit the fact that functions is ordered dict
yuyu2172 5af7506
VGG16Layers -> VGG16
yuyu2172 aa18ad7
add feature_extraction_predictor
yuyu2172 6c5f6f9
use sequential_chain
yuyu2172 b703084
simplify vgg16
yuyu2172 75db194
simplify sequential_chain
yuyu2172 a36323e
sequential_chain --> extraction_chain
yuyu2172 a889d47
delete unnecessary constraint
yuyu2172 a91d29d
stop using unnecessary list
yuyu2172 e0d9147
use property mean
yuyu2172 0972d02
feature_names --> layer_names for VGG
yuyu2172 d7f3f52
mean doc
yuyu2172 e794055
fix tests
yuyu2172 3309415
simplify initialization
yuyu2172 4c8b3e6
delete redundant layerts from extraction_chain
yuyu2172 e51762f
extraction_chain --> sequential_extractor
yuyu2172 10815f7
fix comment
yuyu2172 3e22012
[Sequential Extractor] Change default names for layers
yuyu2172 c8fbda6
[Sequential Extractor] function --> layer
yuyu2172 7ef4623
sequential_extractor --> sequential_feature_extractor and add doc
yuyu2172 427f724
Merge remote-tracking branch 'origin/master' into classification
yuyu2172 1c49938
add doc to feature_extraction_predictor
yuyu2172 467c500
test_feature_extraction_predictor
yuyu2172 a1f43dd
reflect name changes to examples/classification
yuyu2172 c3af69e
Merge remote-tracking branch 'yuyu2172/image-folder-dataset' into cla…
yuyu2172 011b320
update eval_imagenet and README
yuyu2172 ae5fd2d
fix flake8
yuyu2172 c8cdc36
fix error for python3
yuyu2172 268e781
improve doc
yuyu2172 ebe9340
fix feature_extraction_predictor.predict
yuyu2172 8f65c02
fix doc
yuyu2172 80df6a9
reorder arguments
yuyu2172 320174d
use initialize at runtime
yuyu2172 55c50c9
make layer_names dynamically changeable
yuyu2172 57ab2cc
Merge remote-tracking branch 'origin/master' into HEAD
yuyu2172 ec8ddf7
merge sequential-feature-extractor
yuyu2172 a232abc
use updated interface of SequentialFeatureExtractor
yuyu2172 348ae85
fix doc
yuyu2172 f97cd49
Merge remote-tracking branch 'origin/master' into classification
yuyu2172 2940e0d
it is not necessary to do a trick to save initialization time
yuyu2172 3b24821
use block
yuyu2172 8b408c6
simplify init_scope
yuyu2172 578b3aa
add convolution_2d_block
yuyu2172 05b478c
change pretrained weights
yuyu2172 826c183
fix doc
yuyu2172 a5f0be7
fix test_vgg16
yuyu2172 23a42d9
fix doc
yuyu2172 775ab31
fix a mistake in vgg16
yuyu2172 f553948
fix faster_rcnn_vgg
yuyu2172 7602cc7
Merge remote-tracking branch 'yuyu2172/classification' into classific…
yuyu2172 b9c3f5c
fix Faster RCNN train to work
yuyu2172 a982c8e
fix doc
yuyu2172 48db171
use Zero initialization when pretrained model is used
yuyu2172 f78cec3
Merge branch 'classification' of https://github.com/yuyu2172/chainerc…
yuyu2172 459e3c6
flake8
yuyu2172 b8e890d
improve doc of feature_extraction_predictor
yuyu2172 f04cb4b
Merge remote-tracking branch 'origin/master' into classification
yuyu2172 d005cb9
Merge remote-tracking branch 'origin/master' into classification
yuyu2172 3820df9
use remove_unused
yuyu2172 3323ce7
use remove_unused
yuyu2172 9d25473
change init style of convolution_2d_block
yuyu2172 f615940
use crop_size && make crop_size int
yuyu2172 3bc3b09
fix convolution_2d_block
yuyu2172 4233f8e
stop using do_ten_crop
yuyu2172 7608809
merge master
yuyu2172 6b8c0da
fix doc
yuyu2172 b85a3ce
fix doc for VGG16
yuyu2172 20783d6
fix variable names in eval_imagenet
yuyu2172 29f8979
Merge remote-tracking branch 'yuyu2172/image-folder-dataset' into cla…
yuyu2172 20b6eb9
fix vgg16
yuyu2172 a59194a
add convolution2DBlock to doc
yuyu2172 9a64c9b
delete unnecessary declaration of initial_bias
yuyu2172 799b48b
specify n_class in eval_imagenet
yuyu2172 5c83e38
use directory_parsing_label_names
yuyu2172 aafaecc
expose crop option in eval_imagenet
yuyu2172 fdd6f48
fix variable names
yuyu2172 4bbdd5a
support tuple as an argument for shapes (FeatureExtractionPredictor)
yuyu2172 465f1df
use choices kwarg for eval_imagenet
yuyu2172 62edb90
update download link of FasterRCNNVGG16
yuyu2172 895ef8a
flake8
yuyu2172 141bc39
Convolution2DBlock -> Conv2DActiv
yuyu2172 6d4cf29
use activ instead of activation
yuyu2172 b7df511
add a note in Conv2DActiv doc
yuyu2172 fd0dde5
add test on forward and activ
yuyu2172 5b2a257
that of -> those of
yuyu2172 86029ef
fix doc
yuyu2172 93d8e67
fix doc
yuyu2172 53bb7c6
fix doc
yuyu2172 3a13b35
fix doc
yuyu2172 8ea45df
flake8
yuyu2172 c7a14ad
Merge remote-tracking branch 'origin/master' into classification
yuyu2172 9a18739
merge master
yuyu2172 26231db
merge master
yuyu2172 d29172e
update code
yuyu2172 a3a20b2
update README
yuyu2172 10d2e09
use default mean value when unsupecified
yuyu2172 3f0da8e
fix eval_imagenet
yuyu2172 59f53e1
fix eval_imagenet
yuyu2172 4cf289c
fix
yuyu2172 e45bc5a
change convert_vgg to use caffemodel
yuyu2172 49249ed
change name caffee2npz_vgg
yuyu2172 0dfddfe
make vgg directory
yuyu2172 0900011
rename
Hakuyume 4de11f7
update conversion
Hakuyume 9d1e8ee
fix typo
Hakuyume 70bf30b
Merge pull request #4 from Hakuyume/vgg-conversion
yuyu2172 743055f
fix download link for faster_rcnn_vgg
yuyu2172 7bd02b0
update README vgg
yuyu2172 9ca4070
fix typo
yuyu2172 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from chainercv.links.model.vgg.vgg16 import VGG16 # NOQA |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
from __future__ import division | ||
|
||
import numpy as np | ||
|
||
import chainer | ||
from chainer.functions import dropout | ||
from chainer.functions import max_pooling_2d | ||
from chainer.functions import relu | ||
from chainer.functions import softmax | ||
from chainer.initializers import constant | ||
from chainer.initializers import normal | ||
|
||
from chainer.links import Linear | ||
|
||
from chainercv.utils import download_model | ||
|
||
from chainercv.links.connection.conv_2d_activ import Conv2DActiv | ||
from chainercv.links.model.sequential_feature_extractor import \ | ||
SequentialFeatureExtractor | ||
|
||
|
||
# RGB order | ||
_imagenet_mean = np.array( | ||
[123.68, 116.779, 103.939], dtype=np.float32)[:, np.newaxis, np.newaxis] | ||
|
||
|
||
class VGG16(SequentialFeatureExtractor): | ||
|
||
"""VGG-16 Network for classification and feature extraction. | ||
|
||
This is a feature extraction model. | ||
The network can choose output features from set of all | ||
intermediate features. | ||
The value of :obj:`VGG16.feature_names` selects the features that are going | ||
to be collected by :meth:`__call__`. | ||
:obj:`self.all_feature_names` is the list of the names of features | ||
that can be collected. | ||
|
||
Examples: | ||
|
||
>>> model = VGG16() | ||
# By default, __call__ returns a probability score (after Softmax). | ||
>>> prob = model(imgs) | ||
|
||
>>> model.feature_names = 'conv5_3' | ||
# This is feature conv5_3 (after ReLU). | ||
>>> feat5_3 = model(imgs) | ||
|
||
>>> model.feature_names = ['conv5_3', 'fc6'] | ||
>>> # These are features conv5_3 (after ReLU) and fc6 (before ReLU). | ||
>>> feat5_3, feat6 = model(imgs) | ||
|
||
.. seealso:: | ||
:class:`chainercv.links.model.SequentialFeatureExtractor` | ||
|
||
When :obj:`pretrained_model` is the path of a pre-trained chainer model | ||
serialized as a :obj:`.npz` file in the constructor, this chain model | ||
automatically initializes all the parameters with it. | ||
When a string in the prespecified set is provided, a pretrained model is | ||
loaded from weights distributed on the Internet. | ||
The list of pretrained models supported are as follows: | ||
|
||
* :obj:`imagenet`: Loads weights trained with ImageNet and distributed \ | ||
at `Model Zoo \ | ||
<https://github.com/BVLC/caffe/wiki/Model-Zoo>`_. | ||
|
||
Args: | ||
pretrained_model (str): The destination of the pre-trained | ||
chainer model serialized as a :obj:`.npz` file. | ||
If this is one of the strings described | ||
above, it automatically loads weights stored under a directory | ||
:obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/models/`, | ||
where :obj:`$CHAINER_DATASET_ROOT` is set as | ||
:obj:`$HOME/.chainer/dataset` unless you specify another value | ||
by modifying the environment variable. | ||
n_class (int): The number of classes. If :obj:`None`, | ||
the default values are used. | ||
If a supported pretrained model is used, | ||
the number of classes used to train the pretrained model | ||
is used. Otherwise, the number of classes in ILSVRC 2012 dataset | ||
is used. | ||
mean (numpy.ndarray): A mean value. If :obj:`None`, | ||
the default values are used. | ||
If a supported pretrained model is used, | ||
the mean value used to train the pretrained model is used. | ||
Otherwise, the mean value calculated from ILSVRC 2012 dataset | ||
is used. | ||
initialW (callable): Initializer for the weights. | ||
initial_bias (callable): Initializer for the biases. | ||
|
||
""" | ||
|
||
_models = { | ||
'imagenet': { | ||
'n_class': 1000, | ||
'url': 'https://github.com/yuyu2172/share-weights/releases/' | ||
'download/0.0.4/vgg16_imagenet_convert_2017_07_18.npz', | ||
'mean': _imagenet_mean | ||
} | ||
} | ||
|
||
def __init__(self, | ||
pretrained_model=None, n_class=None, mean=None, | ||
initialW=None, initial_bias=None): | ||
if n_class is None: | ||
if pretrained_model in self._models: | ||
n_class = self._models[pretrained_model]['n_class'] | ||
else: | ||
n_class = 1000 | ||
|
||
if mean is None: | ||
if pretrained_model in self._models: | ||
mean = self._models[pretrained_model]['mean'] | ||
else: | ||
mean = _imagenet_mean | ||
self.mean = mean | ||
|
||
if initialW is None: | ||
# Employ default initializers used in the original paper. | ||
initialW = normal.Normal(0.01) | ||
if pretrained_model: | ||
# As a sampling process is time-consuming, | ||
# we employ a zero initializer for faster computation. | ||
initialW = constant.Zero() | ||
kwargs = {'initialW': initialW, 'initial_bias': initial_bias} | ||
|
||
super(VGG16, self).__init__() | ||
with self.init_scope(): | ||
self.conv1_1 = Conv2DActiv(None, 64, 3, 1, 1, **kwargs) | ||
self.conv1_2 = Conv2DActiv(None, 64, 3, 1, 1, **kwargs) | ||
self.pool1 = _max_pooling_2d | ||
self.conv2_1 = Conv2DActiv(None, 128, 3, 1, 1, **kwargs) | ||
self.conv2_2 = Conv2DActiv(None, 128, 3, 1, 1, **kwargs) | ||
self.pool2 = _max_pooling_2d | ||
self.conv3_1 = Conv2DActiv(None, 256, 3, 1, 1, **kwargs) | ||
self.conv3_2 = Conv2DActiv(None, 256, 3, 1, 1, **kwargs) | ||
self.conv3_3 = Conv2DActiv(None, 256, 3, 1, 1, **kwargs) | ||
self.pool3 = _max_pooling_2d | ||
self.conv4_1 = Conv2DActiv(None, 512, 3, 1, 1, **kwargs) | ||
self.conv4_2 = Conv2DActiv(None, 512, 3, 1, 1, **kwargs) | ||
self.conv4_3 = Conv2DActiv(None, 512, 3, 1, 1, **kwargs) | ||
self.pool4 = _max_pooling_2d | ||
self.conv5_1 = Conv2DActiv(None, 512, 3, 1, 1, **kwargs) | ||
self.conv5_2 = Conv2DActiv(None, 512, 3, 1, 1, **kwargs) | ||
self.conv5_3 = Conv2DActiv(None, 512, 3, 1, 1, **kwargs) | ||
self.pool5 = _max_pooling_2d | ||
self.fc6 = Linear(None, 4096, **kwargs) | ||
self.fc6_relu = relu | ||
self.fc6_dropout = dropout | ||
self.fc7 = Linear(None, 4096, **kwargs) | ||
self.fc7_relu = relu | ||
self.fc7_dropout = dropout | ||
self.fc8 = Linear(None, n_class, **kwargs) | ||
self.prob = softmax | ||
|
||
if pretrained_model in self._models: | ||
path = download_model(self._models[pretrained_model]['url']) | ||
chainer.serializers.load_npz(path, self) | ||
elif pretrained_model: | ||
chainer.serializers.load_npz(pretrained_model, self) | ||
|
||
|
||
def _max_pooling_2d(x): | ||
return max_pooling_2d(x, ksize=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
VGG | ||
=== | ||
|
||
.. module:: chainercv.links.model.vgg | ||
|
||
|
||
VGG16 | ||
----- | ||
|
||
.. autoclass:: VGG16 | ||
:members: |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding a note about these behaviours? The default values of
n_class
andmean
.