diff --git a/numexpr/__init__.py b/numexpr/__init__.py index 9cabe69..7946f85 100644 --- a/numexpr/__init__.py +++ b/numexpr/__init__.py @@ -31,7 +31,8 @@ import os, os.path import platform from numexpr.expressions import E -from numexpr.necompiler import NumExpr, disassemble, evaluate, re_evaluate +from numexpr.necompiler import (NumExpr, disassemble, evaluate, re_evaluate, + validate) from numexpr.utils import (_init_num_threads, get_vml_version, set_vml_accuracy_mode, set_vml_num_threads, diff --git a/numexpr/tests/test_numexpr.py b/numexpr/tests/test_numexpr.py index 32f5be4..ccb0b6c 100644 --- a/numexpr/tests/test_numexpr.py +++ b/numexpr/tests/test_numexpr.py @@ -31,7 +31,7 @@ from numpy import shape, allclose, array_equal, ravel, isnan, isinf import numexpr -from numexpr import E, NumExpr, evaluate, re_evaluate, disassemble, use_vml +from numexpr import E, NumExpr, evaluate, re_evaluate, validate, disassemble, use_vml from numexpr.expressions import ConstantNode import unittest @@ -370,10 +370,38 @@ def test_re_evaluate(self): assert_array_equal(x, array([86., 124., 168.])) def test_re_evaluate_dict(self): + a1 = array([1., 2., 3.]) + b1 = array([4., 5., 6.]) + c1 = array([7., 8., 9.]) + x = evaluate("2*a + 3*b*c", local_dict={'a': a1, 'b': b1, 'c': c1}) + x = re_evaluate() + assert_array_equal(x, array([86., 124., 168.])) + + def test_validate(self): a = array([1., 2., 3.]) b = array([4., 5., 6.]) c = array([7., 8., 9.]) - x = evaluate("2*a + 3*b*c", local_dict={'a': a, 'b': b, 'c': c}) + retval = validate("2*a + 3*b*c") + assert(retval is None) + x = re_evaluate() + assert_array_equal(x, array([86., 124., 168.])) + + def test_validate_missing_var(self): + a = array([1., 2., 3.]) + b = array([4., 5., 6.]) + retval = validate("2*a + 3*b*c") + assert(isinstance(retval, KeyError)) + + def test_validate_syntax(self): + retval = validate("2+") + assert(isinstance(retval, SyntaxError)) + + def test_validate_dict(self): + a1 = array([1., 2., 3.]) + b1 = array([4., 5., 6.]) + c1 = array([7., 8., 9.]) + retval = validate("2*a + 3*b*c", local_dict={'a': a1, 'b': b1, 'c': c1}) + assert(retval is None) x = re_evaluate() assert_array_equal(x, array([86., 124., 168.]))