Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactor sample_n (#17618)
Browse files Browse the repository at this point in the history
  • Loading branch information
xidulu committed Feb 28, 2020
1 parent 0e6ab21 commit 1af06d9
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 46 deletions.
40 changes: 24 additions & 16 deletions python/mxnet/ndarray/numpy_extension/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,22 @@ def uniform_n(low=0.0, high=1.0, batch_shape=None, dtype=None, ctx=None):
ctx = current_context()
if batch_shape == ():
batch_shape = None
else:
if isinstance(batch_shape, int):
batch_shape = (batch_shape,)
batch_shape = (-2,) + batch_shape
if input_type == (True, True):
return _npi.uniform_n(low, high, low=None, high=None, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.uniform(low, high, low=None, high=None, size=batch_shape,
ctx=ctx, dtype=dtype)
elif input_type == (False, True):
return _npi.uniform_n(high, low=low, high=None, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.uniform(high, low=low, high=None, size=batch_shape,
ctx=ctx, dtype=dtype)
elif input_type == (True, False):
return _npi.uniform_n(low, low=None, high=high, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.uniform(low, low=None, high=high, size=batch_shape,
ctx=ctx, dtype=dtype)
else:
return _npi.uniform_n(low=low, high=high, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.uniform(low=low, high=high, size=batch_shape,
ctx=ctx, dtype=dtype)


def normal_n(loc=0.0, scale=1.0, batch_shape=None, dtype=None, ctx=None):
Expand Down Expand Up @@ -252,15 +256,19 @@ def normal_n(loc=0.0, scale=1.0, batch_shape=None, dtype=None, ctx=None):
ctx = current_context()
if batch_shape == ():
batch_shape = None
else:
if isinstance(batch_shape, int):
batch_shape = (batch_shape,)
batch_shape = (-2,) + batch_shape
if input_type == (True, True):
return _npi.normal_n(loc, scale, loc=None, scale=None, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.normal(loc, scale, loc=None, scale=None, size=batch_shape,
ctx=ctx, dtype=dtype)
elif input_type == (False, True):
return _npi.normal_n(scale, loc=loc, scale=None, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.normal(scale, loc=loc, scale=None, size=batch_shape,
ctx=ctx, dtype=dtype)
elif input_type == (True, False):
return _npi.normal_n(loc, loc=None, scale=scale, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.normal(loc, loc=None, scale=scale, size=batch_shape,
ctx=ctx, dtype=dtype)
else:
return _npi.normal_n(loc=loc, scale=scale, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.normal(loc=loc, scale=scale, size=batch_shape,
ctx=ctx, dtype=dtype)
40 changes: 24 additions & 16 deletions python/mxnet/symbol/numpy_extension/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,22 @@ def uniform_n(low=0.0, high=1.0, batch_shape=None, dtype=None, ctx=None):
ctx = current_context()
if batch_shape == ():
batch_shape = None
else:
if isinstance(batch_shape, int):
batch_shape = (batch_shape,)
batch_shape = (-2,) + batch_shape
if input_type == (True, True):
return _npi.uniform_n(low, high, low=None, high=None, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.uniform(low, high, low=None, high=None, size=batch_shape,
ctx=ctx, dtype=dtype)
elif input_type == (False, True):
return _npi.uniform_n(high, low=low, high=None, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.uniform(high, low=low, high=None, size=batch_shape,
ctx=ctx, dtype=dtype)
elif input_type == (True, False):
return _npi.uniform_n(low, low=None, high=high, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.uniform(low, low=None, high=high, size=batch_shape,
ctx=ctx, dtype=dtype)
else:
return _npi.uniform_n(low=low, high=high, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.uniform(low=low, high=high, size=batch_shape,
ctx=ctx, dtype=dtype)


def normal_n(loc=0.0, scale=1.0, batch_shape=None, dtype=None, ctx=None):
Expand Down Expand Up @@ -252,15 +256,19 @@ def normal_n(loc=0.0, scale=1.0, batch_shape=None, dtype=None, ctx=None):
ctx = current_context()
if batch_shape == ():
batch_shape = None
else:
if isinstance(batch_shape, int):
batch_shape = (batch_shape,)
batch_shape = (-2,) + batch_shape
if input_type == (True, True):
return _npi.normal_n(loc, scale, loc=None, scale=None, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.normal(loc, scale, loc=None, scale=None, size=batch_shape,
ctx=ctx, dtype=dtype)
elif input_type == (False, True):
return _npi.normal_n(scale, loc=loc, scale=None, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.normal(scale, loc=loc, scale=None, size=batch_shape,
ctx=ctx, dtype=dtype)
elif input_type == (True, False):
return _npi.normal_n(loc, loc=None, scale=scale, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.normal(loc, loc=None, scale=scale, size=batch_shape,
ctx=ctx, dtype=dtype)
else:
return _npi.normal_n(loc=loc, scale=scale, size=batch_shape,
ctx=ctx, dtype=dtype)
return _npi.normal(loc=loc, scale=scale, size=batch_shape,
ctx=ctx, dtype=dtype)
49 changes: 38 additions & 11 deletions src/operator/numpy/random/dist_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,33 +143,60 @@ template <typename DistParam>
inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
// The inferShape function for sampling Ops has two modes: Concat/Broadcast,
// if size[0] == -2, the Concat schema will be selected:
// output_size = (size[1:],) + broadcast(param1.shape, param2.shape)
// otherwise output_size = broadcast(param1.shape, param2.shape, size)
const DistParam &param = nnvm::get<DistParam>(attrs.parsed);
// Variable indicating the mode.
bool concat_mode = false;
// Variable storing the info from `size` parameter.
std::vector<dim_t> oshape_vec;
if (param.size.has_value()) {
// Size declared.
std::vector<dim_t> oshape_vec;
const mxnet::Tuple<int> &size = param.size.value();
for (int i = 0; i < size.ndim(); ++i) {
int head = size[0];
if (head == -2) {
concat_mode = true;
} else {
oshape_vec.emplace_back(head);
}
for (int i = 1; i < size.ndim(); ++i) {
oshape_vec.emplace_back(size[i]);
}
// If under the broadcast mode, `size` is equivalent to the final output_size.
if (!concat_mode) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec));
for (size_t input_idx = 0; input_idx < in_attrs->size(); input_idx++) {
CheckBroadcastable((*in_attrs)[input_idx], (*out_attrs)[0]);
}
}
} else {
// Size undeclared.
}
// Under concat mode, or `size` is not declared.
if (concat_mode || (!param.size.has_value())) {
// broadcast(param1.shape, param2.shape).
mxnet::TShape param_broadcast_shape;
if (in_attrs->size() == 2U) {
// Both params from ndarray.
mxnet::TShape &low = (*in_attrs)[0];
mxnet::TShape &high = (*in_attrs)[1];
mxnet::TShape out(std::max(low.ndim(), high.ndim()), -1);
InferBroadcastShape(low, high, &out);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out);
mxnet::TShape &param1 = (*in_attrs)[0];
mxnet::TShape &param2 = (*in_attrs)[1];
mxnet::TShape out(std::max(param1.ndim(), param2.ndim()), -1);
InferBroadcastShape(param1, param2, &out);
param_broadcast_shape = out;
} else if (in_attrs->size() == 1U) {
// One param from ndarray.
SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0))
param_broadcast_shape = in_attrs->at(0);
} else if (in_attrs->size() == 0) {
// Two scalar case.
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1))
param_broadcast_shape = TShape(0, -1);
}
if (concat_mode) {
for (int i = 0; i < param_broadcast_shape.ndim(); ++i) {
oshape_vec.emplace_back(param_broadcast_shape[i]);
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec));
} else {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param_broadcast_shape);
}
}
if (out_attrs->size() == 2U) {
Expand Down
2 changes: 1 addition & 1 deletion src/operator/numpy/random/np_bernoulli_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ NNVM_REGISTER_OP(_npi_bernoulli)
return (num_inputs == 0) ? std::vector<std::string>() : std::vector<std::string>{"input1"};
})
.set_attr_parser(ParamParser<NumpyBernoulliParam>)
.set_attr<mxnet::FInferShape>("FInferShape", UnaryDistOpShape<NumpyBernoulliParam>)
.set_attr<mxnet::FInferShape>("FInferShape", TwoparamsDistOpShape<NumpyBernoulliParam>)
.set_attr<nnvm::FInferType>("FInferType", NumpyBernoulliOpType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
Expand Down
3 changes: 1 addition & 2 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3764,10 +3764,9 @@ def __init__(self, shape, op_name):
def hybrid_forward(self, F, param1, param2):
op = getattr(F.npx.random, self._op_name, None)
assert op is not None
# return param1 + param2 + op(batch_shape=self._shape)
return op(param1, param2, batch_shape=self._shape)

batch_shapes = [(10,), (2, 3), 6, (), None]
batch_shapes = [(10,), (2, 3), 6, ()]
event_shapes = [(), (2,), (2,2)]
dtypes = ['float16', 'float32', 'float64']
op_names = ['uniform_n', 'normal_n']
Expand Down

0 comments on commit 1af06d9

Please sign in to comment.