Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
renamed default_type argument to rand_type for clarity
Browse files Browse the repository at this point in the history
updated function docstring with argument description

removed rand_type setting for non-max pooling tests
  • Loading branch information
Sam Skalicky committed Aug 8, 2018
1 parent d493036 commit 71c55fd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
10 changes: 7 additions & 3 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,7 @@ def check_speed(sym, location=None, ctx=None, N=20, grad_req=None, typ="whole",
def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
arg_params=None, aux_params=None, tol=None,
raise_on_err=True, ground_truth=None, equal_nan=False,
use_uniform=False, default_type=np.float64):
use_uniform=False, rand_type=np.float64):
"""Check symbol gives the same output for different running context
Parameters
Expand All @@ -1221,6 +1221,10 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
Optional, When flag set to true,
random input data generated follows uniform distribution,
not normal distribution
rand_type: np.dtype
Optional, when input data is passed via arg_params,
defaults to np.float64 (python float default)
Examples
--------
>>> # create the symbol
Expand Down Expand Up @@ -1282,10 +1286,10 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
if n not in arg_params:
if use_uniform:
arg_params[n] = np.random.uniform(low=-0.92, high=0.92,
size=arr.shape).astype(default_type)
size=arr.shape).astype(rand_type)
else:
arg_params[n] = np.random.normal(size=arr.shape,
scale=scale).astype(default_type)
scale=scale).astype(rand_type)
for n, arr in exe_list[0].aux_dict.items():
if n not in aux_params:
aux_params[n] = 0
Expand Down
14 changes: 7 additions & 7 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,13 +614,13 @@ def test_pooling_with_type():
{'ctx': mx.cpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float64}},
{'ctx': mx.cpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}]
sym = mx.sym.Pooling(kernel=(3,3), pool_type='max', pooling_convention='valid', name='pool')
check_consistency(sym, ctx_list, default_type=np.float16)
check_consistency(sym, ctx_list, rand_type=np.float16)

sym = mx.sym.Pooling(kernel=(3,3), pool_type='max', pooling_convention='full', name='pool')
check_consistency(sym, ctx_list, default_type=np.float16)
check_consistency(sym, ctx_list, rand_type=np.float16)

sym = mx.sym.Pooling(kernel=(300,300), pool_type='max', global_pool=True, name='pool')
check_consistency(sym, ctx_list, default_type=np.float16)
check_consistency(sym, ctx_list, rand_type=np.float16)


@with_seed()
Expand Down Expand Up @@ -774,16 +774,16 @@ def test_pooling_with_type2():
{'ctx': mx.cpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}]

sym = mx.sym.Pooling(name='pool', kernel=(3,3), stride=(2,2), pool_type='max')
check_consistency(sym, ctx_list, default_type=np.float16)
check_consistency(sym, ctx_list, rand_type=np.float16)

sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='avg')
check_consistency(sym, ctx_list, default_type=np.float16)
check_consistency(sym, ctx_list)

sym = mx.sym.Pooling(name='pool', kernel=(5,5), pad=(2,2), pool_type='max')
check_consistency(sym, ctx_list, default_type=np.float16)
check_consistency(sym, ctx_list, rand_type=np.float16)

sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='sum')
check_consistency(sym, ctx_list, default_type=np.float16)
check_consistency(sym, ctx_list)

@unittest.skip("Flaky test https://github.com/apache/incubator-mxnet/issues/11517")
@with_seed()
Expand Down

0 comments on commit 71c55fd

Please sign in to comment.