This repository has been archived by the owner on Jul 2, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 304
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #523 from Hakuyume/improve-apply-prediction
Improve apply_prediction
- Loading branch information
Showing
14 changed files
with
354 additions
and
300 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from chainercv.utils.iterator.apply_prediction_to_iterator import apply_prediction_to_iterator # NOQA | ||
from chainercv.utils.iterator.apply_to_iterator import apply_to_iterator # NOQA | ||
from chainercv.utils.iterator.progress_hook import ProgressHook # NOQA | ||
from chainercv.utils.iterator.unzip import unzip # NOQA |
141 changes: 0 additions & 141 deletions
141
chainercv/utils/iterator/apply_prediction_to_iterator.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
from chainercv.utils.iterator.unzip import unzip | ||
|
||
|
||
def apply_to_iterator(func, iterator, n_input=1, hook=None): | ||
"""Apply a function/method to batches from an iterator. | ||
This function applies a function/method to an iterator of batches. | ||
It assumes that the iterator iterates over a collection of tuples | ||
that contain inputs to :func:`func`. | ||
Additionally, the tuples may contain values | ||
that are not used by :func:`func`. | ||
For convenience, we allow the iterator to iterate over a collection of | ||
inputs that are not tuple. | ||
Here is an illustration of the expected behavior of the iterator. | ||
This behaviour is the same as :class:`chainer.Iterator`. | ||
>>> batch = next(iterator) | ||
>>> # batch: [in_val] | ||
or | ||
>>> # batch: [(in_val0, ..., in_val{n_input - 1})] | ||
or | ||
>>> # batch: [(in_val0, ..., in_val{n_input - 1}, rest_val0, ...)] | ||
:func:`func` should take batch(es) of data and | ||
return batch(es) of computed values. | ||
Here is an illustration of the expected behavior of the function. | ||
>>> out_vals = func([in_val0], ..., [in_val{n_input - 1}]) | ||
>>> # out_vals: [out_val] | ||
or | ||
>>> out_vals0, out_vals1, ... = func([in_val0], ..., [in_val{n_input - 1}]) | ||
>>> # out_vals0: [out_val0] | ||
>>> # out_vals1: [out_val1] | ||
With :func:`apply_to_iterator`, users can get iterator(s) of values | ||
returned by :func:`func`. It also returns iterator(s) of input values and | ||
values that are not used for computation. | ||
>>> in_values, out_values, rest_values = apply_to_iterator( | ||
>>> func, iterator, n_input) | ||
>>> # in_values: (iter of in_val0, ..., iter of in_val{n_input - 1}) | ||
>>> # out_values: (iter of out_val0, ...) | ||
>>> # rest_values: (iter of rest_val0, ...) | ||
Here is an exmple, which applies a pretrained Faster R-CNN to | ||
PASCAL VOC dataset. | ||
>>> from chainer import iterators | ||
>>> | ||
>>> from chainercv.datasets import VOCBBoxDataset | ||
>>> from chainercv.links import FasterRCNNVGG16 | ||
>>> from chainercv.utils import apply_to_iterator | ||
>>> | ||
>>> dataset = VOCBBoxDataset(year='2007', split='test') | ||
>>> # next(iterator) -> [(img, gt_bbox, gt_label)] | ||
>>> iterator = iterators.SerialIterator( | ||
... dataset, 2, repeat=False, shuffle=False) | ||
>>> | ||
>>> # model.predict([img]) -> ([pred_bbox], [pred_label], [pred_score]) | ||
>>> model = FasterRCNNVGG16(pretrained_model='voc07') | ||
>>> | ||
>>> in_values, out_values, rest_values = apply_to_iterator( | ||
... model.predict, iterator) | ||
>>> | ||
>>> # in_values contains one iterator | ||
>>> imgs, = in_values | ||
>>> # out_values contains three iterators | ||
>>> pred_bboxes, pred_labels, pred_scores = out_values | ||
>>> # rest_values contains two iterators | ||
>>> gt_bboxes, gt_labels = rest_values | ||
Args: | ||
func: A callable that takes batch(es) of input data and returns | ||
computed data. | ||
iterator (iterator): An iterator of batches. | ||
The first :obj:`n_input` elements in each sample are | ||
treated as input values. They are passed to :obj:`func`. | ||
n_input (int): The number of input data. The default value is :obj:`1`. | ||
hook: A callable that is called after each iteration. | ||
:obj:`in_values`, :obj:`out_values`, and :obj:`rest_values` | ||
are passed as arguments. | ||
Note that these values do not contain data from the previous | ||
iterations. | ||
Returns: | ||
Three tuples of iterators: | ||
This function returns three tuples of iterators: | ||
:obj:`in_values`, :obj:`out_values` and :obj:`rest_values`. | ||
* :obj:`in_values`: A tuple of iterators. Each iterator \ | ||
returns a corresponding input value. \ | ||
For example, if :func:`func` takes \ | ||
:obj:`[in_val0], [in_val1]`, :obj:`next(in_values[0])` \ | ||
and :obj:`next(in_values[1])` will be \ | ||
:obj:`in_val0` and :obj:`in_val1`. | ||
* :obj:`out_values`: A tuple of iterators. Each iterator \ | ||
returns a corresponding computed value. \ | ||
For example, if :func:`func` returns \ | ||
:obj:`([out_val0], [out_val1])`, :obj:`next(out_values[0])` \ | ||
and :obj:`next(out_values[1])` will be \ | ||
:obj:`out_val0` and :obj:`out_val1`. | ||
* :obj:`rest_values`: A tuple of iterators. Each iterator \ | ||
returns a corresponding rest value. \ | ||
For example, if the :obj:`iterator` returns \ | ||
:obj:`[(in_val0, in_val1, rest_val0, rest_val1)]`, \ | ||
:obj:`next(rest_values[0])` \ | ||
and :obj:`next(rest_values[1])` will be \ | ||
:obj:`rest_val0` and :obj:`rest_val1`. \ | ||
If the input \ | ||
iterator does not give any rest values, this tuple \ | ||
will be empty. | ||
""" | ||
|
||
in_values, out_values, rest_values = unzip( | ||
_apply(func, iterator, n_input, hook)) | ||
|
||
# in_values: iter of ([in_val0], [in_val1], ...) | ||
# -> (iter of in_val0, iter of in_val1, ...) | ||
in_values = tuple(map(_flatten, unzip(in_values))) | ||
|
||
# out_values: iter of ([out_val0], [out_val1], ...) | ||
# -> (iter of out_val0, iter of out_val1, ...) | ||
out_values = tuple(map(_flatten, unzip(out_values))) | ||
|
||
# rest_values: iter of ([rest_val0], [rest_val1], ...) | ||
# -> (iter of rest_val0, iter of rest_val1, ...) | ||
rest_values = tuple(map(_flatten, unzip(rest_values))) | ||
|
||
return in_values, out_values, rest_values | ||
|
||
|
||
def _apply(func, iterator, n_input, hook): | ||
for batch in iterator: | ||
# batch: [(in_val0, in_val1, ... , rest_val0, rest_val1, ...)] or | ||
# [in_val] | ||
|
||
in_values = [] | ||
rest_values = [] | ||
for sample in batch: | ||
if isinstance(sample, tuple): | ||
in_values.append(sample[0:n_input]) | ||
rest_values.append(sample[n_input:]) | ||
else: | ||
in_values.append((sample,)) | ||
rest_values.append(()) | ||
|
||
# in_values: [(in_val0, in_val1, ...)] | ||
# -> ([in_val0], [in_val1], ...) | ||
in_values = tuple(list(v) for v in zip(*in_values)) | ||
|
||
# rest_values: [(rest_val0, rest_val1, ...)] | ||
# -> ([rest_val0], [rest_val1], ...) | ||
rest_values = tuple(list(v) for v in zip(*rest_values)) | ||
|
||
# out_values: ([out_val0], [out_val1], ...) or [out_val] | ||
out_values = func(*in_values) | ||
if not isinstance(out_values, tuple): | ||
# pred_values: [out_val] -> ([out_val],) | ||
out_values = out_values, | ||
|
||
if hook: | ||
hook(in_values, out_values, rest_values) | ||
|
||
yield in_values, out_values, rest_values | ||
|
||
|
||
def _flatten(iterator): | ||
return (sample for batch in iterator for sample in batch) |
Oops, something went wrong.