-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add abs function to parameterexpression #7497
Changes from 7 commits
03f43f8
eff2f34
966ee8e
fa1bae4
93e1d6d
a30b8db
1b98392
296e6ae
420435e
f72c7fe
7af0db7
76e8719
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -468,6 +468,15 @@ def __copy__(self): | |
def __deepcopy__(self, memo=None): | ||
return self | ||
|
||
def __abs__(self): | ||
"""Absolute of a ParameterExpression""" | ||
if HAS_SYMENGINE: | ||
return self._call(symengine.Abs) | ||
else: | ||
from sympy import Abs as _abs | ||
|
||
return self._call(_abs) | ||
|
||
def __eq__(self, other): | ||
"""Check if this parameter expression is equal to another parameter expression | ||
or a fixed value (only if this is a bound expression). | ||
|
@@ -490,6 +499,74 @@ def __eq__(self, other): | |
return len(self.parameters) == 0 and complex(self._symbol_expr) == other | ||
return False | ||
|
||
def __lt__(self, other: ParameterValueType) -> bool: | ||
"""Check if this parameter expression is less than another parameter expression | ||
or a fixed value (only if this is a bound expression). | ||
Args: | ||
other (ParameterExpression or a number): | ||
Parameter expression or numeric constant used for comparison | ||
Raises: | ||
TypeError: | ||
- If comparison to number with unbound parameters. | ||
- If comparison to type of other object unsupported. | ||
Returns: | ||
bool: result of the comparison | ||
""" | ||
if isinstance(other, ParameterExpression): | ||
if HAS_SYMENGINE: | ||
from sympy import sympify | ||
|
||
return sympify(self._symbol_expr).__lt__(sympify(other._symbol_expr)) | ||
else: | ||
return self._symbol_expr.__lt__(other._symbol_expr) | ||
elif isinstance(other, numbers.Number): | ||
if len(self.parameters) == 0: | ||
return float(self._symbol_expr) < other | ||
else: | ||
raise TypeError( | ||
"'<' not supported between instances of {type(self)} " | ||
"with unbound parameters {self.parameters} and {type(other)}." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like there will be situations where an expression might not have a single numerical value, but can be bound for certain comparisons. For example, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved the implementation to #7539, please have a look there. |
||
else: | ||
raise TypeError("'<' not supported between instances of {type(self)} and {type(other)}") | ||
czachow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __gt__(self, other: ParameterValueType) -> bool: | ||
"""Check if this parameter expression is greater than another parameter expression | ||
or a fixed value (only if this is a bound expression). | ||
Args: | ||
other (ParameterExpression or a number): | ||
czachow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Parameter expression or numeric constant used for comparison | ||
Raises: | ||
TypeError: | ||
- If comparison to number with unbound parameters. | ||
- If comparison to type of other object unsupported. | ||
Returns: | ||
bool: result of the comparison | ||
""" | ||
if isinstance(other, ParameterExpression): | ||
if HAS_SYMENGINE: | ||
from sympy import sympify | ||
|
||
return sympify(self._symbol_expr).__gt__(sympify(other._symbol_expr)) | ||
else: | ||
return self._symbol_expr.__gt__(other._symbol_expr) | ||
elif isinstance(other, numbers.Number): | ||
if len(self.parameters) == 0: | ||
return float(self._symbol_expr) > other | ||
else: | ||
raise TypeError( | ||
"'>' not supported between instances of {type(self)} " | ||
"with unbound parameters {self.parameters} and {type(other)}." | ||
) | ||
else: | ||
raise TypeError("'>' not supported between instances of {type(self)} and {type(other)}") | ||
|
||
def __ge__(self, other: ParameterValueType) -> bool: | ||
return not self.__lt__(other) | ||
|
||
def __le__(self, other: ParameterValueType) -> bool: | ||
return not self.__gt__(other) | ||
|
||
def __getstate__(self): | ||
if HAS_SYMENGINE: | ||
from sympy import sympify | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1178,7 +1178,30 @@ def test_compare_to_value_when_bound(self): | |
|
||
x = Parameter("x") | ||
bound_expr = x.bind({x: 2.3}) | ||
|
||
czachow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.assertEqual(bound_expr, 2.3) | ||
self.assertTrue(bound_expr < 3.0) | ||
self.assertTrue(bound_expr > 1.0) | ||
self.assertEqual(abs(bound_expr), 2.3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we keep the comparison operators There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be best to separate out the tests of Also, I think we could do with more stringent tests of it; at the moment, if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As suggested, I added more tests to the implementation. What do you mean by testing symbolic expressions, such as |
||
|
||
def test_raise_if_compare_not_supported(self): | ||
"""Verify raises if compare to object.""" | ||
x = Parameter("x") | ||
y = object | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You probably meant to instantiate an |
||
|
||
with self.assertRaisesRegex(TypeError, "not supported"): | ||
x.__lt__(y) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It'd be clearer to write |
||
with self.assertRaisesRegex(TypeError, "not supported"): | ||
x.__gt__(y) | ||
|
||
def test_raise_if_compare_to_value_not_bound(self): | ||
"""Verify raises if compare to value and not bound.""" | ||
x = Parameter("x") | ||
|
||
with self.assertRaisesRegex(TypeError, "unbound parameters"): | ||
x.__gt__(2.3) | ||
with self.assertRaisesRegex(TypeError, "unbound parameters"): | ||
x.__lt__(2.3) | ||
|
||
def test_cast_to_float_when_bound(self): | ||
"""Verify expression can be cast to a float when fully bound.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These don't do what you think they do; they don't always return Boolean values. Also, in general it's better to use the Python operations, and let Python handle the numerics; this way can fail if the left-hand side doesn't evaluate
__lt__
, but the right hand does evaluate__gt__
, whereas if you'd used<
in that situation, Python would evaluate it correctly.For not returning a Boolean, see: