Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix index overflow bug in einsum #16589

Merged
merged 6 commits into from
Oct 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions benchmark/python/einsum/benchmark_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ def test_np_einsum():
cost = measure_cost(500, np.einsum, *args, optimize=True)
print("Greedy einsum: {} ms".format(cost * 1000))

print("RNN Use Case:")
a = np.random.uniform(0, 1, size=(64, 128, 512))
b = np.random.uniform(0, 1, size=(128, 512, 2, 2))
args = ['bij, ijkl->bkl', a, b]
cost = measure_cost(2, np.einsum, *args, optimize=True)
print('Greedy einsum: {} ms'.format(cost * 1000))
cost = measure_cost(2, np.einsum, *args)
print('Basic einsum: {} ms'.format(cost * 1000))

print('Inner Product:')
a = np.ones(6000000)
b = np.ones(6000000)
Expand Down
12 changes: 12 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,18 @@ MSHADOW_XINLINE Shape<ndim> calc_stride(const Shape<ndim>& shape) {
return stride;
}

/* Increment coordinates */
template<int ndim>
MSHADOW_XINLINE bool inc(Shape<ndim>* coord, const Shape<ndim>& shape) {
++(*coord)[ndim-1];
#pragma unroll
for (int i = ndim - 1; i > 0 && (*coord)[i] >= shape[i]; --i) {
(*coord)[i] -= shape[i];
++(*coord)[i-1];
}
return (*coord)[0] < shape[0];
}

