Skip to content

Commit

Permalink
add runtime storage fallback detection" (apache#48)
Browse files Browse the repository at this point in the history
* add runtime storage fallback detection"

* replace cast storage ex with cast storage impl
  • Loading branch information
eric-haibin-lin committed May 24, 2017
1 parent eb250de commit b5bcdd6
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
18 changes: 16 additions & 2 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,29 @@ void CastStorageComputeEx(const nnvm::NodeAttrs& attrs,
namespace common {

#if DMLC_USE_CXX11
/*
* \brief Get input TBlobs from NDArrays, potentially performing cast_storage op and store
* temporary NDArrays in temps. If storage_fallback is false,
* MXNET_EXEC_STORAGE_FALLBACK env var determines whether storage type fallback is allowed.
*/
template <typename xpu>
inline void GetInputBlobs(const std::vector<NDArray>& nds,
std::vector<TBlob> *blobs,
std::vector<NDArray> *temps,
const OpContext& ctx) {
const OpContext& ctx,
bool storage_fallback = false) {
if (storage_fallback == false) {
storage_fallback = dmlc::GetEnv("MXNET_EXEC_STORAGE_FALLBACK", true);
}
for (auto& nd : nds) {
if (nd.storage_type() != kDefaultStorage) {
if (storage_fallback == false) {
LOG(FATAL) << "Storage type conversion detected during execution. "
<< "You are probably executing an operator which "
<< "doesn't support NDArray inputs with non-default storage.";
}
NDArray temp(nd.shape(), nd.ctx(), false);
op::CastStorageComputeEx<xpu>({}, ctx, {nd}, {}, {temp});
op::CastStorageComputeImpl<xpu>(ctx.get_stream<xpu>(), nd, temp);
temps->push_back(temp);
blobs->push_back(temp.data());
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ void FCompExFallback(const nnvm::NodeAttrs& attrs,
FCompute fcompute) {
std::vector<TBlob> in_blobs, out_blobs;
std::vector<NDArray> tmps;
common::GetInputBlobs<xpu>(inputs, &in_blobs, &tmps, ctx);
common::GetInputBlobs<xpu>(inputs, &in_blobs, &tmps, ctx, true);
common::GetOutputBlobs<xpu>(outputs, &out_blobs);
fcompute(attrs, ctx, in_blobs, req, out_blobs);
}
Expand Down
20 changes: 14 additions & 6 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from numpy.testing import assert_allclose
import numpy.random as rnd

def assert_fcompex(f, *args, **kwargs):
prev_val = mx.test_utils.set_env_var("MXNET_EXEC_STORAGE_FALLBACK", "0", "1")
f(*args, **kwargs)
mx.test_utils.set_env_var("MXNET_EXEC_STORAGE_FALLBACK", prev_val)

def check_sparse_nd_elemwise_binary(shapes, storage_types, f, g):
# generate inputs
nds = []
Expand All @@ -27,11 +32,14 @@ def test_sparse_nd_elemwise_add():
op = mx.nd.elemwise_add
for i in range(num_repeats):
shape = [(rnd.randint(1, 10),rnd.randint(1, 10))] * 2
check_sparse_nd_elemwise_binary(shape, ['default_storage'] * 2, op, g)
check_sparse_nd_elemwise_binary(shape, ['default_storage', 'row_sparse'], op, g)
check_sparse_nd_elemwise_binary(shape, ['row_sparse', 'row_sparse'], op, g)

# Test a operator which doesn't implement FComputeEx
assert_fcompex(check_sparse_nd_elemwise_binary,
shape, ['default_storage'] * 2, op, g)
assert_fcompex(check_sparse_nd_elemwise_binary,
shape, ['default_storage', 'row_sparse'], op, g)
assert_fcompex(check_sparse_nd_elemwise_binary,
shape, ['row_sparse', 'row_sparse'], op, g)

# test a operator which doesn't implement FComputeEx
def test_sparse_nd_elementwise_fallback():
num_repeats = 10
g = lambda x,y: x + y
Expand Down Expand Up @@ -141,9 +149,9 @@ def check_sparse_nd_csr_slice(shape):

if __name__ == '__main__':
test_sparse_nd_zeros()
test_sparse_nd_elemwise_add()
test_sparse_nd_elementwise_fallback()
test_sparse_nd_copy()
test_sparse_nd_elemwise_add()
test_sparse_nd_setitem()
test_sparse_nd_basic()
test_sparse_nd_slice()

0 comments on commit b5bcdd6

Please sign in to comment.