Skip to content

Commit

Permalink
Fix index overflow bug in einsum (apache#16589)
Browse files Browse the repository at this point in the history
* fix index overflow

* check index overflow

* fix index overflow in einsum path

* fix indent

* reduce NPY_MAXARGS

* safe accumulate
  • Loading branch information
hzfan authored and yajiedesign committed Nov 6, 2019
1 parent 8e5f9cf commit 01be5aa
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 157 deletions.
9 changes: 9 additions & 0 deletions benchmark/python/einsum/benchmark_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ def test_np_einsum():
cost = measure_cost(500, np.einsum, *args, optimize=True)
print("Greedy einsum: {} ms".format(cost * 1000))

print("RNN Use Case:")
a = np.random.uniform(0, 1, size=(64, 128, 512))
b = np.random.uniform(0, 1, size=(128, 512, 2, 2))
args = ['bij, ijkl->bkl', a, b]
cost = measure_cost(2, np.einsum, *args, optimize=True)
print('Greedy einsum: {} ms'.format(cost * 1000))
cost = measure_cost(2, np.einsum, *args)
print('Basic einsum: {} ms'.format(cost * 1000))

print('Inner Product:')
a = np.ones(6000000)
b = np.ones(6000000)
Expand Down
12 changes: 12 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,18 @@ MSHADOW_XINLINE Shape<ndim> calc_stride(const Shape<ndim>& shape) {
return stride;
}

/* Increment coordinates */
template<int ndim>
MSHADOW_XINLINE bool inc(Shape<ndim>* coord, const Shape<ndim>& shape) {
++(*coord)[ndim-1];
#pragma unroll
for (int i = ndim - 1; i > 0 && (*coord)[i] >= shape[i]; --i) {
(*coord)[i] -= shape[i];
++(*coord)[i-1];
}
return (*coord)[0] < shape[0];
}

/* Increment coordinates and modify index */
template<int ndim>
MSHADOW_XINLINE void inc(Shape<ndim>* coord, const Shape<ndim>& shape,
Expand Down
170 changes: 83 additions & 87 deletions src/operator/numpy/np_einsum_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@
namespace mxnet {
namespace op {

#define NPY_MAXDIMS 32
#define NPY_MAXARGS 32
#define NPY_MAXDIMS 16
#define NPY_MAXARGS 16

inline TShape get_stride(const TShape& shape) {
int ndim = shape.ndim(), prod = 1;
Expand Down Expand Up @@ -415,40 +415,45 @@ class EinsumOp {
}
}; // class EinsumOp

template<int dimension, int req, bool back>
struct numpy_einsum {
template<int dimension, int req, bool back, typename AType>
struct numpy_einsum{
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType* out,
common::StaticArray<DType*, NPY_MAXARGS> op,
mshadow::Shape<dimension> oshape,
mshadow::Shape<dimension> ostride,
common::StaticArray<mshadow::Shape<dimension>,
NPY_MAXARGS> ostride,
mshadow::Shape<dimension> reduceshape,
mshadow::Shape<dimension> reducestride,
mshadow::Shape<dimension> itershape,
common::StaticArray<mshadow::Shape<dimension>,
NPY_MAXARGS> iterstride,
NPY_MAXARGS> rstride,
int nop,
int iop0,
const DType* out_grad) {
using namespace mxnet_op;
index_t oidx = back ? dot(unravel(dot(unravel(i, oshape), ostride), itershape),
iterstride[iop0]) : i;
mshadow::Shape<dimension> oidx = unravel(i, oshape);
i = back ? dot(oidx, ostride[iop0]) : i;
if (req == kWriteTo) {
out[oidx] = (DType)0;
out[i] = (DType)0;
}
for (int rdim = 0; rdim < dimension; ++rdim) {
if (reduceshape[rdim] == 0) {
return;
}
}
for (int j = 0; j < reduceshape.Size(); j++) {
mshadow::Shape<dimension> idx = unravel(dot(unravel(j, reduceshape), reducestride) +
dot(unravel(i, oshape), ostride),
itershape);
DType tmp = back ? out_grad[dot(idx, iterstride[nop])] : (DType)1;
mshadow::Shape<dimension> ridx = unravel(0, reduceshape);
AType sum = 0;
do {
AType tmp = back ? static_cast<AType>(out_grad[dot(oidx, ostride[nop]) +
dot(ridx, rstride[nop])]): (AType)1;
for (int iop = 0; iop < nop; ++iop) {
if (iop != iop0) {
index_t k = dot(idx, iterstride[iop]);
tmp = tmp * op[iop][k];
index_t k = dot(oidx, ostride[iop]) + dot(ridx, rstride[iop]);
tmp = tmp * static_cast<AType>(op[iop][k]);
}
}
out[oidx] = out[oidx] + tmp;
}
sum = sum + tmp;
}while (inc(&ridx, reduceshape));
out[i] = out[i] + static_cast<DType>(sum);
}
};

Expand Down Expand Up @@ -603,12 +608,12 @@ inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
}

/* Step 4: Set up the op_axes for the iterator */
TShape itershape(ndim_iter, -1), iterstride_true(ndim_iter, -1);
TShape itershape(ndim_iter, -1);
std::vector<TShape> iterstride(nop + 1, TShape(ndim_iter, 0));
TShape oshape = back ? inputs[0].shape_ : outputs[0].shape_;
TShape ostride_true = get_stride(oshape);
TShape reduceshape, ostride, reducestride;
std::vector<TShape> iterstride(nop + 1, TShape(ndim_iter, 0));
std::vector<TShape> remainshape(nop), opstride(nop), remainstride(nop);
TShape reduceshape;
std::vector<TShape> remainshape(nop);
int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
int *op_axes[NPY_MAXARGS];

Expand All @@ -632,7 +637,6 @@ inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
for (idim = 0; idim < ndim_output; ++idim) {
iterstride[nop][idim] = ostride_true[idim];
}
iterstride_true = get_stride(itershape);
reduceshape = TShape(ndim_iter - ndim_output, 0);
for (idim = ndim_output; idim < ndim_iter; ++idim) {
reduceshape[idim - ndim_output] = itershape[idim];
Expand All @@ -648,30 +652,6 @@ inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
remainshape[iop] = TShape(rsh.begin(), rsh.end());
}

// calculate stride
ostride = TShape(ndim_output, 0);
for (idim = 0; idim < ndim_output; ++idim) {
ostride[idim] = iterstride_true[idim];
}
reducestride = TShape(ndim_iter - ndim_output, 0);
for (idim = ndim_output; idim < ndim_iter; ++idim) {
reducestride[idim - ndim_output] = iterstride_true[idim];
}
for (iop = 0; iop < nop; ++iop) {
opstride[iop] = TShape(opshape[iop].ndim(), 0);
remainstride[iop] = TShape(remainshape[iop].ndim(), 0);
int j = 0;
for (idim = 0; idim < ndim_iter; ++idim) {
if (op_axes_arrays[iop][idim] != -1 &&
itershape[idim] == opshape[iop][op_axes_arrays[iop][idim]]) {
opstride[iop][op_axes_arrays[iop][idim]] = iterstride_true[idim];
} else {
remainstride[iop][j++] = iterstride_true[idim];
}
}
CHECK_EQ(j, remainstride[iop].ndim());
}

// exclude the 0-dim case
if (ndim_iter == 0) {
ndim_iter = 1;
Expand All @@ -681,43 +661,44 @@ inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
iterstride[iop] = pad(iterstride[iop], ndim_iter);
}
oshape = pad(oshape, ndim_iter);
ostride = pad(ostride, ndim_iter);
reduceshape = pad(reduceshape, ndim_iter);
reducestride = pad(reducestride, ndim_iter);
for (iop = 0; iop < nop; ++iop) {
opshape[iop] = pad(opshape[iop], ndim_iter);
opstride[iop] = pad(opstride[iop], ndim_iter);
remainshape[iop] = pad(remainshape[iop], ndim_iter);
remainstride[iop] = pad(remainstride[iop], ndim_iter);
}

if (!back) {
if (oshape.Size() == 0) {
return;
}
const TBlob &out_data = outputs[0];
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ACC_TYPE_SWITCH(out_data.type_flag_, DType, AType, {
mxnet::common::StaticArray<DType*, NPY_MAXARGS> op;
for (iop = 0; iop < nop; ++iop) {
op[iop] = inputs[iop].dptr<DType>();
}
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, {
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> iterstride_arr;
for (iop = 0; iop <= nop; ++iop) {
iterstride_arr[iop] = iterstride[iop].get<dimension>();
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> ostride_arr;
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> rstride_arr;
for (iop = 0; iop < nop; ++iop) {
mshadow::Shape<dimension> otmp, rtmp;
for (idim = 0; idim < dimension; ++idim) {
otmp[idim] = idim < ndim_output ? iterstride[iop][idim] : 1;
rtmp[idim] = idim < dimension - ndim_output ? iterstride[iop][idim + ndim_output] : 1;
}
ostride_arr[iop] = otmp;
rstride_arr[iop] = rtmp;
}
Kernel<numpy_einsum<dimension, req_type, 0>,
Kernel<numpy_einsum<dimension, req_type, 0, AType>,
xpu>::Launch(ctx.get_stream<xpu>(),
oshape.Size(),
out_data.dptr<DType>(),
op,
oshape.get<dimension>(),
ostride.get<dimension>(),
ostride_arr,
reduceshape.get<dimension>(),
reducestride.get<dimension>(),
itershape.get<dimension>(),
iterstride_arr,
rstride_arr,
nop,
-1,
reinterpret_cast<DType*>(NULL));
Expand All @@ -743,31 +724,44 @@ inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
for (int i = 0; i < nop; ++i) {
const TBlob &out_data = outputs[i];
const TBlob &out_grad = inputs[0];
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
std::vector<TShape> opstride(nop + 1, TShape(ndim_iter, 0));
std::vector<TShape> remainstride(nop + 1, TShape(ndim_iter, 0));
for (iop = 0; iop <= nop; ++iop) {
int j = 0;
for (idim = 0; idim < ndim_iter; ++idim) {
if (op_axes_arrays[i][idim] == -1 ||
opshape[i][op_axes_arrays[i][idim]] == 1) {
remainstride[iop][j++] = iterstride[iop][idim];
} else {
opstride[iop][op_axes_arrays[i][idim]] = iterstride[iop][idim];
}
}
}
MXNET_ACC_TYPE_SWITCH(out_data.type_flag_, DType, AType, {
mxnet::common::StaticArray<DType*, NPY_MAXARGS> op;
for (iop = 0; iop < nop; ++iop) {
op[iop] = inputs[iop + back].dptr<DType>();
}
MXNET_ASSIGN_REQ_SWITCH(req[i], req_type, {
MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, {
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> iterstride_arr;
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> opstride_arr;
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> remainstride_arr;
for (iop = 0; iop <= nop; ++iop) {
iterstride_arr[iop] = iterstride[iop].get<dimension>();
opstride_arr[iop] = opstride[iop].get<dimension>();
remainstride_arr[iop] = remainstride[iop].get<dimension>();
}
Kernel<numpy_einsum<dimension, req_type, 1>,
Kernel<numpy_einsum<dimension, req_type, 1, AType>,
xpu>::Launch(ctx.get_stream<xpu>(),
opshape[i].Size(),
out_data.dptr<DType>(),
op,
opshape[i].get<dimension>(),
opstride[i].get<dimension>(),
remainshape[i].get<dimension>(),
remainstride[i].get<dimension>(),
itershape.get<dimension>(),
iterstride_arr,
nop,
i,
out_grad.dptr<DType>());
opshape[i].Size(),
out_data.dptr<DType>(),
op,
opshape[i].get<dimension>(),
opstride_arr,
remainshape[i].get<dimension>(),
remainstride_arr,
nop,
i,
out_grad.dptr<DType>());
})
})
})
Expand Down Expand Up @@ -798,13 +792,14 @@ inline void NumpyEinsumForward(const OpStatePtr& state_ptr,
std::vector<std::vector<int> > pos;
std::string string_repr;
paths = einsum_path(state.subscripts, inputs, true, ctx.run_ctx, &pos, &string_repr);
int paths_len = paths.size(), temp_space_size = 0, max_temp_space_size = 0;
int paths_len = paths.size();
size_t temp_space_size = 0, max_temp_space_size = 0;
std::vector<TBlob> operands(inputs), tmp_operands, temp_space_vec(paths_len - 1);
for (int i = 0; i + 1 < paths_len; ++i) {
temp_space_size += paths[i].oshape.Size();
}
for (int i = 0; i < paths_len; ++i) {
max_temp_space_size = std::max(max_temp_space_size, static_cast<int>(paths[i].oshape.Size()));
max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size());
}
temp_space_size += max_temp_space_size;
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Expand All @@ -813,7 +808,7 @@ inline void NumpyEinsumForward(const OpStatePtr& state_ptr,
false,
outputs[0].type_flag_));
Tensor<xpu, 1, DType> temp_space = state.tempspace->data().FlatTo1D<xpu, DType>();
int begin = max_temp_space_size;
size_t begin = max_temp_space_size;
for (int i = 0; i < paths_len - 1; ++i) {
TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
temp_space_vec[i] = tblob.reshape(paths[i].oshape);
Expand Down Expand Up @@ -910,12 +905,13 @@ inline void NumpyEinsumBackward(const OpStatePtr& state_ptr,
}
// calculate temporary space size for temp_grad
const std::vector<Step>& paths = state.paths;
int paths_len = paths.size(), temp_space_size = 0, max_temp_space_size = 0;
int paths_len = paths.size();
size_t temp_space_size = 0, max_temp_space_size = 0;
for (int i = 0; i < paths_len - 1; ++i) {
temp_space_size += paths[i].oshape.Size();
}
for (int i = 0; i < paths_len; ++i) {
max_temp_space_size = std::max(max_temp_space_size, static_cast<int>(paths[i].oshape.Size()));
max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size());
}
temp_space_size += max_temp_space_size;
// replay the forward process
Expand All @@ -936,8 +932,8 @@ inline void NumpyEinsumBackward(const OpStatePtr& state_ptr,
}
}
// calculate temporary space size for tensordot
int tensordot_max_tempspace_size = 0;
int begin_tensordot_tempspace = 0;
size_t tensordot_max_tempspace_size = 0;
size_t begin_tensordot_tempspace = 0;
std::vector<TBlob> temp_inputs, temp_outputs;
std::vector<OpReqType> temp_req;
std::vector<size_t> tensordot_tempspace_size;
Expand Down Expand Up @@ -999,7 +995,7 @@ inline void NumpyEinsumBackward(const OpStatePtr& state_ptr,
}
tensordot_tempspace_size.push_back(cur_tensordot_tempspace_size);
tensordot_max_tempspace_size = std::max(tensordot_max_tempspace_size,
static_cast<int>(cur_tensordot_tempspace_size));
cur_tensordot_tempspace_size);
}
begin_tensordot_tempspace = temp_space_size;
temp_space_size += (tensordot_max_tempspace_size + sizeof(DType) - 1) / sizeof(DType);
Expand All @@ -1010,7 +1006,7 @@ inline void NumpyEinsumBackward(const OpStatePtr& state_ptr,
// allocate temporary space for gradients of intermediate results
Tensor<xpu, 1, DType> temp_space = ctx.requested[0].get_space_typed<xpu, 1, DType>
(Shape1(temp_space_size), s);
int begin = max_temp_space_size;
size_t begin = max_temp_space_size;
for (int i = 0; i + 1 < paths_len; ++i) {
TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
temp_grad[i] = tblob.reshape(paths[i].oshape);
Expand Down
11 changes: 11 additions & 0 deletions src/operator/numpy/np_einsum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,17 @@ bool NumpyEinsumShape(const nnvm::NodeAttrs& attrs,
oshape[i] = dimension_dict[static_cast<int>(output_str[i])];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
size_t lim = static_cast<size_t>(std::numeric_limits<index_t>::max());
for (int i = 0; i < num_args; ++i) {
CHECK_LE(in_attrs->at(i).Size(), lim)
<< "Size of operand " << i
<< " exceeds the maximum index."
<< " Try setting `USE_INT64_TENSOR_SIZE`.";
}
CHECK_LE(oshape.Size(), lim)
<< "Size of output"
<< " exceeds the maximum index."
<< " Try setting `USE_INT64_TENSOR_SIZE`.";
return shape_is_known(oshape);
}

Expand Down
Loading

0 comments on commit 01be5aa

Please sign in to comment.