Skip to content

Commit

Permalink
Sparse square sum (apache#7206)
Browse files Browse the repository at this point in the history
* Add square_sum op

* Add unit test and fix check_numeric_gradient

* Add .cu file and example

* Fix lint

* Remove gpu registration

* Use square_sum in test_module_fm
  • Loading branch information
reminisce authored and eric-haibin-lin committed Jul 27, 2017
1 parent f0af872 commit 6f0719f
Show file tree
Hide file tree
Showing 5 changed files with 428 additions and 16 deletions.
5 changes: 4 additions & 1 deletion python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,10 @@ def random_projection(shape):
assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict"
for k, v in grad_stype_dict.items():
if k in args_grad and v in _STORAGE_TYPE_STR_TO_ID and v != 'default':
args_grad[k] = mx.nd.cast_storage(args_grad[k], stype=v)
# create an uninitialized sparse ndarray for executor
# if the symbolic grad is expected to be zero, it should not be initialized at all
args_grad[k] = mx.nd.zeros(args_grad[k].shape, args_grad[k].context,
args_grad[k].dtype, v)

executor = out.bind(ctx, grad_req=grad_req,
args=location, args_grad=args_grad, aux_states=aux_states)
Expand Down
340 changes: 340 additions & 0 deletions src/operator/tensor/square_sum-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,340 @@
/*!
* Copyright (c) 2017 by Contributors
* \file square_sum-inl.h
* \brief This is a temporary solution for fusing operators
* square and sum together as a composite op for row sparse tensors.
* The purpose for fusing square and sum for row sparse tensors
* is that the gradient of the fused operator depends on the input
* ndarray and thus its gradient is a row-sparse ndarray too.
* This fused op will become deprecated after the functionality
* of fusing operators is finished in the future.
*/

#ifndef MXNET_OPERATOR_TENSOR_SQUARE_SUM_INL_H_
#define MXNET_OPERATOR_TENSOR_SQUARE_SUM_INL_H_

#include <vector>
#include <algorithm>
#include <utility>
#include "../mxnet_op.h"
#include "./broadcast_reduce_op.h"

namespace mxnet {
namespace op {

inline bool SquareSumForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const Context& ctx,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_EQ((*in_attrs)[0], kRowSparseStorage)
<< "_square_sum only supports row-sparse ndarray as input";
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
if (param.axis[0] == 1 && param.keepdims) { // sum per row and keep dims
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage);
} else {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage);
}
return true;
}

inline bool SquareSumBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
const Context& ctx,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, 0, kDefaultStorage);
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, 1, kRowSparseStorage);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage);
return true;
}

/*!
* \brief square sum of a rsp
* if axis = -1, same as mx.nd.sum(tensor*tensor)
* if axis = 0, same as mx.nd.sum(tensor*tensor, axis=0)
* if axis = 1, same as mx.nd.sum(tensor*tensor, axis=1)
* where tensor*tensor is elemwise multiplication of two ndarrays.
*/
template<int req, int axis, bool keepdim>
struct SquareSumRspKernel;

/*!
* \brief square sum of a rsp on axis=0 without keeping the dim
*/
template<int req>
struct SquareSumRspKernel<req, 0, false> {
/*!
* \param j the element index in out_data and column id of in_data
*/
template<typename DType>
MSHADOW_XINLINE static void Map(int j, DType* out_data, const DType* in_data,
const int64_t nnr, const int64_t num_cols) {
DType sum = 0;
for (int64_t i = 0; i < nnr; ++i) {
const DType val = in_data[i*num_cols+j];
sum += val * val;
}
KERNEL_ASSIGN(out_data[j], req, sum);
}
};

/*!
* \brief square sum of a rsp on axis=1 without keeping the dim
*/
template<int req>
struct SquareSumRspKernel<req, 1, false> {
/*!
* \param i the i-th non-zero row of in_data
*/
template<typename IType, typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const IType* in_row_idx,
const DType* in_data, const int64_t num_cols) {
DType sum = 0;
const int64_t offset = i * num_cols;
for (int64_t j = 0; j < num_cols; ++j) {
const DType val = in_data[offset+j];
sum += val * val;
}
KERNEL_ASSIGN(out_data[in_row_idx[i]], req, sum);
}
};

/*!
* \brief square sum of a rsp on axis=1 keeping the dim
*/
template<int req>
struct SquareSumRspKernel<req, 1, true> {
/*!
* \param i the i-th non-zero row of in_data
*/
template<typename IType, typename DType>
MSHADOW_XINLINE static void Map(int i, IType* out_row_idx, DType* out_data,
const IType* in_row_idx, const DType* in_data,
const int64_t num_cols) {
DType sum = 0;
out_row_idx[i] = in_row_idx[i];
const int64_t offset = i * num_cols;
for (int64_t j = 0; j < num_cols; ++j) {
const DType val = in_data[offset+j];
sum += val * val;
}
KERNEL_ASSIGN(out_data[i], req, sum);
}
};

