Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Numpy compatible max #15161

Merged
merged 24 commits into from
Jun 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> {
}
};

struct NumpyMaxParam : public dmlc::Parameter<NumpyMaxParam> {
dmlc::optional<mxnet::Tuple<int>> axis;
bool keepdims;
dmlc::optional<double> initial;
DMLC_DECLARE_PARAMETER(NumpyMaxParam) {
DMLC_DECLARE_FIELD(axis)
.set_default(dmlc::optional<mxnet::Tuple<int>>())
.describe("Axis or axes along which a sum is performed. The default, axis=None, will sum "
"all of the elements of the input array. If axis is negative it counts from the "
"last to the first axis.");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
DMLC_DECLARE_FIELD(initial).set_default(dmlc::optional<double>())
.describe("Starting value for the sum.");
}
};

inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape,
const dmlc::optional<mxnet::Tuple<int>>& axis,
bool keepdims) {
Expand Down Expand Up @@ -152,6 +170,39 @@ inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs,
return shape_is_known(out_attrs->at(0));
}

inline bool NumpyMaxShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!shape_is_known(in_attrs->at(0))) {
return false;
}
const NumpyMaxParam& param = nnvm::get<NumpyMaxParam>(attrs.parsed);
// check the case where the reduction axis should not be zero
bool is_all_reducded_axes_not_zero = true;
const TShape& ishape = (*in_attrs)[0];
if (param.axis.has_value()) {
const mxnet::Tuple<int>& axes = param.axis.value();
for (int i = 0; i < axes.ndim(); ++i) {
if (ishape[axes[i]] == 0) {
is_all_reducded_axes_not_zero = false;
stu1130 marked this conversation as resolved.
Show resolved Hide resolved
break;
}
}
} else {
if (ishape.Size() == 0) {
// global reduction should excuted only when input have size more than 0
is_all_reducded_axes_not_zero = false;
}
}
CHECK(is_all_reducded_axes_not_zero)
<< "zero-size array to reduction operation maximum which has no identity";
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims));
return shape_is_known(out_attrs->at(0));
}

template<bool safe_acc_hint = false>
inline bool NeedSafeAcc(int itype, int otype) {
bool rule = (itype != otype) || (itype != mshadow::kFloat32 && itype != mshadow::kFloat64);
Expand Down Expand Up @@ -187,6 +238,29 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
}
}

template<typename xpu, typename reducer, typename OP = op::mshadow_op::identity>
void NumpyMaxCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const NumpyMaxParam& param = nnvm::get<NumpyMaxParam>(attrs.parsed);
if (param.initial.has_value()) {
LOG(FATAL) << "initial is not supported yet";
}
if (inputs[0].shape_.Size() == 0U || outputs[0].shape_.Size() == 0U) return; // zero-size tensor
if (param.axis.has_value() && param.axis.value().ndim() == 0) {
UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs);
}
TShape small;
if (param.keepdims) {
small = outputs[0].shape_;
} else {
small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
}
ReduceAxesComputeImpl<xpu, reducer, false, false, OP>(ctx, inputs, req, outputs, small);
}

template<typename xpu, bool normalize = false>
inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -213,6 +287,24 @@ inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
}
}

template<typename xpu, typename OP>
void NumpyMaxBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const NumpyMaxParam& param = nnvm::get<NumpyMaxParam>(attrs.parsed);
TShape small;
if (param.keepdims) {
small = inputs[0].shape_;
} else {
small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true);
}
ReduceAxesBackwardUseInOutImpl<xpu, OP, false>(ctx, small, inputs, req, outputs);
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_
39 changes: 39 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(NumpyReduceAxesParam);
DMLC_REGISTER_PARAMETER(NumpyMaxParam);

inline bool NumpySumType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
Expand Down Expand Up @@ -128,5 +129,43 @@ NNVM_REGISTER_OP(_backward_np_mean)
.set_num_inputs(1)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesBackwardUseNone<cpu, true>);

inline bool NumpyMaxType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));

return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}

