diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 976207fd4e70..f55381428615 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1204,7 +1204,7 @@ def check_speed(sym, location=None, ctx=None, N=20, grad_req=None, typ="whole", def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', arg_params=None, aux_params=None, tol=None, raise_on_err=True, ground_truth=None, equal_nan=False, - use_uniform=False, default_type=np.float64): + use_uniform=False, rand_type=np.float64): """Check symbol gives the same output for different running context Parameters @@ -1221,6 +1221,10 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', Optional, When flag set to true, random input data generated follows uniform distribution, not normal distribution + rand_type: np.dtype + Optional, when input data is passed via arg_params, + defaults to np.float64 (python float default) + Examples -------- >>> # create the symbol @@ -1282,10 +1286,10 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', if n not in arg_params: if use_uniform: arg_params[n] = np.random.uniform(low=-0.92, high=0.92, - size=arr.shape).astype(default_type) + size=arr.shape).astype(rand_type) else: arg_params[n] = np.random.normal(size=arr.shape, - scale=scale).astype(default_type) + scale=scale).astype(rand_type) for n, arr in exe_list[0].aux_dict.items(): if n not in aux_params: aux_params[n] = 0 diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 921844acc085..f1cae5199587 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -614,13 +614,13 @@ def test_pooling_with_type(): {'ctx': mx.cpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float64}}, {'ctx': mx.cpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}] sym = mx.sym.Pooling(kernel=(3,3), pool_type='max', pooling_convention='valid', name='pool') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list, rand_type=np.float16) sym = mx.sym.Pooling(kernel=(3,3), pool_type='max', pooling_convention='full', name='pool') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list, rand_type=np.float16) sym = mx.sym.Pooling(kernel=(300,300), pool_type='max', global_pool=True, name='pool') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list, rand_type=np.float16) @with_seed() @@ -774,16 +774,16 @@ def test_pooling_with_type2(): {'ctx': mx.cpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}] sym = mx.sym.Pooling(name='pool', kernel=(3,3), stride=(2,2), pool_type='max') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list, rand_type=np.float16) sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='avg') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list) sym = mx.sym.Pooling(name='pool', kernel=(5,5), pad=(2,2), pool_type='max') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list, rand_type=np.float16) sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='sum') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list) @unittest.skip("Flaky test https://github.com/apache/incubator-mxnet/issues/11517") @with_seed()