Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Added default tolerance levels for regression checks for MBCC (#12006)
Browse files Browse the repository at this point in the history
* Added tolerance level for assert_almost_equal for MBCC

* Nudge to CI
  • Loading branch information
piyushghai authored and marcoabreu committed Aug 3, 2018
1 parent 619700a commit 2534164
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 2 additions & 0 deletions tests/nightly/model_backwards_compatibility_check/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
backslash = '/'
s3 = boto3.resource('s3')
ctx = mx.cpu(0)
atol_default = 1e-5
rtol_default = 1e-5


def get_model_path(model_name):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_module_checkpoint_api():
old_inference_results = load_inference_results(model_name)
inference_results = loaded_model.predict(data_iter)
# Check whether they are equal or not ?
assert_almost_equal(inference_results.asnumpy(), old_inference_results.asnumpy())
assert_almost_equal(inference_results.asnumpy(), old_inference_results.asnumpy(), rtol=rtol_default, atol=atol_default)
clean_model_files(model_files, model_name)
logging.info('=================================')

Expand All @@ -69,7 +69,7 @@ def test_lenet_gluon_load_params_api():
loaded_model.load_params(model_name + '-params')
output = loaded_model(test_data)
old_inference_results = mx.nd.load(model_name + '-inference')['inference']
assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy())
assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy(), rtol=rtol_default, atol=atol_default)
clean_model_files(model_files, model_name)
logging.info('=================================')
logging.info('Assertion passed for model : %s' % model_name)
Expand All @@ -92,7 +92,7 @@ def test_lenet_gluon_hybrid_imports_api():
loaded_model = gluon.SymbolBlock.imports(model_name + '-symbol.json', ['data'], model_name + '-0000.params')
output = loaded_model(test_data)
old_inference_results = mx.nd.load(model_name + '-inference')['inference']
assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy())
assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy(), rtol=rtol_default, atol=atol_default)
clean_model_files(model_files, model_name)
logging.info('=================================')
logging.info('Assertion passed for model : %s' % model_name)
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_lstm_gluon_load_parameters_api():
loaded_model.load_parameters(model_name + '-params')
output = loaded_model(test_data)
old_inference_results = mx.nd.load(model_name + '-inference')['inference']
assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy())
assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy(), rtol=rtol_default, atol=atol_default)
clean_model_files(model_files, model_name)
logging.info('=================================')
logging.info('Assertion passed for model : %s' % model_name)
Expand Down

0 comments on commit 2534164

Please sign in to comment.