NNVM_REGISTER_OP(_np_max)
.describe(R"code()code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyMaxParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyMaxShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyMaxType)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
})
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_arguments(NumpyMaxParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyMaxCompute<cpu, mshadow::red::maximum>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient", ReduceGrad{"_backward_np_max"});

NNVM_REGISTER_OP(_backward_np_max)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyMaxParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_num_inputs(3)
.set_attr<FCompute>("FCompute<cpu>", NumpyMaxBackward<cpu, mshadow_op::eq>);

} // namespace op
} // namespace mxnet
5 changes: 5 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ NNVM_REGISTER_OP(_np_mean)
NNVM_REGISTER_OP(_backward_np_mean)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu, true>);

NNVM_REGISTER_OP(_np_max)
.set_attr<FCompute>("FCompute<gpu>", NumpyMaxCompute<gpu, mshadow::red::maximum>);

NNVM_REGISTER_OP(_backward_np_max)
.set_attr<FCompute>("FCompute<gpu>", NumpyMaxBackward<gpu, mshadow_op::eq>);

} // namespace op
} // namespace mxnet
92 changes: 91 additions & 1 deletion tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
import numpy as _np
import mxnet as mx
from mxnet import np, npx
from mxnet.base import MXNetError
from mxnet.gluon import HybridBlock
from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray
from mxnet.test_utils import check_numeric_gradient
from common import with_seed
from common import assertRaises, with_seed
import random


Expand Down Expand Up @@ -197,6 +198,95 @@ def is_int(dtype):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@npx.use_np_shape
def test_np_max():
@npx.use_np_shape
class TestMax(HybridBlock):
def __init__(self, axis=None, keepdims=False):
super(TestMax, self).__init__()
self._axis = axis
self._keepdims = keepdims

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.max(a, axis=self._axis, keepdims=self._keepdims)

def is_int(dtype):
return 'int' == dtype

def get_grad(axis):
if axis == ():
return _np.ones((2,3,4,5))
else:
temp = _np.zeros((2,3,4,5))
if axis == 0:
temp[-1,:,:,:] = 1
return temp
elif axis == 1:
temp[:,-1,:,:] = 1
return temp
elif axis == 2:
temp[:,:,-1,:] = 1
return temp
elif axis == 3:
temp[:,:,:,-1] = 1
return temp
elif not axis:
temp[-1,-1,-1,-1] = 1
return temp
raise ValueError('axis should be int or None or ()')

def _test_np_max_exception(shape, dim):
x = _np.random.uniform(-1.0, 1.0, shape)
x = mx.nd.array(x).as_np_ndarray()
out = mx.np.max(x)
assert out.ndim == dim, 'dimension mismatch, output.ndim={}, dim={}'.format(output.ndim, dim)

in_data_dim = random.choice([2, 3, 4])
shape = rand_shape_nd(in_data_dim, dim=3)
for hybridize in [False, True]:
for keepdims in [True, False]:
for axis in ([i for i in range(in_data_dim)] + [(), None]):
for itype in ['float16', 'float32', 'float64', 'int']:
# test gluon
test_max = TestMax(axis=axis, keepdims=keepdims)
if hybridize:
test_max.hybridize()
if is_int(itype):
x = mx.nd.arange(120).reshape((2, 3, 4, 5))
x = mx.nd.array(x)
else:
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
x = x.as_np_ndarray()
x.attach_grad()
expected_ret = _np.amax(x.asnumpy(), axis=axis, keepdims=keepdims)
with mx.autograd.record():
y = test_max(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if itype == 'float16' else 1e-3,
atol=1e-5 if itype == 'float16' else 1e-5)
y.backward()
# only check the gradient with hardcoded input
if is_int(itype):
assert same(x.grad.asnumpy(), get_grad(axis)), \
'x={}\ny={}\nx.grad={}\nnumpy={}'.format(x.asnumpy(), y.asnumpy(), x.grad.asnumpy(), get_grad(axis))

# test imperative
mx_out = np.max(x, axis=axis, keepdims=keepdims)
np_out = _np.amax(x.asnumpy(), axis=axis, keepdims=keepdims)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

# test zero and zero dim
shapes = [(), (0), (2, 0), (0, 2, 1)]
exceptions = [False, True, True, True]
dims = [0] * len(shapes)
for shape, exception, dim in zip(shapes, exceptions, dims):
if exception:
assertRaises(MXNetError, _test_np_max_exception, shape, dim)
else:
_test_np_max_exception(shape, dim)


@with_seed()
@npx.use_np_shape
def test_np_transpose():
Expand Down