-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
9fa6ec2
to
322cb05
Compare
dbcaae6
to
3bcaf8e
Compare
@zheng-da @piiswrong @szha @eric-haibin-lin Hey could you help review this PR? |
python/mxnet/ndarray/contrib.py
Outdated
@@ -363,3 +362,97 @@ def _func_wrapper(loop_vars): | |||
[" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)] | |||
)) | |||
return stacked_outputs, list(loop_vars) | |||
|
|||
def ifelse(cond, then_func, else_func, inputs): | |||
"""Run a if-then-else using user-defined condition and computation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a => an
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
python/mxnet/ndarray/contrib.py
Outdated
This operator simulates a if-like branch which chooses to do one of | ||
the two customized computations according to the specified condition. | ||
|
||
`inputs` is a list of NDArrays on which the condition and computations reply on. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reply => rely
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
python/mxnet/symbol/contrib.py
Outdated
@@ -556,3 +556,154 @@ def _union_inputs(*graphs): | |||
outputs = [result[i] for i in range(num_out_data)] | |||
final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] | |||
return outputs, final_loop_vars | |||
|
|||
def ifelse(cond, then_func, else_func, inputs, name="ifelse"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please fix the same typos as the one above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) | ||
return inputs | ||
|
||
def _create_subgraph(graph_vars, graph_func, subgraph_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems this function and the function below are the same as the one in while_loop. Can you move them out and reuse them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are not exactly the same. One would search for var_locs, another doesn't.
python/mxnet/ndarray/contrib.py
Outdated
outputs = _to_ndarray_tuple(outputs, "outputs of then_func") | ||
else: | ||
outputs = else_func(*inputs) | ||
outputs = _to_ndarray_tuple(outputs, "outputs of else_func") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a way of checking if the outputs from the if branch and the else branch have the same number of outputs and the same types, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's give up ><
auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr<Symbol> subg, | ||
ShapeVector *_subg_out, | ||
const nnvm::Tuple<dim_t> &input_locs, | ||
bool fill_out_shape) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also reuse this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are not identical either.
src/operator/control_flow.cc
Outdated
params.then_input_locs, true); | ||
bool succ_2 = infer_subg(attrs.subgraphs[2], &else_out_shape, \ | ||
params.else_input_locs, true); | ||
return succ_0 && succ_1 && succ_2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need to check then_out_shape
and else_out_shape
and see if they are the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, fixed
bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &then_in_type, out_type); | ||
CHECK(sync_in_in(params.then_input_locs, in_type, &then_in_type, is_udf)); | ||
bool succ_2 = InferSubgraphDataType(*attrs.subgraphs[2], &else_in_type, out_type); | ||
CHECK(sync_in_in(params.else_input_locs, in_type, &else_in_type, is_udf)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two subgraphs write to the same out_type
, so we don't have to worry in this case.
CHECK(sync_in_in(params.then_input_locs, in_attrs, &then_in_attrs, is_udf)); | ||
bool succ_2 = InferSubgraphStorage(*attrs.subgraphs[2], dev_mask, \ | ||
&else_mode, &else_in_attrs, out_attrs); | ||
CHECK(sync_in_in(params.else_input_locs, in_attrs, &else_in_attrs, is_udf)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two subgraphs write to the same out_type
, so we don't have to worry in this case.
do you need to rebase to the master since the flaky test has been fixed. |
We are going to rename the operator from |
ifelse
operatorcond
operator
python/mxnet/ndarray/contrib.py
Outdated
@@ -363,3 +362,97 @@ def _func_wrapper(loop_vars): | |||
[" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)] | |||
)) | |||
return stacked_outputs, list(loop_vars) | |||
|
|||
def ifelse(cond, then_func, else_func, inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have both cond and inputs? Can we just have cond which could be true/false?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to make it consistent with TF?
https://www.tensorflow.org/api_docs/python/tf/cond
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may help improve user experience I think.
python/mxnet/ndarray/contrib.py
Outdated
@@ -363,3 +362,97 @@ def _func_wrapper(loop_vars): | |||
[" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)] | |||
)) | |||
return stacked_outputs, list(loop_vars) | |||
|
|||
def cond(cond_func, then_func, else_func, inputs): # pylint: disable=redefined-outer-name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove inputs and have tf style interface?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Just finished the API change. Thanks!
cond
operatorcondition
operator
@zheng-da and I have done an API change, according to the valuable comments from @eric-haibin-lin. Here is the signature of our new API: Although there is no actual difference between the old and new APIs in the backend, we believe that this change will make our API easier to use for customers. Thanks again for @eric-haibin-lin and @zheng-da for the valuable discussion! |
So could someone help merge the code? |
@eric-haibin-lin @szha @piiswrong Do you have more comments? If not, can you merge it? |
We are trying to get this in 1.3. So could someone help merge this PR? |
python/mxnet/ndarray/contrib.py
Outdated
@@ -28,7 +28,7 @@ | |||
except ImportError: | |||
pass | |||
|
|||
__all__ = ["rand_zipfian", "foreach", "while_loop"] | |||
__all__ = ["rand_zipfian", "foreach", "while_loop", "condition"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
condition
is not a good name. Maybe cond
, or conditional
or if_else
or something like that
python/mxnet/ndarray/contrib.py
Outdated
@@ -363,3 +362,87 @@ def _func_wrapper(loop_vars): | |||
[" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)] | |||
)) | |||
return stacked_outputs, list(loop_vars) | |||
|
|||
def condition(cond, then_func, else_func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change signature to be the same with other packages?
@piiswrong Our initial name is ifelse. @zheng-da proposed to change a new name ‘cond’. ‘cond’ cannot make pylint happy, because it conflicts with the first argument of ‘while_loop’. So then I change it to ‘condition’. So could you guys help let know me a best naming so that we could get it merged today? @zheng-da @piiswrong |
I change |
condition
operatorcond
operator
Thank you so much guys for offering me valuable suggestions, and making this PR possible! |
* Initial commit for `Ifelse` * Address comments * Rename ifelse to condition * API change * Trigger CI * Rename condition to cond * Fix lint
Waiting for the
while_loop
operator to be merged so that I could rebase to master. Please do not merge for now.Description
This PR is part of the proposal of adding a set of control flow operators to MXNet. Link to proposal. See also
foreach
(#11531) andwhile_loop
(#11566).Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
ifelse
operator in src/operator/control_flow.ccTODO
while_loop
to be mergedwhile_loop
because of numeric overflow