/* Increment coordinates and modify index */
template<int ndim>
MSHADOW_XINLINE void inc(Shape<ndim>* coord, const Shape<ndim>& shape,
Expand Down
170 changes: 83 additions & 87 deletions src/operator/numpy/np_einsum_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@
namespace mxnet {
namespace op {

#define NPY_MAXDIMS 32
#define NPY_MAXARGS 32
#define NPY_MAXDIMS 16
#define NPY_MAXARGS 16

inline TShape get_stride(const TShape& shape) {
int ndim = shape.ndim(), prod = 1;
Expand Down Expand Up @@ -415,40 +415,45 @@ class EinsumOp {
}
}; // class EinsumOp

template<int dimension, int req, bool back>
struct numpy_einsum {
template<int dimension, int req, bool back, typename AType>
struct numpy_einsum{
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType* out,
common::StaticArray<DType*, NPY_MAXARGS> op,
mshadow::Shape<dimension> oshape,
mshadow::Shape<dimension> ostride,
common::StaticArray<mshadow::Shape<dimension>,
NPY_MAXARGS> ostride,
mshadow::Shape<dimension> reduceshape,
mshadow::Shape<dimension> reducestride,
mshadow::Shape<dimension> itershape,
common::StaticArray<mshadow::Shape<dimension>,
NPY_MAXARGS> iterstride,
NPY_MAXARGS> rstride,
int nop,
int iop0,
const DType* out_grad) {
using namespace mxnet_op;
index_t oidx = back ? dot(unravel(dot(unravel(i, oshape), ostride), itershape),
iterstride[iop0]) : i;
mshadow::Shape<dimension> oidx = unravel(i, oshape);
i = back ? dot(oidx, ostride[iop0]) : i;
if (req == kWriteTo) {
out[oidx] = (DType)0;
out[i] = (DType)0;
}
for (int rdim = 0; rdim < dimension; ++rdim) {
if (reduceshape[rdim] == 0) {
return;
}
}
for (int j = 0; j < reduceshape.Size(); j++) {
mshadow::Shape<dimension> idx = unravel(dot(unravel(j, reduceshape), reducestride) +
dot(unravel(i, oshape), ostride),
itershape);
DType tmp = back ? out_grad[dot(idx, iterstride[nop])] : (DType)1;
mshadow::Shape<dimension> ridx = unravel(0, reduceshape);
AType sum = 0;
do {
AType tmp = back ? static_cast<AType>(out_grad[dot(oidx, ostride[nop]) +
dot(ridx, rstride[nop])]): (AType)1;
for (int iop = 0; iop < nop; ++iop) {
if (iop != iop0) {
index_t k = dot(idx, iterstride[iop]);
tmp = tmp * op[iop][k];
index_t k = dot(oidx, ostride[iop]) + dot(ridx, rstride[iop]);
tmp = tmp * static_cast<AType>(op[iop][k]);
}
}
out[oidx] = out[oidx] + tmp;
}
sum = sum + tmp;
}while (inc(&ridx, reduceshape));
out[i] = out[i] + static_cast<DType>(sum);
}
};

Expand Down Expand Up @@ -603,12 +608,12 @@ inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
}

/* Step 4: Set up the op_axes for the iterator */
TShape itershape(ndim_iter, -1), iterstride_true(ndim_iter, -1);
TShape itershape(ndim_iter, -1);
std::vector<TShape> iterstride(nop + 1, TShape(ndim_iter, 0));
TShape oshape = back ? inputs[0].shape_ : outputs[0].shape_;
TShape ostride_true = get_stride(oshape);
TShape reduceshape, ostride, reducestride;
std::vector<TShape> iterstride(nop + 1, TShape(ndim_iter, 0));
std::vector<TShape> remainshape(nop), opstride(nop), remainstride(nop);
TShape reduceshape;
std::vector<TShape> remainshape(nop);
int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
int *op_axes[NPY_MAXARGS];

Expand All @@ -632,7 +637,6 @@ inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
for (idim = 0; idim < ndim_output; ++idim) {
iterstride[nop][idim] = ostride_true[idim];
}
iterstride_true = get_stride(itershape);
reduceshape = TShape(ndim_iter - ndim_output, 0);
for (idim = ndim_output; idim < ndim_iter; ++idim) {
reduceshape[idim - ndim_output] = itershape[idim];
Expand All @@ -648,30 +652,6 @@ inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
remainshape[iop] = TShape(rsh.begin(), rsh.end());
}

// calculate stride
ostride = TShape(ndim_output, 0);
for (idim = 0; idim < ndim_output; ++idim) {
ostride[idim] = iterstride_true[idim];
}
reducestride = TShape(ndim_iter - ndim_output, 0);
for (idim = ndim_output; idim < ndim_iter; ++idim) {
reducestride[idim - ndim_output] = iterstride_true[idim];
}
for (iop = 0; iop < nop; ++iop) {
opstride[iop] = TShape(opshape[iop].ndim(), 0);
remainstride[iop] = TShape(remainshape[iop].ndim(), 0);
int j = 0;
for (idim = 0; idim < ndim_iter; ++idim) {
if (op_axes_arrays[iop][idim] != -1 &&
itershape[idim] == opshape[iop][op_axes_arrays[iop][idim]]) {
opstride[iop][op_axes_arrays[iop][idim]] = iterstride_true[idim];
} else {
remainstride[iop][j++] = iterstride_true[idim];
}
}
CHECK_EQ(j, remainstride[iop].ndim());
}

// exclude the 0-dim case
if (ndim_iter == 0) {
ndim_iter = 1;
Expand All @@ -681,43 +661,44 @@ inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
iterstride[iop] = pad(iterstride[iop], ndim_iter);
}
oshape = pad(oshape, ndim_iter);
ostride = pad(ostride, ndim_iter);
reduceshape = pad(reduceshape, ndim_iter);
reducestride = pad(reducestride, ndim_iter);
for (iop = 0; iop < nop; ++iop) {
opshape[iop] = pad(opshape[iop], ndim_iter);
opstride[iop] = pad(opstride[iop], ndim_iter);
remainshape[iop] = pad(remainshape[iop], ndim_iter);
remainstride[iop] = pad(remainstride[iop], ndim_iter);
}

if (!back) {
if (oshape.Size() == 0) {
return;
}
const TBlob &out_data = outputs[0];
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ACC_TYPE_SWITCH(out_data.type_flag_, DType, AType, {
mxnet::common::StaticArray<DType*, NPY_MAXARGS> op;
for (iop = 0; iop < nop; ++iop) {
op[iop] = inputs[iop].dptr<DType>();
}
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, {
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> iterstride_arr;
for (iop = 0; iop <= nop; ++iop) {
iterstride_arr[iop] = iterstride[iop].get<dimension>();
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> ostride_arr;
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> rstride_arr;
for (iop = 0; iop < nop; ++iop) {
mshadow::Shape<dimension> otmp, rtmp;
for (idim = 0; idim < dimension; ++idim) {
otmp[idim] = idim < ndim_output ? iterstride[iop][idim] : 1;
rtmp[idim] = idim < dimension - ndim_output ? iterstride[iop][idim + ndim_output] : 1;
}
ostride_arr[iop] = otmp;
rstride_arr[iop] = rtmp;
}
Kernel<numpy_einsum<dimension, req_type, 0>,
Kernel<numpy_einsum<dimension, req_type, 0, AType>,
xpu>::Launch(ctx.get_stream<xpu>(),
oshape.Size(),
out_data.dptr<DType>(),
op,
oshape.get<dimension>(),
ostride.get<dimension>(),
ostride_arr,
reduceshape.get<dimension>(),
reducestride.get<dimension>(),
itershape.get<dimension>(),
iterstride_arr,
rstride_arr,
nop,
-1,
reinterpret_cast<DType*>(NULL));
Expand All @@ -743,31 +724,44 @@ inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
for (int i = 0; i < nop; ++i) {
const TBlob &out_data = outputs[i];
const TBlob &out_grad = inputs[0];
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
std::vector<TShape> opstride(nop + 1, TShape(ndim_iter, 0));
std::vector<TShape> remainstride(nop + 1, TShape(ndim_iter, 0));
for (iop = 0; iop <= nop; ++iop) {
int j = 0;
for (idim = 0; idim < ndim_iter; ++idim) {
if (op_axes_arrays[i][idim] == -1 ||
opshape[i][op_axes_arrays[i][idim]] == 1) {
remainstride[iop][j++] = iterstride[iop][idim];
} else {
opstride[iop][op_axes_arrays[i][idim]] = iterstride[iop][idim];
}
}
}
MXNET_ACC_TYPE_SWITCH(out_data.type_flag_, DType, AType, {
mxnet::common::StaticArray<DType*, NPY_MAXARGS> op;
for (iop = 0; iop < nop; ++iop) {
op[iop] = inputs[iop + back].dptr<DType>();
}
MXNET_ASSIGN_REQ_SWITCH(req[i], req_type, {
MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, {
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> iterstride_arr;
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> opstride_arr;
mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> remainstride_arr;
for (iop = 0; iop <= nop; ++iop) {
iterstride_arr[iop] = iterstride[iop].get<dimension>();
opstride_arr[iop] = opstride[iop].get<dimension>();
remainstride_arr[iop] = remainstride[iop].get<dimension>();
}
Kernel<numpy_einsum<dimension, req_type, 1>,
Kernel<numpy_einsum<dimension, req_type, 1, AType>,
xpu>::Launch(ctx.get_stream<xpu>(),
opshape[i].Size(),
out_data.dptr<DType>(),
op,
opshape[i].get<dimension>(),
opstride[i].get<dimension>(),
remainshape[i].get<dimension>(),
remainstride[i].get<dimension>(),
itershape.get<dimension>(),
iterstride_arr,
nop,
i,
out_grad.dptr<DType>());
opshape[i].Size(),
out_data.dptr<DType>(),
op,
opshape[i].get<dimension>(),
opstride_arr,
remainshape[i].get<dimension>(),
remainstride_arr,
nop,
i,
out_grad.dptr<DType>());
})
})
})
Expand Down Expand Up @@ -798,13 +792,14 @@ inline void NumpyEinsumForward(const OpStatePtr& state_ptr,
std::vector<std::vector<int> > pos;
std::string string_repr;
paths = einsum_path(state.subscripts, inputs, true, ctx.run_ctx, &pos, &string_repr);
int paths_len = paths.size(), temp_space_size = 0, max_temp_space_size = 0;
int paths_len = paths.size();
size_t temp_space_size = 0, max_temp_space_size = 0;
std::vector<TBlob> operands(inputs), tmp_operands, temp_space_vec(paths_len - 1);
for (int i = 0; i + 1 < paths_len; ++i) {
temp_space_size += paths[i].oshape.Size();
}
for (int i = 0; i < paths_len; ++i) {
max_temp_space_size = std::max(max_temp_space_size, static_cast<int>(paths[i].oshape.Size()));
max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size());
}
temp_space_size += max_temp_space_size;
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Expand All @@ -813,7 +808,7 @@ inline void NumpyEinsumForward(const OpStatePtr& state_ptr,
false,
outputs[0].type_flag_));
Tensor<xpu, 1, DType> temp_space = state.tempspace->data().FlatTo1D<xpu, DType>();
int begin = max_temp_space_size;
size_t begin = max_temp_space_size;
for (int i = 0; i < paths_len - 1; ++i) {
TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
temp_space_vec[i] = tblob.reshape(paths[i].oshape);
Expand Down Expand Up @@ -910,12 +905,13 @@ inline void NumpyEinsumBackward(const OpStatePtr& state_ptr,
}
// calculate temporary space size for temp_grad
const std::vector<Step>& paths = state.paths;
int paths_len = paths.size(), temp_space_size = 0, max_temp_space_size = 0;
int paths_len = paths.size();
size_t temp_space_size = 0, max_temp_space_size = 0;
for (int i = 0; i < paths_len - 1; ++i) {
temp_space_size += paths[i].oshape.Size();
}
for (int i = 0; i < paths_len; ++i) {
max_temp_space_size = std::max(max_temp_space_size, static_cast<int>(paths[i].oshape.Size()));
max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size());
}
temp_space_size += max_temp_space_size;
// replay the forward process
Expand All @@ -936,8 +932,8 @@ inline void NumpyEinsumBackward(const OpStatePtr& state_ptr,
}
}
// calculate temporary space size for tensordot
int tensordot_max_tempspace_size = 0;
int begin_tensordot_tempspace = 0;
size_t tensordot_max_tempspace_size = 0;
size_t begin_tensordot_tempspace = 0;
std::vector<TBlob> temp_inputs, temp_outputs;
std::vector<OpReqType> temp_req;
std::vector<size_t> tensordot_tempspace_size;
Expand Down Expand Up @@ -999,7 +995,7 @@ inline void NumpyEinsumBackward(const OpStatePtr& state_ptr,
}
tensordot_tempspace_size.push_back(cur_tensordot_tempspace_size);
tensordot_max_tempspace_size = std::max(tensordot_max_tempspace_size,
static_cast<int>(cur_tensordot_tempspace_size));
cur_tensordot_tempspace_size);
}
begin_tensordot_tempspace = temp_space_size;
temp_space_size += (tensordot_max_tempspace_size + sizeof(DType) - 1) / sizeof(DType);
Expand All @@ -1010,7 +1006,7 @@ inline void NumpyEinsumBackward(const OpStatePtr& state_ptr,
// allocate temporary space for gradients of intermediate results
Tensor<xpu, 1, DType> temp_space = ctx.requested[0].get_space_typed<xpu, 1, DType>
(Shape1(temp_space_size), s);
int begin = max_temp_space_size;
size_t begin = max_temp_space_size;
for (int i = 0; i + 1 < paths_len; ++i) {
TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
temp_grad[i] = tblob.reshape(paths[i].oshape);
Expand Down
11 changes: 11 additions & 0 deletions src/operator/numpy/np_einsum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,17 @@ bool NumpyEinsumShape(const nnvm::NodeAttrs& attrs,
oshape[i] = dimension_dict[static_cast<int>(output_str[i])];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
size_t lim = static_cast<size_t>(std::numeric_limits<index_t>::max());
for (int i = 0; i < num_args; ++i) {
CHECK_LE(in_attrs->at(i).Size(), lim)
<< "Size of operand " << i
<< " exceeds the maximum index."
<< " Try setting `USE_INT64_TENSOR_SIZE`.";
}
CHECK_LE(oshape.Size(), lim)
<< "Size of output"
<< " exceeds the maximum index."
<< " Try setting `USE_INT64_TENSOR_SIZE`.";
return shape_is_known(oshape);
}

Expand Down
Loading