diff --git a/benchmark/python/einsum/benchmark_einsum.py b/benchmark/python/einsum/benchmark_einsum.py index 3593de2db9e1..6de8223287da 100644 --- a/benchmark/python/einsum/benchmark_einsum.py +++ b/benchmark/python/einsum/benchmark_einsum.py @@ -48,6 +48,15 @@ def test_np_einsum(): cost = measure_cost(500, np.einsum, *args, optimize=True) print("Greedy einsum: {} ms".format(cost * 1000)) + print("RNN Use Case:") + a = np.random.uniform(0, 1, size=(64, 128, 512)) + b = np.random.uniform(0, 1, size=(128, 512, 2, 2)) + args = ['bij, ijkl->bkl', a, b] + cost = measure_cost(2, np.einsum, *args, optimize=True) + print('Greedy einsum: {} ms'.format(cost * 1000)) + cost = measure_cost(2, np.einsum, *args) + print('Basic einsum: {} ms'.format(cost * 1000)) + print('Inner Product:') a = np.ones(6000000) b = np.ones(6000000) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 8ccc34247b6f..463c71b5b0eb 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -619,6 +619,18 @@ MSHADOW_XINLINE Shape calc_stride(const Shape& shape) { return stride; } +/* Increment coordinates */ +template +MSHADOW_XINLINE bool inc(Shape* coord, const Shape& shape) { + ++(*coord)[ndim-1]; + #pragma unroll + for (int i = ndim - 1; i > 0 && (*coord)[i] >= shape[i]; --i) { + (*coord)[i] -= shape[i]; + ++(*coord)[i-1]; + } + return (*coord)[0] < shape[0]; +} + /* Increment coordinates and modify index */ template MSHADOW_XINLINE void inc(Shape* coord, const Shape& shape, diff --git a/src/operator/numpy/np_einsum_op-inl.h b/src/operator/numpy/np_einsum_op-inl.h index 2145abec682b..051280763331 100644 --- a/src/operator/numpy/np_einsum_op-inl.h +++ b/src/operator/numpy/np_einsum_op-inl.h @@ -73,8 +73,8 @@ namespace mxnet { namespace op { -#define NPY_MAXDIMS 32 -#define NPY_MAXARGS 32 +#define NPY_MAXDIMS 16 +#define NPY_MAXARGS 16 inline TShape get_stride(const TShape& shape) { int ndim = shape.ndim(), prod = 1; @@ -415,40 +415,45 @@ class EinsumOp { } }; // class EinsumOp -template -struct numpy_einsum { +template +struct numpy_einsum{ template MSHADOW_XINLINE static void Map(index_t i, DType* out, common::StaticArray op, mshadow::Shape oshape, - mshadow::Shape ostride, + common::StaticArray, + NPY_MAXARGS> ostride, mshadow::Shape reduceshape, - mshadow::Shape reducestride, - mshadow::Shape itershape, common::StaticArray, - NPY_MAXARGS> iterstride, + NPY_MAXARGS> rstride, int nop, int iop0, const DType* out_grad) { using namespace mxnet_op; - index_t oidx = back ? dot(unravel(dot(unravel(i, oshape), ostride), itershape), - iterstride[iop0]) : i; + mshadow::Shape oidx = unravel(i, oshape); + i = back ? dot(oidx, ostride[iop0]) : i; if (req == kWriteTo) { - out[oidx] = (DType)0; + out[i] = (DType)0; + } + for (int rdim = 0; rdim < dimension; ++rdim) { + if (reduceshape[rdim] == 0) { + return; + } } - for (int j = 0; j < reduceshape.Size(); j++) { - mshadow::Shape idx = unravel(dot(unravel(j, reduceshape), reducestride) + - dot(unravel(i, oshape), ostride), - itershape); - DType tmp = back ? out_grad[dot(idx, iterstride[nop])] : (DType)1; + mshadow::Shape ridx = unravel(0, reduceshape); + AType sum = 0; + do { + AType tmp = back ? static_cast(out_grad[dot(oidx, ostride[nop]) + + dot(ridx, rstride[nop])]): (AType)1; for (int iop = 0; iop < nop; ++iop) { if (iop != iop0) { - index_t k = dot(idx, iterstride[iop]); - tmp = tmp * op[iop][k]; + index_t k = dot(oidx, ostride[iop]) + dot(ridx, rstride[iop]); + tmp = tmp * static_cast(op[iop][k]); } } - out[oidx] = out[oidx] + tmp; - } + sum = sum + tmp; + }while (inc(&ridx, reduceshape)); + out[i] = out[i] + static_cast(sum); } }; @@ -603,12 +608,12 @@ inline void NumpyEinsumProcess(const std::vector& inputs, } /* Step 4: Set up the op_axes for the iterator */ - TShape itershape(ndim_iter, -1), iterstride_true(ndim_iter, -1); + TShape itershape(ndim_iter, -1); + std::vector iterstride(nop + 1, TShape(ndim_iter, 0)); TShape oshape = back ? inputs[0].shape_ : outputs[0].shape_; TShape ostride_true = get_stride(oshape); - TShape reduceshape, ostride, reducestride; - std::vector iterstride(nop + 1, TShape(ndim_iter, 0)); - std::vector remainshape(nop), opstride(nop), remainstride(nop); + TShape reduceshape; + std::vector remainshape(nop); int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS]; int *op_axes[NPY_MAXARGS]; @@ -632,7 +637,6 @@ inline void NumpyEinsumProcess(const std::vector& inputs, for (idim = 0; idim < ndim_output; ++idim) { iterstride[nop][idim] = ostride_true[idim]; } - iterstride_true = get_stride(itershape); reduceshape = TShape(ndim_iter - ndim_output, 0); for (idim = ndim_output; idim < ndim_iter; ++idim) { reduceshape[idim - ndim_output] = itershape[idim]; @@ -648,30 +652,6 @@ inline void NumpyEinsumProcess(const std::vector& inputs, remainshape[iop] = TShape(rsh.begin(), rsh.end()); } - // calculate stride - ostride = TShape(ndim_output, 0); - for (idim = 0; idim < ndim_output; ++idim) { - ostride[idim] = iterstride_true[idim]; - } - reducestride = TShape(ndim_iter - ndim_output, 0); - for (idim = ndim_output; idim < ndim_iter; ++idim) { - reducestride[idim - ndim_output] = iterstride_true[idim]; - } - for (iop = 0; iop < nop; ++iop) { - opstride[iop] = TShape(opshape[iop].ndim(), 0); - remainstride[iop] = TShape(remainshape[iop].ndim(), 0); - int j = 0; - for (idim = 0; idim < ndim_iter; ++idim) { - if (op_axes_arrays[iop][idim] != -1 && - itershape[idim] == opshape[iop][op_axes_arrays[iop][idim]]) { - opstride[iop][op_axes_arrays[iop][idim]] = iterstride_true[idim]; - } else { - remainstride[iop][j++] = iterstride_true[idim]; - } - } - CHECK_EQ(j, remainstride[iop].ndim()); - } - // exclude the 0-dim case if (ndim_iter == 0) { ndim_iter = 1; @@ -681,14 +661,10 @@ inline void NumpyEinsumProcess(const std::vector& inputs, iterstride[iop] = pad(iterstride[iop], ndim_iter); } oshape = pad(oshape, ndim_iter); - ostride = pad(ostride, ndim_iter); reduceshape = pad(reduceshape, ndim_iter); - reducestride = pad(reducestride, ndim_iter); for (iop = 0; iop < nop; ++iop) { opshape[iop] = pad(opshape[iop], ndim_iter); - opstride[iop] = pad(opstride[iop], ndim_iter); remainshape[iop] = pad(remainshape[iop], ndim_iter); - remainstride[iop] = pad(remainstride[iop], ndim_iter); } if (!back) { @@ -696,28 +672,33 @@ inline void NumpyEinsumProcess(const std::vector& inputs, return; } const TBlob &out_data = outputs[0]; - MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ACC_TYPE_SWITCH(out_data.type_flag_, DType, AType, { mxnet::common::StaticArray op; for (iop = 0; iop < nop; ++iop) { op[iop] = inputs[iop].dptr(); } MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, { - mxnet::common::StaticArray, NPY_MAXARGS> iterstride_arr; - for (iop = 0; iop <= nop; ++iop) { - iterstride_arr[iop] = iterstride[iop].get(); + mxnet::common::StaticArray, NPY_MAXARGS> ostride_arr; + mxnet::common::StaticArray, NPY_MAXARGS> rstride_arr; + for (iop = 0; iop < nop; ++iop) { + mshadow::Shape otmp, rtmp; + for (idim = 0; idim < dimension; ++idim) { + otmp[idim] = idim < ndim_output ? iterstride[iop][idim] : 1; + rtmp[idim] = idim < dimension - ndim_output ? iterstride[iop][idim + ndim_output] : 1; + } + ostride_arr[iop] = otmp; + rstride_arr[iop] = rtmp; } - Kernel, + Kernel, xpu>::Launch(ctx.get_stream(), oshape.Size(), out_data.dptr(), op, oshape.get(), - ostride.get(), + ostride_arr, reduceshape.get(), - reducestride.get(), - itershape.get(), - iterstride_arr, + rstride_arr, nop, -1, reinterpret_cast(NULL)); @@ -743,31 +724,44 @@ inline void NumpyEinsumProcess(const std::vector& inputs, for (int i = 0; i < nop; ++i) { const TBlob &out_data = outputs[i]; const TBlob &out_grad = inputs[0]; - MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + std::vector opstride(nop + 1, TShape(ndim_iter, 0)); + std::vector remainstride(nop + 1, TShape(ndim_iter, 0)); + for (iop = 0; iop <= nop; ++iop) { + int j = 0; + for (idim = 0; idim < ndim_iter; ++idim) { + if (op_axes_arrays[i][idim] == -1 || + opshape[i][op_axes_arrays[i][idim]] == 1) { + remainstride[iop][j++] = iterstride[iop][idim]; + } else { + opstride[iop][op_axes_arrays[i][idim]] = iterstride[iop][idim]; + } + } + } + MXNET_ACC_TYPE_SWITCH(out_data.type_flag_, DType, AType, { mxnet::common::StaticArray op; for (iop = 0; iop < nop; ++iop) { op[iop] = inputs[iop + back].dptr(); } MXNET_ASSIGN_REQ_SWITCH(req[i], req_type, { MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, { - mxnet::common::StaticArray, NPY_MAXARGS> iterstride_arr; + mxnet::common::StaticArray, NPY_MAXARGS> opstride_arr; + mxnet::common::StaticArray, NPY_MAXARGS> remainstride_arr; for (iop = 0; iop <= nop; ++iop) { - iterstride_arr[iop] = iterstride[iop].get(); + opstride_arr[iop] = opstride[iop].get(); + remainstride_arr[iop] = remainstride[iop].get(); } - Kernel, + Kernel, xpu>::Launch(ctx.get_stream(), - opshape[i].Size(), - out_data.dptr(), - op, - opshape[i].get(), - opstride[i].get(), - remainshape[i].get(), - remainstride[i].get(), - itershape.get(), - iterstride_arr, - nop, - i, - out_grad.dptr()); + opshape[i].Size(), + out_data.dptr(), + op, + opshape[i].get(), + opstride_arr, + remainshape[i].get(), + remainstride_arr, + nop, + i, + out_grad.dptr()); }) }) }) @@ -798,13 +792,14 @@ inline void NumpyEinsumForward(const OpStatePtr& state_ptr, std::vector > pos; std::string string_repr; paths = einsum_path(state.subscripts, inputs, true, ctx.run_ctx, &pos, &string_repr); - int paths_len = paths.size(), temp_space_size = 0, max_temp_space_size = 0; + int paths_len = paths.size(); + size_t temp_space_size = 0, max_temp_space_size = 0; std::vector operands(inputs), tmp_operands, temp_space_vec(paths_len - 1); for (int i = 0; i + 1 < paths_len; ++i) { temp_space_size += paths[i].oshape.Size(); } for (int i = 0; i < paths_len; ++i) { - max_temp_space_size = std::max(max_temp_space_size, static_cast(paths[i].oshape.Size())); + max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size()); } temp_space_size += max_temp_space_size; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { @@ -813,7 +808,7 @@ inline void NumpyEinsumForward(const OpStatePtr& state_ptr, false, outputs[0].type_flag_)); Tensor temp_space = state.tempspace->data().FlatTo1D(); - int begin = max_temp_space_size; + size_t begin = max_temp_space_size; for (int i = 0; i < paths_len - 1; ++i) { TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size())); temp_space_vec[i] = tblob.reshape(paths[i].oshape); @@ -910,12 +905,13 @@ inline void NumpyEinsumBackward(const OpStatePtr& state_ptr, } // calculate temporary space size for temp_grad const std::vector& paths = state.paths; - int paths_len = paths.size(), temp_space_size = 0, max_temp_space_size = 0; + int paths_len = paths.size(); + size_t temp_space_size = 0, max_temp_space_size = 0; for (int i = 0; i < paths_len - 1; ++i) { temp_space_size += paths[i].oshape.Size(); } for (int i = 0; i < paths_len; ++i) { - max_temp_space_size = std::max(max_temp_space_size, static_cast(paths[i].oshape.Size())); + max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size()); } temp_space_size += max_temp_space_size; // replay the forward process @@ -936,8 +932,8 @@ inline void NumpyEinsumBackward(const OpStatePtr& state_ptr, } } // calculate temporary space size for tensordot - int tensordot_max_tempspace_size = 0; - int begin_tensordot_tempspace = 0; + size_t tensordot_max_tempspace_size = 0; + size_t begin_tensordot_tempspace = 0; std::vector temp_inputs, temp_outputs; std::vector temp_req; std::vector tensordot_tempspace_size; @@ -999,7 +995,7 @@ inline void NumpyEinsumBackward(const OpStatePtr& state_ptr, } tensordot_tempspace_size.push_back(cur_tensordot_tempspace_size); tensordot_max_tempspace_size = std::max(tensordot_max_tempspace_size, - static_cast(cur_tensordot_tempspace_size)); + cur_tensordot_tempspace_size); } begin_tensordot_tempspace = temp_space_size; temp_space_size += (tensordot_max_tempspace_size + sizeof(DType) - 1) / sizeof(DType); @@ -1010,7 +1006,7 @@ inline void NumpyEinsumBackward(const OpStatePtr& state_ptr, // allocate temporary space for gradients of intermediate results Tensor temp_space = ctx.requested[0].get_space_typed (Shape1(temp_space_size), s); - int begin = max_temp_space_size; + size_t begin = max_temp_space_size; for (int i = 0; i + 1 < paths_len; ++i) { TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size())); temp_grad[i] = tblob.reshape(paths[i].oshape); diff --git a/src/operator/numpy/np_einsum_op.cc b/src/operator/numpy/np_einsum_op.cc index 4d232b9b7c04..522780f5f3ad 100644 --- a/src/operator/numpy/np_einsum_op.cc +++ b/src/operator/numpy/np_einsum_op.cc @@ -305,6 +305,17 @@ bool NumpyEinsumShape(const nnvm::NodeAttrs& attrs, oshape[i] = dimension_dict[static_cast(output_str[i])]; } SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + size_t lim = static_cast(std::numeric_limits::max()); + for (int i = 0; i < num_args; ++i) { + CHECK_LE(in_attrs->at(i).Size(), lim) + << "Size of operand " << i + << " exceeds the maximum index." + << " Try setting `USE_INT64_TENSOR_SIZE`."; + } + CHECK_LE(oshape.Size(), lim) + << "Size of output" + << " exceeds the maximum index." + << " Try setting `USE_INT64_TENSOR_SIZE`."; return shape_is_known(oshape); } diff --git a/src/operator/numpy/np_einsum_path_op-inl.h b/src/operator/numpy/np_einsum_path_op-inl.h index cebd4e8ce9af..968d52106da7 100644 --- a/src/operator/numpy/np_einsum_path_op-inl.h +++ b/src/operator/numpy/np_einsum_path_op-inl.h @@ -80,7 +80,7 @@ struct Contraction { }; struct Alternative { - int cost[2]; + int64_t cost[2]; std::vector positions; SetVector new_input_sets; }; @@ -115,28 +115,28 @@ inline size_t _compute_size_by_dict(const std::bitset& indices, return ret; } -inline int _flop_count(const std::string& idx_contraction, - bool inner, - int num_terms, - const dim_t size_dictionary[]) { +inline int64_t _flop_count(const std::string& idx_contraction, + bool inner, + int num_terms, + const dim_t size_dictionary[]) { size_t overall_size = _compute_size_by_dict(idx_contraction, size_dictionary); int op_factor = std::max(1, num_terms - 1); if (inner) { ++op_factor; } - return overall_size * op_factor; + return static_cast(overall_size) * op_factor; } -inline int _flop_count(const std::bitset& idx_contraction, - bool inner, - int num_terms, - const dim_t size_dictionary[]) { +inline int64_t _flop_count(const std::bitset& idx_contraction, + bool inner, + int num_terms, + const dim_t size_dictionary[]) { size_t overall_size = _compute_size_by_dict(idx_contraction, size_dictionary); int op_factor = std::max(1, num_terms - 1); if (inner) { ++op_factor; } - return overall_size * op_factor; + return static_cast(overall_size) * op_factor; } inline Contraction _find_contraction(const std::vector& positions, @@ -164,16 +164,16 @@ inline int _parse_possible_contraction(const std::vector& positions, const SetVector& input_sets, const std::bitset& output_set, const dim_t idx_dict[], - int memory_limit, - int path_cost, - int naive_cost, + size_t memory_limit, + int64_t path_cost, + int64_t naive_cost, Alternative* ret) { // Find the contraction Contraction contract = _find_contraction(positions, input_sets, output_set); // Sieve the results based on memory_limit size_t new_size = _compute_size_by_dict(contract.new_result, idx_dict); - if (new_size > static_cast(memory_limit)) { + if (new_size > memory_limit) { return -1; } @@ -182,10 +182,10 @@ inline int _parse_possible_contraction(const std::vector& positions, for (auto p : positions) { old_sizes += _compute_size_by_dict(input_sets[p], idx_dict); } - int remove_size = old_sizes - new_size; + int64_t remove_size = static_cast(old_sizes) - static_cast(new_size); - int cost = _flop_count(contract.idx_contract, contract.idx_removed.any(), - positions.size(), idx_dict); + int64_t cost = _flop_count(contract.idx_contract, contract.idx_removed.any(), + positions.size(), idx_dict); ret->cost[0] = -remove_size; ret->cost[1] = cost; @@ -206,7 +206,7 @@ inline void _update_other_results(std::vector* results, int bx = best_con[0], by = best_con[1]; size_t size = results->size(); - for (int i = size - 1; i >= 0; --i) { + for (int i = static_cast(size) - 1; i >= 0; --i) { int x = results->at(i).positions[0], y = results->at(i).positions[1]; // Ignore results involving tensors just contracted @@ -233,9 +233,9 @@ inline void _update_other_results(std::vector* results, inline std::vector > _greedy_path(const SetVector* input_sets, const std::bitset& output_set, const dim_t idx_dict[], - int memory_limit) { - size_t isize = input_sets->size(); - size_t iteration_num = isize; + size_t memory_limit) { + int isize = static_cast(input_sets->size()); + int iteration_num = isize; // Handle trivial cases that leaked through if (isize == 1) { return std::vector >{std::vector{0}}; @@ -245,23 +245,23 @@ inline std::vector > _greedy_path(const SetVector* input_sets, // Build up a naive cost std::vector range(isize); - for (size_t i = 0; i < isize; ++i) { + for (int i = 0; i < isize; ++i) { range[i] = i; } Contraction contract = _find_contraction(range, *input_sets, output_set); - int naive_cost = _flop_count(contract.idx_contract, contract.idx_removed.any(), - isize, idx_dict); + int64_t naive_cost = _flop_count(contract.idx_contract, contract.idx_removed.any(), + isize, idx_dict); // Initially iterate over all pairs std::vector known_contractions; Alternative best; - int path_cost = 0; + int64_t path_cost = 0; std::vector > ret; - for (size_t iteration = 0; iteration + 1 < iteration_num; ++iteration) { + for (int iteration = 0; iteration + 1 < iteration_num; ++iteration) { if (iteration == 0) { - for (int x = 0; x < static_cast(isize); ++x) { - for (int y = x + 1; y < static_cast(isize); ++y) { + for (int x = 0; x < isize; ++x) { + for (int y = x + 1; y < isize; ++y) { if (!((input_sets->at(x) & input_sets->at(y)).any())) { continue; } @@ -280,7 +280,7 @@ inline std::vector > _greedy_path(const SetVector* input_sets, } } } else { - for (int x = 0; x < static_cast(isize) - 1; ++x) { + for (int x = 0; x < isize - 1; ++x) { int y = isize - 1; if (!((input_sets->at(x) & input_sets->at(y)).any())) { continue; @@ -303,8 +303,8 @@ inline std::vector > _greedy_path(const SetVector* input_sets, // If we do not have a inner contraction, rescan pairs including outer products if (known_contractions.size() == 0) { // Then check the outer productsj - for (int x = 0; x < static_cast(isize); ++x) { - for (int y = x + 1; y < static_cast(isize); ++y) { + for (int x = 0; x < isize; ++x) { + for (int y = x + 1; y < isize; ++y) { Alternative alternative; int result = _parse_possible_contraction(std::vector{x, y}, *input_sets, @@ -323,7 +323,7 @@ inline std::vector > _greedy_path(const SetVector* input_sets, // If we still did not find any remaining contractions, default back to einsum like behavior if (known_contractions.size() == 0) { std::vector range(isize); - for (size_t i = 0; i < isize; ++i) { + for (int i = 0; i < isize; ++i) { range[i] = i; } ret.push_back(range); @@ -332,17 +332,17 @@ inline std::vector > _greedy_path(const SetVector* input_sets, } // Sort based on first index - int best_cost[2], idx = -1; - size_t size = known_contractions.size(); - for (size_t i = 0; i < size; ++i) { + int64_t best_cost[2]; + int idx = -1, size = static_cast(known_contractions.size()); + for (int i = 0; i < size; ++i) { auto x = known_contractions[i]; if (idx == -1) { best_cost[0] = x.cost[0]; best_cost[1] = x.cost[1]; idx = i; } else if (x.cost[0] < best_cost[0] || - (x.cost[0] == best_cost[0] && - x.cost[1] < best_cost[1])) { + (x.cost[0] == best_cost[0] && + x.cost[1] < best_cost[1])) { best_cost[0] = x.cost[0]; best_cost[1] = x.cost[1]; idx = i; @@ -356,7 +356,7 @@ inline std::vector > _greedy_path(const SetVector* input_sets, // Next iteration only compute contractions with the new tensor // All other contractions have been accounted for input_sets = &best.new_input_sets; - isize = input_sets->size(); + isize = static_cast(input_sets->size()); // Update path and total cost ret.push_back(best.positions); @@ -708,9 +708,9 @@ inline std::vector einsum_path(const std::string& subscripts, // Build a few useful list and sets std::vector input_list = split(parsed_subscripts[0], ","); - size_t isize = input_list.size(); + int isize = static_cast(input_list.size()); SetVector input_sets; - for (int i = 0; i < static_cast(isize); ++i) { + for (int i = 0; i < isize; ++i) { input_sets.push_back(str2set(input_list[i])); } std::bitset output_set = str2set(parsed_subscripts[1]); @@ -721,7 +721,7 @@ inline std::vector einsum_path(const std::string& subscripts, dim_t dimension_dict[MAXAXIS]; SetVector broadcast_indices(isize); memset(dimension_dict, -1, sizeof(dimension_dict)); - for (size_t i = 0; i < isize; ++i) { + for (int i = 0; i < isize; ++i) { const std::string& term = input_list[i]; const TShape& sh = operands[i].shape_; CHECK_EQ(sh.ndim(), term.length()) @@ -756,8 +756,8 @@ inline std::vector einsum_path(const std::string& subscripts, // Compute size of each input array plus the output array std::vector size_list(isize + 1); - size_t max_size = -1, memory_arg; - for (size_t i = 0; i < isize; ++i) { + size_t max_size = 0, memory_arg; + for (int i = 0; i < isize; ++i) { size_list[i] = _compute_size_by_dict(input_list[i], dimension_dict); max_size = std::max(max_size, size_list[i]); } @@ -778,7 +778,7 @@ inline std::vector einsum_path(const std::string& subscripts, std::vector > path; if (optimize == false) { path.push_back(std::vector()); - for (size_t i = 0; i < isize; ++i) { + for (int i = 0; i < isize; ++i) { path[0].push_back(i); } } else { @@ -801,7 +801,7 @@ inline std::vector einsum_path(const std::string& subscripts, Contraction contract = _find_contraction(contract_inds, input_sets, output_set); input_sets = contract.remaining; - int cost = _flop_count(contract.idx_contract, + int64_t cost = _flop_count(contract.idx_contract, contract.idx_removed.any(), contract_inds.size(), dimension_dict); @@ -847,9 +847,9 @@ inline std::vector einsum_path(const std::string& subscripts, a < b); }); } - size_t len_idx_result = idx_result.length(); + int len_idx_result = static_cast(idx_result.length()); ret[i].oshape = TShape(len_idx_result, -1); - for (size_t j = 0; j < len_idx_result; ++j) { + for (int j = 0; j < len_idx_result; ++j) { ret[i].oshape[j] = dimension_dict[static_cast(idx_result[j])]; } @@ -867,18 +867,18 @@ inline std::vector einsum_path(const std::string& subscripts, std::vector left_pos, right_pos; left_pos.reserve(MAXAXIS); right_pos.reserve(MAXAXIS); - size_t tmp[MAXAXIS] = {0}; - size_t length_left_input = tmp_inputs[0].length(); - size_t length_right_input = tmp_inputs[1].length(); - for (size_t j = 0; j < length_right_input; ++j) { + int tmp[MAXAXIS] = {0}; + int length_left_input = static_cast(tmp_inputs[0].length()); + int length_right_input = static_cast(tmp_inputs[1].length()); + for (int j = 0; j < length_right_input; ++j) { if (contract.idx_removed.test(static_cast(tmp_inputs[1][j]))) { tmp[static_cast(tmp_inputs[1][j])] = j; } } - for (size_t j = 0; j < length_left_input; ++j) { + for (int j = 0; j < length_left_input; ++j) { if (contract.idx_removed.test(static_cast(tmp_inputs[0][j]))) { - left_pos.push_back(static_cast(j)); - right_pos.push_back(static_cast(tmp[static_cast(tmp_inputs[0][j])])); + left_pos.push_back(j); + right_pos.push_back(tmp[static_cast(tmp_inputs[0][j])]); } } // Calculate left_pos and right_pos @@ -887,11 +887,11 @@ inline std::vector einsum_path(const std::string& subscripts, // Calculate do_einsum ret[i].do_einsum = (tensor_result != idx_result); // Calculate tshape - CHECK_EQ(tensor_result.length(), len_idx_result) + CHECK_EQ(static_cast(tensor_result.length()), len_idx_result) << "tensordot produces dim " << tensor_result.length() << ", while einsum produces dim " << len_idx_result << "."; ret[i].tshape = TShape(len_idx_result, -1); - for (size_t j = 0; j < len_idx_result; ++j) { + for (int j = 0; j < len_idx_result; ++j) { ret[i].tshape[j] = dimension_dict[static_cast(tensor_result[j])]; } // Calculate blas2einsum_str diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 9e8156f3239c..62004ac6d263 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -258,6 +258,7 @@ def _add_workload_einsum(): size_dict = dict(zip(chars, sizes)) configs = [ + # test_einsum_broadcast ('ij...,j...->ij...', [(2, 3, 4), (3,)]), ('ij...,...j->ij...', [(2, 3, 4), (3,)]), ('ij...,j->ij...', [(2, 3, 4), (3,)]), @@ -310,6 +311,39 @@ def _add_workload_einsum(): ('abjk,kl,jl,ab->ab', [(1, 1, 5, 4), (4, 6), (5, 6), (7, 7)]), ('obk,ijk->ioj', [(2, 4, 8), (2, 4, 8)]), ] + # check_einsum_sums + configs.extend([('i->', [(i,)]) for i in range(1, 17)]) + configs.extend([('...i->...', [(2, 3, i,)]) for i in range(1, 17)]) + configs.extend([('i...->...', [(2, i,)]) for i in range(1, 17)]) + configs.extend([('i...->...', [(2, 3, i,)]) for i in range(1, 17)]) + configs.extend([('ii', [(i, i,)]) for i in range(1, 17)]) + configs.extend([('..., ...', [(3, i,), (2, 3, i,)]) for i in range(1, 17)]) + configs.extend([('...i, ...i', [(2, 3, i,), (i,)]) for i in range(1, 17)]) + configs.extend([('i..., i...', [(i, 3, 2,), (i,)]) for i in range(1, 11)]) + configs.extend([('i, j', [(3,), (i,)]) for i in range(1, 17)]) + configs.extend([('ij, j', [(4, i), (i,)]) for i in range(1, 17)]) + configs.extend([('ji, j', [(i, 4), (i,)]) for i in range(1, 17)]) + configs.extend([('ij, jk', [(4, i), (i, 6)]) for i in range(1, 8)]) + configs.extend([ + ('ij,jk,kl', [(3, 4), (4, 5), (5, 6)]), + ('ijk, jil -> kl', [(3, 4, 5), (4, 3, 2)]), + ('i, i, i -> i', [(8,), (8,), (8,)]), + (',i->', [(), (9,)]), + ('i,->', [(9,), ()]), + ]) + configs.extend([('...,...', [(n,), (n,)]) for n in range(1, 25)]) + configs.extend([('i,i', [(n,), (n,)]) for n in range(1, 25)]) + configs.extend([('i,->i', [(n,), ()]) for n in range(1, 25)]) + configs.extend([(',i->i', [(), (n,)]) for n in range(1, 25)]) + configs.extend([('i,->', [(n,), ()]) for n in range(1, 25)]) + configs.extend([(',i->', [(), (n,)]) for n in range(1, 25)]) + configs.extend([('...,...', [(n - 1,), (n - 1,)]) for n in range(1, 25)]) + configs.extend([('i,i', [(n - 1,), (n - 1,)]) for n in range(1, 25)]) + configs.extend([('i,->i', [(n - 1,), ()]) for n in range(1, 25)]) + configs.extend([(',i->i', [(), (n - 1,)]) for n in range(1, 25)]) + configs.extend([('i,->', [(n - 1,), ()]) for n in range(1, 25)]) + configs.extend([(',i->', [(), (n - 1,)]) for n in range(1, 25)]) + for optimize in [False, True]: for config in configs: subscripts, args = config diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index b764ac73d30c..227e1c625ef6 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3377,16 +3377,22 @@ def dbg(name, data): _np.dot(args[0].T, _np.dot(_np.ones((2, 2)), args[2].T)), _np.dot(_np.dot(args[0], args[1]).T, _np.ones((2, 2))))), # broadcast bug - (('ij, ij -> i'), [(1, 4), (2, 4)], lambda *args: (_np.sum(args[1], axis=0)[None, :], - _np.tile(args[0], [2, 1]))), + ('ij, ij -> i', [(1, 4), (2, 4)], lambda *args: (_np.sum(args[1], axis=0)[None, :], + _np.tile(args[0], [2, 1]))), + # issue #16576 + # commented due to long running time + # ('abiz,abjz->abij', [(64, 8, 128, 512), (64, 8, 128, 512)], lambda *args: (_np.matmul(_np.ones((64, 8, 128, 128)), args[1]), + # _np.matmul(_np.ones((64, 8, 128, 128)), args[0]))), ] - dtypes = ['int32', 'float16', 'float32', 'float64'] + dtypes = ['float16', 'float32', 'float64', 'int32'] + acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64', + 'int32': 'int64'} for hybridize in [False, True]: for dtype in dtypes: for config in configs: for optimize in [False, True]: - rtol = 1e-0 if dtype == 'float16' else 1e-3 - atol = 1e-1 if dtype == 'float16' else 1e-5 + rtol = 1e-2 if dtype == 'float16' else 1e-3 + atol = 1e-4 if dtype == 'float16' else 1e-5 (subscripts, operands, get_grad) = config test_einsum = TestEinsum(subscripts, optimize) if hybridize: @@ -3394,11 +3400,11 @@ def dbg(name, data): x = [] x_np = [] for shape in operands: - x_np.append(_np.array(_np.random.uniform(-10.0, 10.0, shape), - dtype=dtype)) - x.append(np.array(x_np[-1], dtype=dtype)) + tmp = _np.array(_np.random.uniform(-1.0, 1.0, shape), dtype=dtype) + x_np.append(tmp.astype(acc_type[dtype])) + x.append(np.array(tmp, dtype=dtype)) x[-1].attach_grad() - expected_np = _np.einsum(subscripts, *x_np, optimize=optimize) + expected_np = _np.einsum(subscripts, *x_np, optimize=optimize).astype(dtype) with mx.autograd.record(): out_mx = test_einsum(*x) assert out_mx.shape == expected_np.shape @@ -3416,7 +3422,7 @@ def dbg(name, data): expected_np = _np.einsum(subscripts, *x_np, optimize=optimize) assert_almost_equal(out_mx.asnumpy(), expected_np, rtol=rtol, atol=atol) for (iop, op) in enumerate(x): - assert_almost_equal(op.grad.asnumpy(), get_grad(*x_np)[iop], rtol=rtol, atol=atol) + assert_almost_equal(op.grad.asnumpy(), get_grad(*x_np)[iop].astype(dtype), rtol=rtol, atol=atol) configs = [ (('ij,jk,kl->il'), [(2, 2), (2, 5), (5, 2)]), (('ea,fb,abcd,gc,hd->efgh'), [(5, 5), (5, 5), (5, 5, 5, 5), (5, 5), (5, 5)]), @@ -3426,8 +3432,8 @@ def dbg(name, data): for dtype in dtypes: for config in configs: (subscripts, operands) = config - rtol = 1e-0 if dtype == 'float16' else 1e-2 - atol = 1e-1 if dtype == 'float16' else 1e-2 + rtol = 1e-2 if dtype == 'float16' else 1e-3 + atol = 1e-4 if dtype == 'float16' else 1e-5 grad = [] x_np = [] for shape in operands: @@ -3441,7 +3447,8 @@ def dbg(name, data): test_einsum = TestEinsum(subscripts, optimize) if hybridize: test_einsum.hybridize() - expected_np = _np.einsum(subscripts, *x_np, optimize=optimize) + expected_np = _np.einsum(subscripts, *[op.astype(acc_type[dtype]) for op in x_np], + optimize=optimize).astype(dtype) with mx.autograd.record(): out_mx = test_einsum(*x) assert out_mx.shape == expected_np.shape