diff --git a/python/mxnet/ndarray/numpy_extension/random.py b/python/mxnet/ndarray/numpy_extension/random.py index 8bd17cf092b0..1ddd28f9e013 100644 --- a/python/mxnet/ndarray/numpy_extension/random.py +++ b/python/mxnet/ndarray/numpy_extension/random.py @@ -165,18 +165,22 @@ def uniform_n(low=0.0, high=1.0, batch_shape=None, dtype=None, ctx=None): ctx = current_context() if batch_shape == (): batch_shape = None + else: + if isinstance(batch_shape, int): + batch_shape = (batch_shape,) + batch_shape = (-2,) + batch_shape if input_type == (True, True): - return _npi.uniform_n(low, high, low=None, high=None, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.uniform(low, high, low=None, high=None, size=batch_shape, + ctx=ctx, dtype=dtype) elif input_type == (False, True): - return _npi.uniform_n(high, low=low, high=None, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.uniform(high, low=low, high=None, size=batch_shape, + ctx=ctx, dtype=dtype) elif input_type == (True, False): - return _npi.uniform_n(low, low=None, high=high, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.uniform(low, low=None, high=high, size=batch_shape, + ctx=ctx, dtype=dtype) else: - return _npi.uniform_n(low=low, high=high, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.uniform(low=low, high=high, size=batch_shape, + ctx=ctx, dtype=dtype) def normal_n(loc=0.0, scale=1.0, batch_shape=None, dtype=None, ctx=None): @@ -252,15 +256,19 @@ def normal_n(loc=0.0, scale=1.0, batch_shape=None, dtype=None, ctx=None): ctx = current_context() if batch_shape == (): batch_shape = None + else: + if isinstance(batch_shape, int): + batch_shape = (batch_shape,) + batch_shape = (-2,) + batch_shape if input_type == (True, True): - return _npi.normal_n(loc, scale, loc=None, scale=None, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.normal(loc, scale, loc=None, scale=None, size=batch_shape, + ctx=ctx, dtype=dtype) elif input_type == (False, True): - return _npi.normal_n(scale, loc=loc, scale=None, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.normal(scale, loc=loc, scale=None, size=batch_shape, + ctx=ctx, dtype=dtype) elif input_type == (True, False): - return _npi.normal_n(loc, loc=None, scale=scale, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.normal(loc, loc=None, scale=scale, size=batch_shape, + ctx=ctx, dtype=dtype) else: - return _npi.normal_n(loc=loc, scale=scale, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.normal(loc=loc, scale=scale, size=batch_shape, + ctx=ctx, dtype=dtype) diff --git a/python/mxnet/symbol/numpy_extension/random.py b/python/mxnet/symbol/numpy_extension/random.py index 35bc8489c27e..bad6a74d139f 100644 --- a/python/mxnet/symbol/numpy_extension/random.py +++ b/python/mxnet/symbol/numpy_extension/random.py @@ -165,18 +165,22 @@ def uniform_n(low=0.0, high=1.0, batch_shape=None, dtype=None, ctx=None): ctx = current_context() if batch_shape == (): batch_shape = None + else: + if isinstance(batch_shape, int): + batch_shape = (batch_shape,) + batch_shape = (-2,) + batch_shape if input_type == (True, True): - return _npi.uniform_n(low, high, low=None, high=None, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.uniform(low, high, low=None, high=None, size=batch_shape, + ctx=ctx, dtype=dtype) elif input_type == (False, True): - return _npi.uniform_n(high, low=low, high=None, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.uniform(high, low=low, high=None, size=batch_shape, + ctx=ctx, dtype=dtype) elif input_type == (True, False): - return _npi.uniform_n(low, low=None, high=high, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.uniform(low, low=None, high=high, size=batch_shape, + ctx=ctx, dtype=dtype) else: - return _npi.uniform_n(low=low, high=high, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.uniform(low=low, high=high, size=batch_shape, + ctx=ctx, dtype=dtype) def normal_n(loc=0.0, scale=1.0, batch_shape=None, dtype=None, ctx=None): @@ -252,15 +256,19 @@ def normal_n(loc=0.0, scale=1.0, batch_shape=None, dtype=None, ctx=None): ctx = current_context() if batch_shape == (): batch_shape = None + else: + if isinstance(batch_shape, int): + batch_shape = (batch_shape,) + batch_shape = (-2,) + batch_shape if input_type == (True, True): - return _npi.normal_n(loc, scale, loc=None, scale=None, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.normal(loc, scale, loc=None, scale=None, size=batch_shape, + ctx=ctx, dtype=dtype) elif input_type == (False, True): - return _npi.normal_n(scale, loc=loc, scale=None, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.normal(scale, loc=loc, scale=None, size=batch_shape, + ctx=ctx, dtype=dtype) elif input_type == (True, False): - return _npi.normal_n(loc, loc=None, scale=scale, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.normal(loc, loc=None, scale=scale, size=batch_shape, + ctx=ctx, dtype=dtype) else: - return _npi.normal_n(loc=loc, scale=scale, size=batch_shape, - ctx=ctx, dtype=dtype) + return _npi.normal(loc=loc, scale=scale, size=batch_shape, + ctx=ctx, dtype=dtype) diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h index e8358294eaf0..375b8d225ddf 100644 --- a/src/operator/numpy/random/dist_common.h +++ b/src/operator/numpy/random/dist_common.h @@ -143,33 +143,60 @@ template inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs, std::vector *in_attrs, std::vector *out_attrs) { + // The inferShape function for sampling Ops has two modes: Concat/Broadcast, + // if size[0] == -2, the Concat schema will be selected: + // output_size = (size[1:],) + broadcast(param1.shape, param2.shape) + // otherwise output_size = broadcast(param1.shape, param2.shape, size) const DistParam ¶m = nnvm::get(attrs.parsed); + // Variable indicating the mode. + bool concat_mode = false; + // Variable storing the info from `size` parameter. + std::vector oshape_vec; if (param.size.has_value()) { // Size declared. - std::vector oshape_vec; const mxnet::Tuple &size = param.size.value(); - for (int i = 0; i < size.ndim(); ++i) { + int head = size[0]; + if (head == -2) { + concat_mode = true; + } else { + oshape_vec.emplace_back(head); + } + for (int i = 1; i < size.ndim(); ++i) { oshape_vec.emplace_back(size[i]); } + // If under the broadcast mode, `size` is equivalent to the final output_size. + if (!concat_mode) { SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec)); for (size_t input_idx = 0; input_idx < in_attrs->size(); input_idx++) { CheckBroadcastable((*in_attrs)[input_idx], (*out_attrs)[0]); + } } - } else { - // Size undeclared. + } + // Under concat mode, or `size` is not declared. + if (concat_mode || (!param.size.has_value())) { + // broadcast(param1.shape, param2.shape). + mxnet::TShape param_broadcast_shape; if (in_attrs->size() == 2U) { // Both params from ndarray. - mxnet::TShape &low = (*in_attrs)[0]; - mxnet::TShape &high = (*in_attrs)[1]; - mxnet::TShape out(std::max(low.ndim(), high.ndim()), -1); - InferBroadcastShape(low, high, &out); - SHAPE_ASSIGN_CHECK(*out_attrs, 0, out); + mxnet::TShape ¶m1 = (*in_attrs)[0]; + mxnet::TShape ¶m2 = (*in_attrs)[1]; + mxnet::TShape out(std::max(param1.ndim(), param2.ndim()), -1); + InferBroadcastShape(param1, param2, &out); + param_broadcast_shape = out; } else if (in_attrs->size() == 1U) { // One param from ndarray. - SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)) + param_broadcast_shape = in_attrs->at(0); } else if (in_attrs->size() == 0) { // Two scalar case. - SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)) + param_broadcast_shape = TShape(0, -1); + } + if (concat_mode) { + for (int i = 0; i < param_broadcast_shape.ndim(); ++i) { + oshape_vec.emplace_back(param_broadcast_shape[i]); + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec)); + } else { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, param_broadcast_shape); } } if (out_attrs->size() == 2U) { diff --git a/src/operator/numpy/random/np_bernoulli_op.cc b/src/operator/numpy/random/np_bernoulli_op.cc index d67ad1b8d7f6..1377d525015d 100644 --- a/src/operator/numpy/random/np_bernoulli_op.cc +++ b/src/operator/numpy/random/np_bernoulli_op.cc @@ -53,7 +53,7 @@ NNVM_REGISTER_OP(_npi_bernoulli) return (num_inputs == 0) ? std::vector() : std::vector{"input1"}; }) .set_attr_parser(ParamParser) -.set_attr("FInferShape", UnaryDistOpShape) +.set_attr("FInferShape", TwoparamsDistOpShape) .set_attr("FInferType", NumpyBernoulliOpType) .set_attr("FResourceRequest", [](const nnvm::NodeAttrs& attrs) { diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 218f80da4a35..268d58c7026a 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3764,10 +3764,9 @@ def __init__(self, shape, op_name): def hybrid_forward(self, F, param1, param2): op = getattr(F.npx.random, self._op_name, None) assert op is not None - # return param1 + param2 + op(batch_shape=self._shape) return op(param1, param2, batch_shape=self._shape) - batch_shapes = [(10,), (2, 3), 6, (), None] + batch_shapes = [(10,), (2, 3), 6, ()] event_shapes = [(), (2,), (2,2)] dtypes = ['float16', 'float32', 'float64'] op_names = ['uniform_n', 'normal_n']