-
Notifications
You must be signed in to change notification settings - Fork 304
Improve apply_prediction #523
Changes from 11 commits
4179b00
6ebf6e6
4afd154
14d0973
cdbc657
592d30a
4db7d8b
2a6a4a2
7a57c6c
77acf2e
2d3cb55
a4e7c2e
2e052a4
f920744
cc91196
0bd2f6d
2e449a3
583ed87
09f7ade
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from chainercv.utils.iterator.apply_prediction_to_iterator import apply_prediction_to_iterator # NOQA | ||
from chainercv.utils.iterator.apply_to_iterator import apply_to_iterator # NOQA | ||
from chainercv.utils.iterator.progress_hook import ProgressHook # NOQA | ||
from chainercv.utils.iterator.unzip import unzip # NOQA |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
from chainercv.utils.iterator.unzip import unzip | ||
|
||
|
||
def apply_to_iterator(func, iterator, n_input=1, hook=None): | ||
"""Apply a function/method to batches from an iterator. | ||
|
||
This function applies a function/method to an iterator of batches. | ||
|
||
It assumes that the iterator iterates over a collection of tuples | ||
that contain inputs to :func:`func`. | ||
Additionally, the tuples may contain values that are not used by func. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/by func/by the funciton |
||
For convenience, we allow the iterator to iterate over a collection of | ||
inputs that are not tuple. | ||
Here is an illustration of the expected behavior of the iterator. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this comment is helpful for Chainer users. Here is an .... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This behaviour is the same as chainer.Iterator This one looks good! (“the” was missing) |
||
|
||
>>> batch = next(iterator) | ||
>>> # batch: [in_val] | ||
or | ||
>>> # batch: [(in_val0, ..., in_val{n_input})] | ||
or | ||
>>> # batch: [(in_val0, ..., in_val{n_input}, rest_val0, ...)] | ||
|
||
:func:`func` should take batch(es) of data and | ||
return batch(es) of computed values. | ||
Here is an illustration of the expected behavior of the function. | ||
|
||
>>> out_vals = func(in_val0, ..., in_val{n_input}) | ||
>>> # out_vals: [out_val] | ||
or | ||
>>> out_vals0, out_vals1, ... = func(in_val0, ..., in_val{n_input}) | ||
>>> # out_vals0: [out_val0] | ||
>>> # out_vals1: [out_val1] | ||
|
||
With :func:`apply_to_iterator`, users can get iterator(s) of values | ||
returned by :func:`func`. It also returns iterator(s) of input values and | ||
values that are not used for computation. | ||
|
||
>>> in_values, out_values, rest_values = apply_to_iterator( | ||
>>> func, iterator, n_input) | ||
>>> # in_values: (iter of in_val0, ..., iter of in_val{n_input}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. n_input --> n_input-1 |
||
>>> # out_values: (iter of out_val0, ...) | ||
>>> # rest_values: (iter of rest_val0, ...) | ||
|
||
Here is an exmple, which applies a pretrained Faster R-CNN to | ||
PASCAL VOC dataset. | ||
|
||
>>> from chainer import iterators | ||
>>> | ||
>>> from chainercv.datasets import VOCBBoxDataset | ||
>>> from chainercv.links import FasterRCNNVGG16 | ||
>>> from chainercv.utils import apply_to_iterator | ||
>>> | ||
>>> dataset = VOCBBoxDataset(year='2007', split='test') | ||
>>> # next(iterator) -> [(img, gt_bbox, gt_label)] | ||
>>> iterator = iterators.SerialIterator( | ||
... dataset, 2, repeat=False, shuffle=False) | ||
>>> | ||
>>> # model.predict([img]) -> ([pred_bbox], [pred_label], [pred_score]) | ||
>>> model = FasterRCNNVGG16(pretrained_model='voc07') | ||
>>> | ||
>>> in_values, out_values, rest_values = apply_to_iterator( | ||
... model.predict, iterator) | ||
>>> | ||
>>> # in_values contains one iterator | ||
>>> imgs, = in_values | ||
>>> # out_values contains three iterators | ||
>>> pred_bboxes, pred_labels, pred_scores = out_values | ||
>>> # rest_values contains two iterators | ||
>>> gt_bboxes, gt_labels = rest_values | ||
|
||
Args: | ||
func: A callable that takes batch(es) of input data and returns | ||
computed data. | ||
iterator (chainer.Iterator): An iterator of batch. | ||
The first :obj:`n_input` elements in each sample are | ||
treated as input values. They are passed to :obj:`func`. | ||
n_input (int): The number of input data. The default value is :obj:`1`. | ||
hook: A callable that is called after each iteration. | ||
:obj:`in_values`, :obj:`out_values`, and :obj:`rest_values` | ||
are passed as arguments. | ||
Note that these values do not contain data from the previous | ||
iterations. | ||
|
||
Returns: | ||
Three tuples of iterators: | ||
This function returns three tuples of iterators: | ||
:obj:`in_values`, :obj:`out_values` and :obj:`rest_values`. | ||
|
||
* :obj:`in_values`: A tuple of iterators. Each iterator \ | ||
returns a corresponding input value. \ | ||
For example, if :func:`func` takes \ | ||
:obj:`[in_val0], [in_val1]`, :obj:`next(in_values[0])` \ | ||
and :obj:`next(in_values[1])` will be \ | ||
:obj:`in_val0` and :obj:`in_val1`. | ||
* :obj:`out_values`: A tuple of iterators. Each iterator \ | ||
returns a corresponding computed value. \ | ||
For example, if :func:`func` returns \ | ||
:obj:`([out_val0], [out_val1])`, :obj:`next(out_values[0])` \ | ||
and :obj:`next(out_values[1])` will be \ | ||
:obj:`out_val0` and :obj:`out_val1`. | ||
* :obj:`rest_values`: A tuple of iterators. Each iterator \ | ||
returns a corresponding rest value. \ | ||
For example, if the :obj:`iterator` returns \ | ||
:obj:`[(in_val0, in_val1, rest_val0, rest_val1)]`, \ | ||
:obj:`next(rest_values[0])` \ | ||
and :obj:`next(rest_values[1])` will be \ | ||
:obj:`rest_val0` and :obj:`rest_val1`. \ | ||
If the input \ | ||
iterator does not give any rest values, this tuple \ | ||
will be empty. | ||
""" | ||
|
||
in_values, out_values, rest_values = unzip( | ||
_apply(func, iterator, n_input, hook)) | ||
|
||
# in_values: iter of ([in_val0], [in_val1], ...) | ||
# -> (iter of in_val0, iter of in_val1, ...) | ||
in_values = tuple(map(_flatten, unzip(in_values))) | ||
|
||
# out_values: iter of ([out_val0], [out_val1], ...) | ||
# -> (iter of out_val0, iter of out_val1, ...) | ||
out_values = tuple(map(_flatten, unzip(out_values))) | ||
|
||
# rest_values: iter of ([rest_val0], [rest_val1], ...) | ||
# -> (iter of rest_val0, iter of rest_val1, ...) | ||
rest_values = tuple(map(_flatten, unzip(rest_values))) | ||
|
||
return in_values, out_values, rest_values | ||
|
||
|
||
def _apply(func, iterator, n_input, hook): | ||
for batch in iterator: | ||
# batch: [(in_val0, in_val1, ... , rest_val0, rest_val1, ...)] or | ||
# [in_val] | ||
|
||
in_values = [] | ||
rest_values = [] | ||
for sample in batch: | ||
if isinstance(sample, tuple): | ||
in_values.append(sample[0:n_input]) | ||
rest_values.append(sample[n_input:]) | ||
else: | ||
in_values.append((sample,)) | ||
rest_values.append(tuple()) | ||
|
||
# in_values: [(in_val0, in_val1, ...)] | ||
# -> ([in_val0], [in_val1], ...) | ||
in_values = tuple(list(v) for v in zip(*in_values)) | ||
|
||
# rest_values: [(rest_val0, rest_val1, ...)] | ||
# -> ([rest_val0], [rest_val1], ...) | ||
rest_values = tuple(list(v) for v in zip(*rest_values)) | ||
|
||
# out_values: ([out_val0], [out_val1], ...) or [out_val] | ||
out_values = func(*in_values) | ||
if not isinstance(out_values, tuple): | ||
# pred_values: [out_val] -> ([out_val],) | ||
out_values = out_values, | ||
|
||
if hook: | ||
hook(in_values, out_values, rest_values) | ||
|
||
yield in_values, out_values, rest_values | ||
|
||
|
||
def _flatten(iterator): | ||
return (sample for batch in iterator for sample in batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function applies a function/method to batches from an iterator of type
chainer.Iterator
.This makes clear that we are talking about Chainer's iterator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not only for
chainer.Iterator
. It can work with other iterators.an iterator of type chainer.Iterator
sounds it requires some special features.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Then we should change args doc.
https://github.com/Hakuyume/chainercv/blob/0bd2f6d970aeaca436e6858b0e53661e364c1fd6/chainercv/utils/iterator/apply_to_iterator.py#L75
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing out.