Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix edge case assigning new numeric types to Var/Param with units #3151

Merged
merged 2 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions pyomo/core/base/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,31 @@ def set_value(self, value, idx=NOTSET):
# required to be mutable.
#
_comp = self.parent_component()
if type(value) in native_types:
if value.__class__ in native_types:
# TODO: warn/error: check if this Param has units: assigning
# a dimensionless value to a united param should be an error
pass
elif _comp._units is not None:
_src_magnitude = expr_value(value)
_src_units = units.get_units(value)
value = units.convert_value(
num_value=_src_magnitude, from_units=_src_units, to_units=_comp._units
)
# Note: expr_value() could have just registered a new numeric type
if value.__class__ in native_types:
value = _src_magnitude
else:
_src_units = units.get_units(value)
value = units.convert_value(
num_value=_src_magnitude,
from_units=_src_units,
to_units=_comp._units,
)
# FIXME: we should call value() here [to ensure types get
# registered], but doing so breks non-numeric Params (which we
blnicho marked this conversation as resolved.
Show resolved Hide resolved
# allow). The real fix will be to follow the precedent from
# GetItemExpressiona and have separate types based on which
blnicho marked this conversation as resolved.
Show resolved Hide resolved
# expression "system" the Param should participate in (numeric,
# logical, or structural).
#
# else:
# value = expr_value(value)

old_value, self._value = self._value, value
try:
Expand Down
15 changes: 10 additions & 5 deletions pyomo/core/base/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,17 +384,22 @@ def set_value(self, val, skip_validation=False):
#
# Check if this Var has units: assigning dimensionless
# values to a variable with units should be an error
if type(val) not in native_numeric_types:
if self.parent_component()._units is not None:
_src_magnitude = value(val)
if val.__class__ in native_numeric_types:
pass
elif self.parent_component()._units is not None:
_src_magnitude = value(val)
# Note: value() could have just registered a new numeric type
if val.__class__ in native_numeric_types:
val = _src_magnitude
else:
_src_units = units.get_units(val)
val = units.convert_value(
num_value=_src_magnitude,
from_units=_src_units,
to_units=self.parent_component()._units,
)
else:
val = value(val)
else:
val = value(val)

if not skip_validation:
if val not in self.domain:
Expand Down
8 changes: 7 additions & 1 deletion pyomo/core/tests/unit/test_numvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,8 @@ def test_numpy_basic_bool_registration(self):
@unittest.skipUnless(numpy_available, "This test requires NumPy")
def test_automatic_numpy_registration(self):
cmd = (
'import pyomo; from pyomo.core.base import Var, Param; import numpy as np; '
'import pyomo; from pyomo.core.base import Var, Param; '
'from pyomo.core.base.units_container import units; import numpy as np; '
'print(np.float64 in pyomo.common.numeric_types.native_numeric_types); '
'%s; print(np.float64 in pyomo.common.numeric_types.native_numeric_types)'
)
Expand All @@ -582,6 +583,11 @@ def _tester(expr):
_tester('Var() + np.float64(5)')
_tester('v = Var(); v.construct(); v.value = np.float64(5)')
_tester('p = Param(mutable=True); p.construct(); p.value = np.float64(5)')
_tester('v = Var(units=units.m); v.construct(); v.value = np.float64(5)')
_tester(
'p = Param(mutable=True, units=units.m); p.construct(); '
'p.value = np.float64(5)'
)


if __name__ == "__main__":
Expand Down