template<int req, int axis>
struct SquareSumRspGradKernel;

template<int req>
struct SquareSumRspGradKernel<req, 0> {
/*!
* \param i element index in in_grad and in_data
* \param in_grad_row_idx row_idx of the gradient of the op's input
* \param in_grad gradient of the op's input
* \param out_grad gradient of the op's output
* \param in_row_idx row idx of the op's input
* \param in_data op's input
*/
template<typename IType, typename DType>
MSHADOW_XINLINE static void Map(int i, IType* in_grad_row_idx, DType* in_grad,
const DType* out_grad, const IType* in_row_idx,
const DType* in_data, const int64_t num_cols) {
const int64_t row = i / num_cols;
in_grad_row_idx[row] = in_row_idx[row];
KERNEL_ASSIGN(in_grad[i], req, 2*in_data[i]*out_grad[i%num_cols]);
}
};

template<int req>
struct SquareSumRspGradKernel<req, 1> {
/*!
* \param i element index in in_grad and in_data
* \param in_grad_row_idx row_idx of the gradient of the op's input
* \param in_grad gradient of the op's input
* \param out_grad gradient of the op's output
* \param in_row_idx row idx of the op's input
* \param in_data op's input
*/
template<typename IType, typename DType>
MSHADOW_XINLINE static void Map(int i, IType* in_grad_row_idx, DType* in_grad,
const DType* out_grad, const IType* in_row_idx,
const DType* in_data, const int64_t num_cols) {
const int64_t row = i / num_cols;
in_grad_row_idx[row] = in_row_idx[row];
KERNEL_ASSIGN(in_grad[i], req, 2*in_data[i]*out_grad[in_row_idx[row]]);
}
};

template<typename xpu>
void SquareSumRspImpl(const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu>* s,
const NDArray& input,
const OpReqType req,
NDArray* output) {
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
CHECK_EQ(param.axis.ndim(), 1U) << "_square_sum(row_sparse_matrix) only supports axis=0 or 1";
CHECK(param.axis[0] == 0 || param.axis[0] == 1)
<< "_square_sum(row_sparse_matrix) only supports axis=0 or 1";
CHECK_EQ(input.storage_type(), kRowSparseStorage)
<< "_square_sum op only supports row-sparse matrix as input";
int64_t out_data_size = 0;
if (param.axis[0] == 0) { // axis = 0
CHECK_EQ(output->storage_type(), kDefaultStorage);
out_data_size = input.storage_shape()[1];
} else if (param.keepdims) { // axis = 1, keepdims = true
CHECK_EQ(output->storage_type(), kRowSparseStorage);
out_data_size = input.storage_shape()[0];
} else { // axis = 1, keepdims = false
CHECK_EQ(output->storage_type(), kDefaultStorage);
out_data_size = input.shape()[0];
}
CHECK_NE(req, kWriteInplace);

using namespace mxnet_op;
if (!input.storage_initialized()) {
if (req == kWriteTo && output->storage_type() == kDefaultStorage) {
MSHADOW_TYPE_SWITCH(output->data().type_flag_, DType, {
Kernel<set_zero, xpu>::Launch(s, out_data_size, output->data().dptr<DType>());
})
}
return;
}

if (output->storage_type() == kRowSparseStorage) {
output->CheckAndAlloc({input.aux_shape(rowsparse::kIdx)});
}
const TBlob& out_data = output->data();
const int64_t nnr = input.storage_shape()[0];
const int64_t num_cols = input.storage_shape()[1];
const TBlob& in_data = input.data();
if (0 == param.axis[0]) { // axis = 0, output is dense
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
Kernel<SquareSumRspKernel<req_type, 0, false>, xpu>::Launch(s, num_cols,
out_data.dptr<DType>(), input.data().dptr<DType>(), nnr, num_cols);
})
})
} else { // axis = 1
const TBlob in_row_idx = input.aux_data(rowsparse::kIdx);
if (param.keepdims) { // output is rsp
const TBlob out_row_idx = output->aux_data(rowsparse::kIdx);
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(in_row_idx.type_flag_, IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
Kernel<SquareSumRspKernel<req_type, 1, true>, xpu>::Launch(s, nnr,
out_row_idx.dptr<IType>(), out_data.dptr<DType>(), in_row_idx.dptr<IType>(),
in_data.dptr<DType>(), num_cols);
})
})
})
} else { // output is dense
if (req == kWriteTo) {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
Kernel<set_zero, xpu>::Launch(s, out_data_size, out_data.dptr<DType>());
})
}
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(in_row_idx.type_flag_, IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
Kernel<SquareSumRspKernel<req_type, 1, false>, xpu>::Launch(s, nnr,
out_data.dptr<DType>(), in_row_idx.dptr<IType>(), in_data.dptr<DType>(), num_cols);
})
})
})
}
}
}

