Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
Merge pull request #523 from Hakuyume/improve-apply-prediction
Browse files Browse the repository at this point in the history
Improve apply_prediction
  • Loading branch information
yuyu2172 authored Mar 5, 2018
2 parents 8252b2e + 09f7ade commit 4d239ad
Show file tree
Hide file tree
Showing 14 changed files with 354 additions and 300 deletions.
18 changes: 9 additions & 9 deletions chainercv/extensions/evaluator/detection_voc_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import chainer.training.extensions

from chainercv.evaluations import eval_detection_voc
from chainercv.utils import apply_prediction_to_iterator
from chainercv.utils import apply_to_iterator


class DetectionVOCEvaluator(chainer.training.extensions.Evaluator):
Expand Down Expand Up @@ -72,17 +72,17 @@ def evaluate(self):
else:
it = copy.copy(iterator)

imgs, pred_values, gt_values = apply_prediction_to_iterator(
in_values, out_values, rest_values = apply_to_iterator(
target.predict, it)
# delete unused iterator explicitly
del imgs
# delete unused iterators explicitly
del in_values

pred_bboxes, pred_labels, pred_scores = pred_values
pred_bboxes, pred_labels, pred_scores = out_values

if len(gt_values) == 3:
gt_bboxes, gt_labels, gt_difficults = gt_values
elif len(gt_values) == 2:
gt_bboxes, gt_labels = gt_values
if len(rest_values) == 3:
gt_bboxes, gt_labels, gt_difficults = rest_values
elif len(rest_values) == 2:
gt_bboxes, gt_labels = rest_values
gt_difficults = None

result = eval_detection_voc(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import chainer.training.extensions

from chainercv.evaluations import eval_semantic_segmentation
from chainercv.utils import apply_prediction_to_iterator
from chainercv.utils import apply_to_iterator


class SemanticSegmentationEvaluator(chainer.training.extensions.Evaluator):
Expand Down Expand Up @@ -79,13 +79,13 @@ def evaluate(self):
else:
it = copy.copy(iterator)

imgs, pred_values, gt_values = apply_prediction_to_iterator(
in_values, out_values, rest_values = apply_to_iterator(
target.predict, it)
# delete unused iterator explicitly
del imgs
# delete unused iterators explicitly
del in_values

pred_labels, = pred_values
gt_labels, = gt_values
pred_labels, = out_values
gt_labels, = rest_values

result = eval_semantic_segmentation(pred_labels, gt_labels)

Expand Down
2 changes: 1 addition & 1 deletion chainercv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from chainercv.utils.image import read_image # NOQA
from chainercv.utils.image import tile_images # NOQA
from chainercv.utils.image import write_image # NOQA
from chainercv.utils.iterator import apply_prediction_to_iterator # NOQA
from chainercv.utils.iterator import apply_to_iterator # NOQA
from chainercv.utils.iterator import ProgressHook # NOQA
from chainercv.utils.iterator import unzip # NOQA
from chainercv.utils.testing import assert_is_bbox # NOQA
Expand Down
2 changes: 1 addition & 1 deletion chainercv/utils/iterator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from chainercv.utils.iterator.apply_prediction_to_iterator import apply_prediction_to_iterator # NOQA
from chainercv.utils.iterator.apply_to_iterator import apply_to_iterator # NOQA
from chainercv.utils.iterator.progress_hook import ProgressHook # NOQA
from chainercv.utils.iterator.unzip import unzip # NOQA
141 changes: 0 additions & 141 deletions chainercv/utils/iterator/apply_prediction_to_iterator.py

This file was deleted.

169 changes: 169 additions & 0 deletions chainercv/utils/iterator/apply_to_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from chainercv.utils.iterator.unzip import unzip


def apply_to_iterator(func, iterator, n_input=1, hook=None):
"""Apply a function/method to batches from an iterator.
This function applies a function/method to an iterator of batches.
It assumes that the iterator iterates over a collection of tuples
that contain inputs to :func:`func`.
Additionally, the tuples may contain values
that are not used by :func:`func`.
For convenience, we allow the iterator to iterate over a collection of
inputs that are not tuple.
Here is an illustration of the expected behavior of the iterator.
This behaviour is the same as :class:`chainer.Iterator`.
>>> batch = next(iterator)
>>> # batch: [in_val]
or
>>> # batch: [(in_val0, ..., in_val{n_input - 1})]
or
>>> # batch: [(in_val0, ..., in_val{n_input - 1}, rest_val0, ...)]
:func:`func` should take batch(es) of data and
return batch(es) of computed values.
Here is an illustration of the expected behavior of the function.
>>> out_vals = func([in_val0], ..., [in_val{n_input - 1}])
>>> # out_vals: [out_val]
or
>>> out_vals0, out_vals1, ... = func([in_val0], ..., [in_val{n_input - 1}])
>>> # out_vals0: [out_val0]
>>> # out_vals1: [out_val1]
With :func:`apply_to_iterator`, users can get iterator(s) of values
returned by :func:`func`. It also returns iterator(s) of input values and
values that are not used for computation.
>>> in_values, out_values, rest_values = apply_to_iterator(
>>> func, iterator, n_input)
>>> # in_values: (iter of in_val0, ..., iter of in_val{n_input - 1})
>>> # out_values: (iter of out_val0, ...)
>>> # rest_values: (iter of rest_val0, ...)
Here is an exmple, which applies a pretrained Faster R-CNN to
PASCAL VOC dataset.
>>> from chainer import iterators
>>>
>>> from chainercv.datasets import VOCBBoxDataset
>>> from chainercv.links import FasterRCNNVGG16
>>> from chainercv.utils import apply_to_iterator
>>>
>>> dataset = VOCBBoxDataset(year='2007', split='test')
>>> # next(iterator) -> [(img, gt_bbox, gt_label)]
>>> iterator = iterators.SerialIterator(
... dataset, 2, repeat=False, shuffle=False)
>>>
>>> # model.predict([img]) -> ([pred_bbox], [pred_label], [pred_score])
>>> model = FasterRCNNVGG16(pretrained_model='voc07')
>>>
>>> in_values, out_values, rest_values = apply_to_iterator(
... model.predict, iterator)
>>>
>>> # in_values contains one iterator
>>> imgs, = in_values
>>> # out_values contains three iterators
>>> pred_bboxes, pred_labels, pred_scores = out_values
>>> # rest_values contains two iterators
>>> gt_bboxes, gt_labels = rest_values
Args:
func: A callable that takes batch(es) of input data and returns
computed data.
iterator (iterator): An iterator of batches.
The first :obj:`n_input` elements in each sample are
treated as input values. They are passed to :obj:`func`.
n_input (int): The number of input data. The default value is :obj:`1`.
hook: A callable that is called after each iteration.
:obj:`in_values`, :obj:`out_values`, and :obj:`rest_values`
are passed as arguments.
Note that these values do not contain data from the previous
iterations.
Returns:
Three tuples of iterators:
This function returns three tuples of iterators:
:obj:`in_values`, :obj:`out_values` and :obj:`rest_values`.
* :obj:`in_values`: A tuple of iterators. Each iterator \
returns a corresponding input value. \
For example, if :func:`func` takes \
:obj:`[in_val0], [in_val1]`, :obj:`next(in_values[0])` \
and :obj:`next(in_values[1])` will be \
:obj:`in_val0` and :obj:`in_val1`.
* :obj:`out_values`: A tuple of iterators. Each iterator \
returns a corresponding computed value. \
For example, if :func:`func` returns \
:obj:`([out_val0], [out_val1])`, :obj:`next(out_values[0])` \
and :obj:`next(out_values[1])` will be \
:obj:`out_val0` and :obj:`out_val1`.
* :obj:`rest_values`: A tuple of iterators. Each iterator \
returns a corresponding rest value. \
For example, if the :obj:`iterator` returns \
:obj:`[(in_val0, in_val1, rest_val0, rest_val1)]`, \
:obj:`next(rest_values[0])` \
and :obj:`next(rest_values[1])` will be \
:obj:`rest_val0` and :obj:`rest_val1`. \
If the input \
iterator does not give any rest values, this tuple \
will be empty.
"""

in_values, out_values, rest_values = unzip(
_apply(func, iterator, n_input, hook))

# in_values: iter of ([in_val0], [in_val1], ...)
# -> (iter of in_val0, iter of in_val1, ...)
in_values = tuple(map(_flatten, unzip(in_values)))

# out_values: iter of ([out_val0], [out_val1], ...)
# -> (iter of out_val0, iter of out_val1, ...)
out_values = tuple(map(_flatten, unzip(out_values)))

# rest_values: iter of ([rest_val0], [rest_val1], ...)
# -> (iter of rest_val0, iter of rest_val1, ...)
rest_values = tuple(map(_flatten, unzip(rest_values)))

return in_values, out_values, rest_values


def _apply(func, iterator, n_input, hook):
for batch in iterator:
# batch: [(in_val0, in_val1, ... , rest_val0, rest_val1, ...)] or
# [in_val]

in_values = []
rest_values = []
for sample in batch:
if isinstance(sample, tuple):
in_values.append(sample[0:n_input])
rest_values.append(sample[n_input:])
else:
in_values.append((sample,))
rest_values.append(())

# in_values: [(in_val0, in_val1, ...)]
# -> ([in_val0], [in_val1], ...)
in_values = tuple(list(v) for v in zip(*in_values))

# rest_values: [(rest_val0, rest_val1, ...)]
# -> ([rest_val0], [rest_val1], ...)
rest_values = tuple(list(v) for v in zip(*rest_values))

# out_values: ([out_val0], [out_val1], ...) or [out_val]
out_values = func(*in_values)
if not isinstance(out_values, tuple):
# pred_values: [out_val] -> ([out_val],)
out_values = out_values,

if hook:
hook(in_values, out_values, rest_values)

yield in_values, out_values, rest_values


def _flatten(iterator):
return (sample for batch in iterator for sample in batch)
Loading

0 comments on commit 4d239ad

Please sign in to comment.