Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
minor unittest update
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanhenneking committed Jul 18, 2017
1 parent f670b45 commit cdac7a1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
3 changes: 1 addition & 2 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from test_operator import *
from test_optimizer import *
from test_random import *
from test_sparse_operator import test_sparse_dot
from test_sparse_operator import test_cast_storage_ex
from test_sparse_operator import test_cast_storage_ex, test_sparse_dot
import mxnet as mx
import numpy as np
from mxnet.test_utils import check_consistency, set_default_context
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1):
rtol=1e-3, atol=1e-4)

lhs_shape = rand_shape_2d(50, 200)
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 4)), 'default', False) # test gpu
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 4)), 'default', True ) # vector kernel
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False) # test gpu
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True ) # scalar kernel
test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False) # test gpu SpMV
test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True ) # (vector kernel)
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False) # test gpu SpMM
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True ) # (scalar kernel)
if default_context().device_type is 'cpu':
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False)
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True )
Expand Down

0 comments on commit cdac7a1

Please sign in to comment.