template<typename xpu>
void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu>* s,
const NDArray& ograd,
const NDArray& input,
const OpReqType req,
NDArray* igrad) {
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
CHECK_EQ(param.axis.ndim(), 1U) << "_square_sum(row_sparse_matrix) only supports axis=0";
CHECK(param.axis[0] == 0 || param.axis[0] == 1)
<< "_square_sum(row_sparse_matrix) only supports axis=0 or 1";
CHECK_EQ(ograd.storage_type(), kDefaultStorage);
CHECK_EQ(input.storage_type(), kRowSparseStorage);
CHECK_EQ(igrad->storage_type(), kRowSparseStorage);
CHECK_NE(req, kWriteInplace);
if (!input.storage_initialized()) return;

using namespace mxnet_op;
igrad->CheckAndAlloc({input.aux_shape(rowsparse::kIdx)});
const int64_t num_cols = input.storage_shape()[1];
const TBlob& igrad_data = igrad->data();
const TBlob igrad_row_idx = igrad->aux_data(rowsparse::kIdx);
const TBlob& ograd_data = ograd.data();
const TBlob in_data = input.data();
const TBlob in_row_idx = input.aux_data(rowsparse::kIdx);
MSHADOW_TYPE_SWITCH(igrad_data.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(igrad_row_idx.type_flag_, IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
if (0 == param.axis[0]) { // forward is sum per column
Kernel<SquareSumRspGradKernel<req_type, 0>, xpu>::Launch(s, igrad_data.Size(),
igrad_row_idx.dptr<IType>(), igrad_data.dptr<DType>(), ograd_data.dptr<DType>(),
in_row_idx.dptr<IType>(), in_data.dptr<DType>(), num_cols);
} else { // forward is sum per row
Kernel<SquareSumRspGradKernel<req_type, 1>, xpu>::Launch(s, igrad_data.Size(),
igrad_row_idx.dptr<IType>(), igrad_data.dptr<DType>(), ograd_data.dptr<DType>(),
in_row_idx.dptr<IType>(), in_data.dptr<DType>(), num_cols);
}
})
})
})
}

template<typename xpu>
void SquareSumOpForwardEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
const NDArrayStorageType istype = inputs[0].storage_type();
if (istype == kRowSparseStorage) {
NDArray output = outputs[0];
SquareSumRspImpl(attrs, s, inputs[0], req[0], &output);
} else {
LOG(FATAL) << "_square_sum op only supports row-sparse ndarray"
" as input, while input stype = "
<< istype;
}
}

template<typename xpu>
void SquareSumOpBackwardEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
const NDArrayStorageType ograd_stype = inputs[0].storage_type();
const NDArrayStorageType input_stype = inputs[1].storage_type();
if (input_stype == kRowSparseStorage && ograd_stype == kDefaultStorage) {
NDArray output = outputs[0];
SquareSumRspGradImpl(attrs, s, inputs[0], inputs[1], req[0], &output);
} else {
LOG(FATAL) << "_square_sum op backward only supports dense ndarray as ograd,"
" row-sparse ndarray as input and row-sparse ndarray as igrad,"
" while ograd_stype = " << ograd_stype
<< " input_stype = " << input_stype;
}
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_TENSOR_SQUARE_SUM_INL_H_
34 changes: 34 additions & 0 deletions src/operator/tensor/square_sum.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*!
* Copyright (c) 2017 by Contributors
* \file square_sum.cc
* \brief CPU Implementation of square_sum op.
*/
#include "./square_sum-inl.h"

namespace mxnet {
namespace op {
MXNET_OPERATOR_REGISTER_REDUCE(_square_sum)
.describe(R"code(Computes the square sum of array elements over a given axis
for row-sparse matrix. This is a temporary solution for fusing ops square and
sum together for row-sparse matrix to save memory for storing gradients.
It will become deprecated once the functionality of fusing operators is finished
in the future.
Example::
dns = mx.nd.array([[0, 0], [1, 2], [0, 0], [3, 4], [0, 0]])
rsp = mx.nd.cast_storage(dns, stype='row_sparse')
sum = mx.nd._internal._square_sum(rsp, axis=1)
sum = [0, 5, 0, 25, 0]
)code" ADD_FILELINE)
.set_attr<FInferStorageType>("FInferStorageType", SquareSumForwardInferStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", SquareSumOpForwardEx<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_square_sum"});

MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_square_sum)
.set_num_inputs(2)
.set_attr<FInferStorageType>("FInferStorageType", SquareSumBackwardInferStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", SquareSumOpBackwardEx<cpu>);

} // namespace op
} // namespace mxnet
Loading

0 comments on commit 6f0719f

Please sign in to comment.