diff --git a/pyomo/core/base/set.py b/pyomo/core/base/set.py index 049775fd9dd..e4a6d13e96e 100644 --- a/pyomo/core/base/set.py +++ b/pyomo/core/base/set.py @@ -16,15 +16,18 @@ import math import sys import weakref -from pyomo.common.pyomo_typing import overload -from typing import Union, Type, Any as typingAny + from collections.abc import Iterator +from functools import partial +from typing import Union, Type, Any as typingAny +from pyomo.common.autoslots import AutoSlots from pyomo.common.collections import ComponentSet from pyomo.common.deprecation import deprecated, deprecation_warning, RenamedClass from pyomo.common.errors import DeveloperError, PyomoException from pyomo.common.log import is_debug_set from pyomo.common.modeling import NOTSET +from pyomo.common.pyomo_typing import overload from pyomo.common.sorting import sorted_robust from pyomo.common.timing import ConstructionTimer @@ -478,9 +481,7 @@ def __call__(self, parent, index): if not isinstance(_val, Sequence): _val = tuple(_val) - if len(_val) == 0: - return _val - if isinstance(_val[0], tuple): + if not _val or isinstance(_val[0], tuple): return _val return self._tuplize(_val, parent, index) @@ -501,7 +502,7 @@ def _tuplize(self, _val, parent, index): "length %s is not a multiple of dimen=%s" % (len(_val), d) ) - return list(tuple(_val[d * i : d * (i + 1)]) for i in range(len(_val) // d)) + return (tuple(_val[i : i + d]) for i in range(0, len(_val), d)) class _NotFound(object): @@ -1364,87 +1365,12 @@ def filter(self): return self._filter def add(self, *values): - count = 0 - _block = self.parent_block() - for value in values: - if normalize_index.flatten: - _value = normalize_index(value) - if _value.__class__ is tuple: - _d = len(_value) - else: - _d = 1 - else: - # If we are not normalizing indices, then we cannot reliably - # infer the set dimen - _d = 1 - if isinstance(value, Sequence) and self.dimen != 1: - _d = len(value) - _value = value - if _value not in self._domain: - raise ValueError( - "Cannot add value %s to Set %s.\n" - "\tThe value is not in the domain %s" - % (value, self.name, self._domain) - ) - - # We wrap this check in a try-except because some values - # (like lists) are not hashable and can raise exceptions. - try: - if _value in self: - logger.warning( - "Element %s already exists in Set %s; no action taken" - % (value, self.name) - ) - continue - except: - exc = sys.exc_info() - raise TypeError( - "Unable to insert '%s' into Set %s:\n\t%s: %s" - % (value, self.name, exc[0].__name__, exc[1]) - ) + N = len(self) + self.update(values) + return len(self) - N - if self._filter is not None: - if not self._filter(_block, _value): - continue - - if self._validate is not None: - try: - flag = self._validate(_block, _value) - except: - logger.error( - "Exception raised while validating element '%s' " - "for Set %s" % (value, self.name) - ) - raise - if not flag: - raise ValueError( - "The value=%s violates the validation rule of Set %s" - % (value, self.name) - ) - - # If the Set has a fixed dimension, check that this element is - # compatible. - if self._dimen is not None: - if _d != self._dimen: - if self._dimen is UnknownSetDimen: - # The first thing added to a Set with unknown - # dimension sets its dimension - self._dimen = _d - else: - raise ValueError( - "The value=%s has dimension %s and is not " - "valid for Set %s which has dimen=%s" - % (value, _d, self.name, self._dimen) - ) - - # Add the value to this object (this last redirection allows - # derived classes to implement a different storage mechanism) - self._add_impl(_value) - count += 1 - return count - - def _add_impl(self, value): - self._values.add(value) + def _update_impl(self, values): + self._values.update(values) def remove(self, val): self._values.remove(val) @@ -1457,17 +1383,147 @@ def clear(self): def set_value(self, val): self.clear() - for x in val: - self.add(x) + self.update(val) + + def _initialize(self, val): + try: + # We want to explicitly call the update() on *this class* to + # bypass potential double logging of the use of unordered + # data with ordered Sets + FiniteSetData.update(self, val) + except TypeError as e: + if 'not iterable' in str(e): + logger.error( + "Initializer for Set %s returned non-iterable object " + "of type %s." + % ( + self.name, + (val if val.__class__ is type else type(val).__name__), + ) + ) + raise def update(self, values): - for v in values: - if v not in self: - self.add(v) + # Special case: set operations that are not first attached + # to the model must be constructed. + if isinstance(values, SetOperator): + values.construct() + # It is important that val_iter is an actual iterator + val_iter = iter(values) + if self._dimen is not None: + if normalize_index.flatten: + val_iter = self._cb_normalized_dimen_verifier(self._dimen, val_iter) + else: + val_iter = self._cb_raw_dimen_verifier(self._dimen, val_iter) + elif normalize_index.flatten: + val_iter = map(normalize_index, val_iter) + else: + val_iter = self._cb_check_set_end(val_iter) + + if self._domain is not Any: + val_iter = self._cb_domain_verifier(self._domain, val_iter) + + if self._filter is not None: + val_iter = filter(partial(self._filter, self.parent_block()), val_iter) + + if self._validate is not None: + val_iter = self._cb_validate(self._validate, self.parent_block(), val_iter) + + # We wrap this check in a try-except because some values + # (like lists) are not hashable and can raise exceptions. + try: + self._update_impl(val_iter) + except Set._SetEndException: + pass def pop(self): return self._values.pop() + def _cb_domain_verifier(self, domain, val_iter): + for value in val_iter: + if value not in domain: + raise ValueError( + "Cannot add value %s to Set %s.\n" + "\tThe value is not in the domain %s" + % (value, self.name, self._domain) + ) + yield value + + def _cb_check_set_end(self, val_iter): + for value in val_iter: + if value is Set.End: + return + yield value + + def _cb_validate(self, validate, block, val_iter): + for value in val_iter: + try: + flag = validate(block, value) + except: + logger.error( + "Exception raised while validating element '%s' " + "for Set %s" % (value, self.name) + ) + raise + if not flag: + raise ValueError( + "The value=%s violates the validation rule of Set %s" + % (value, self.name) + ) + yield value + + def _cb_normalized_dimen_verifier(self, dimen, val_iter): + for value in val_iter: + if value.__class__ in native_types: + if dimen == 1: + yield value + continue + normalized_value = value + else: + normalized_value = normalize_index(value) + # Note: normalize_index() will never return a 1-tuple + if normalized_value.__class__ is tuple: + if dimen == len(normalized_value): + yield normalized_value[0] if dimen == 1 else normalized_value + continue + + _d = len(normalized_value) if normalized_value.__class__ is tuple else 1 + if _d == dimen: + yield normalized_value + elif dimen is UnknownSetDimen: + # The first thing added to a Set with unknown dimension + # sets its dimension + self._dimen = dimen = _d + yield normalized_value + else: + raise ValueError( + "The value=%s has dimension %s and is not " + "valid for Set %s which has dimen=%s" + % (value, _d, self.name, self._dimen) + ) + + def _cb_raw_dimen_verifier(self, dimen, val_iter): + for value in val_iter: + if isinstance(value, Sequence): + if dimen == len(value): + yield value + continue + elif dimen == 1: + yield value + continue + _d = len(value) if isinstance(value, Sequence) else 1 + if dimen is UnknownSetDimen: + # The first thing added to a Set with unknown dimension + # sets its dimension + self._dimen = dimen = _d + yield value + else: + raise ValueError( + "The value=%s has dimension %s and is not " + "valid for Set %s which has dimen=%s" + % (value, _d, self.name, self._dimen) + ) + class _FiniteSetData(metaclass=RenamedClass): __renamed__new_class__ = FiniteSetData @@ -1545,10 +1601,16 @@ def ordered_iter(self): return iter(self) def first(self): - return self.at(1) + try: + return next(iter(self)) + except StopIteration: + raise IndexError(f"{self.name} index out of range") from None def last(self): - return self.at(len(self)) + try: + return next(reversed(self)) + except StopIteration: + raise IndexError(f"{self.name} index out of range") from None def next(self, item, step=1): """ @@ -1655,27 +1717,30 @@ class OrderedSetData(_OrderedSetMixin, FiniteSetData): def __init__(self, component): self._values = {} - self._ordered_values = [] + self._ordered_values = None FiniteSetData.__init__(self, component=component) def _iter_impl(self): """ Return an iterator for the set. """ - return iter(self._ordered_values) + return iter(self._values) def __reversed__(self): - return reversed(self._ordered_values) + return reversed(self._values) - def _add_impl(self, value): - self._values[value] = len(self._values) - self._ordered_values.append(value) + def _update_impl(self, values): + for val in values: + # Note that we reset _ordered_values within the loop because + # of an old example where the initializer rule makes + # reference to values previously inserted into the Set + # (which triggered the creation of the _ordered_values) + self._ordered_values = None + self._values[val] = None def remove(self, val): - idx = self._values.pop(val) - self._ordered_values.pop(idx) - for i in range(idx, len(self._ordered_values)): - self._values[self._ordered_values[i]] -= 1 + self._values.pop(val) + self._ordered_values = None def discard(self, val): try: @@ -1685,15 +1750,15 @@ def discard(self, val): def clear(self): self._values.clear() - self._ordered_values = [] + self._ordered_values = None def pop(self): try: ans = self.last() except IndexError: - # Map the index error to a KeyError for consistency with - # set().pop() - raise KeyError('pop from an empty set') + # Map the exception for iterating over an empty dict to a + # KeyError for consistency with set().pop() + raise KeyError('pop from an empty set') from None self.discard(ans) return ans @@ -1704,6 +1769,8 @@ def at(self, index): The public Set API is 1-based, even though the internal _lookup and _values are (pythonically) 0-based. """ + if self._ordered_values is None: + self._rebuild_ordered_values() i = self._to_0_based_index(index) try: return self._ordered_values[i] @@ -1723,6 +1790,8 @@ def ord(self, item): # when they are actually put as Set members. So, we will look # for the exact thing that the user sent us and then fall back # on the scalar. + if self._ordered_values is None: + self._rebuild_ordered_values() try: return self._values[item] + 1 except KeyError: @@ -1733,6 +1802,12 @@ def ord(self, item): except KeyError: raise ValueError("%s.ord(x): x not in %s" % (self.name, self.name)) + def _rebuild_ordered_values(self): + _set = self._values + self._ordered_values = list(_set) + for i, v in enumerate(self._ordered_values): + _set[v] = i + class _OrderedSetData(metaclass=RenamedClass): __renamed__new_class__ = OrderedSetData @@ -1752,6 +1827,16 @@ class InsertionOrderSetData(OrderedSetData): __slots__ = () + def _initialize(self, val): + if type(val) in Set._UnorderedInitializers: + logger.warning( + "Initializing ordered Set %s with " + "a fundamentally unordered data source (type: %s). " + "This WILL potentially lead to nondeterministic behavior " + "in Pyomo" % (self.name, type(val).__name__) + ) + super()._initialize(val) + def set_value(self, val): if type(val) in Set._UnorderedInitializers: logger.warning( @@ -1760,7 +1845,8 @@ def set_value(self, val): "This WILL potentially lead to nondeterministic behavior " "in Pyomo" % (type(val).__name__,) ) - super(InsertionOrderSetData, self).set_value(val) + self.clear() + super().update(val) def update(self, values): if type(values) in Set._UnorderedInitializers: @@ -1770,7 +1856,7 @@ def update(self, values): "This WILL potentially lead to nondeterministic behavior " "in Pyomo" % (type(values).__name__,) ) - super(InsertionOrderSetData, self).update(values) + super().update(values) class _InsertionOrderSetData(metaclass=RenamedClass): @@ -1800,73 +1886,42 @@ class SortedSetData(_SortedSetMixin, OrderedSetData): Public Class Attributes: """ - __slots__ = ('_is_sorted',) - - def __init__(self, component): - # An empty set is sorted... - self._is_sorted = True - OrderedSetData.__init__(self, component=component) + __slots__ = () def _iter_impl(self): """ Return an iterator for the set. """ - if not self._is_sorted: - self._sort() - return super(SortedSetData, self)._iter_impl() + if self._ordered_values is None: + self._rebuild_ordered_values() + return iter(self._ordered_values) def __reversed__(self): - if not self._is_sorted: - self._sort() - return super(SortedSetData, self).__reversed__() + if self._ordered_values is None: + self._rebuild_ordered_values() + return reversed(self._ordered_values) - def _add_impl(self, value): - # Note that the sorted status has no bearing on insertion, - # so there is no reason to check if the data is correctly sorted - self._values[value] = len(self._values) - self._ordered_values.append(value) - self._is_sorted = False + def _update_impl(self, values): + for val in values: + # Note that we reset _ordered_values within the loop because + # of an old example where the initializer rule makes + # reference to values previously inserted into the Set + # (which triggered the creation of the _ordered_values) + self._ordered_values = None + self._values[val] = None # Note: removing data does not affect the sorted flag # def remove(self, val): # def discard(self, val): - def clear(self): - super(SortedSetData, self).clear() - self._is_sorted = True - - def at(self, index): - """ - Return the specified member of the set. - - The public Set API is 1-based, even though the - internal _lookup and _values are (pythonically) 0-based. - """ - if not self._is_sorted: - self._sort() - return super(SortedSetData, self).at(index) - - def ord(self, item): - """ - Return the position index of the input value. - - Note that Pyomo Set objects have positions starting at 1 (not 0). - - If the search item is not in the Set, then an IndexError is raised. - """ - if not self._is_sorted: - self._sort() - return super(SortedSetData, self).ord(item) - def sorted_data(self): return self.data() - def _sort(self): - self._ordered_values = list( - self.parent_component()._sort_fcn(self._ordered_values) - ) - self._values = {j: i for i, j in enumerate(self._ordered_values)} - self._is_sorted = True + def _rebuild_ordered_values(self): + _set = self._values + self._ordered_values = list(self.parent_component()._sort_fcn(_set)) + for i, v in enumerate(self._ordered_values): + _set[v] = i class _SortedSetData(metaclass=RenamedClass): @@ -1975,10 +2030,14 @@ class Set(IndexedComponent): """ - class End(object): + class _SetEndException(Exception): pass - class Skip(object): + class _SetEndType(type): + def __hash__(self): + raise Set._SetEndException() + + class End(metaclass=_SetEndType): pass class InsertionOrder(object): @@ -2223,21 +2282,6 @@ def _getitem_when_not_present(self, index): if _d is UnknownSetDimen and domain is not None and domain.dimen is not None: _d = domain.dimen - if self._init_values is not None: - self._init_values._dimen = _d - try: - _values = self._init_values(_block, index) - except TuplizeError as e: - raise ValueError( - str(e) % (self._name, "[%s]" % index if self.is_indexed() else "") - ) - - if _values is Set.Skip: - return - elif _values is None: - raise ValueError( - "Set rule or initializer returned None instead of Set.Skip" - ) if index is None and not self.is_indexed(): obj = self._data[index] = self else: @@ -2259,55 +2303,35 @@ def _getitem_when_not_present(self, index): obj._validate = self._init_validate if self._init_filter is not None: try: - _filter = Initializer(self._init_filter(_block, index)) - if _filter.constant(): + obj._filter = Initializer(self._init_filter(_block, index)) + if obj._filter.constant(): # _init_filter was the actual filter function; use it. - _filter = self._init_filter + obj._filter = self._init_filter except: # We will assume any exceptions raised when getting the # filter for this index indicate that the function # should have been passed directly to the underlying sets. - _filter = self._init_filter + obj._filter = self._init_filter else: - _filter = None + obj._filter = None if self._init_values is not None: - # _values was initialized above... - if obj.isordered() and type(_values) in Set._UnorderedInitializers: - logger.warning( - "Initializing ordered Set %s with a fundamentally " - "unordered data source (type: %s). This WILL potentially " - "lead to nondeterministic behavior in Pyomo" - % (self.name, type(_values).__name__) - ) - # Special case: set operations that are not first attached - # to the model must be constructed. - if isinstance(_values, SetOperator): - _values.construct() + # record the user-provided dimen in the initializer + self._init_values._dimen = _d try: - val_iter = iter(_values) - except TypeError: - logger.error( - "Initializer for Set %s%s returned non-iterable object " - "of type %s." - % ( - self.name, - ("[%s]" % (index,) if self.is_indexed() else ""), - ( - _values - if _values.__class__ is type - else type(_values).__name__ - ), - ) + _values = self._init_values(_block, index) + except TuplizeError as e: + raise ValueError( + str(e) % (self._name, "[%s]" % index if self.is_indexed() else "") ) - raise - for val in val_iter: - if val is Set.End: - break - if _filter is None or _filter(_block, val): - obj.add(val) - # We defer adding the filter until now so that add() doesn't - # call it a second time. - obj._filter = _filter + if _values is Set.Skip: + del self._data[index] + return + elif _values is None: + raise ValueError( + "Set rule or initializer returned None instead of Set.Skip" + ) + + obj._initialize(_values) return obj @staticmethod diff --git a/pyomo/core/tests/unit/test_set.py b/pyomo/core/tests/unit/test_set.py index 2b0da8b861d..0361f41c835 100644 --- a/pyomo/core/tests/unit/test_set.py +++ b/pyomo/core/tests/unit/test_set.py @@ -3763,8 +3763,8 @@ def I_init(m): m = ConcreteModel() m.I = Set(initialize={1, 3, 2, 4}) ref = ( - "Initializing ordered Set I with a " - "fundamentally unordered data source (type: set)." + 'Initializing ordered Set I with a fundamentally ' + 'unordered data source (type: set).' ) self.assertIn(ref, output.getvalue()) self.assertEqual(m.I.sorted_data(), (1, 2, 3, 4)) @@ -3811,6 +3811,7 @@ def I_init(m): self.assertEqual(m.I.data(), (4, 3, 2, 1)) self.assertEqual(m.I.dimen, 1) + def test_initialize_with_noniterable(self): output = StringIO() with LoggingIntercept(output, 'pyomo.core'): with self.assertRaisesRegex(TypeError, "'int' object is not iterable"): @@ -3819,6 +3820,14 @@ def I_init(m): ref = "Initializer for Set I returned non-iterable object of type int." self.assertIn(ref, output.getvalue()) + output = StringIO() + with LoggingIntercept(output, 'pyomo.core'): + with self.assertRaisesRegex(TypeError, "'int' object is not iterable"): + m = ConcreteModel() + m.I = Set([1, 2], initialize=5) + ref = "Initializer for Set I[1] returned non-iterable object of type int." + self.assertIn(ref, output.getvalue()) + def test_scalar_indexed_api(self): m = ConcreteModel() m.I = Set(initialize=range(3)) @@ -3877,12 +3886,13 @@ def _verify(_s, _l): m.I.add(4) _verify(m.I, [1, 3, 2, 4]) + N = len(m.I) output = StringIO() with LoggingIntercept(output, 'pyomo.core'): m.I.add(3) - self.assertEqual( - output.getvalue(), "Element 3 already exists in Set I; no action taken\n" - ) + # In Pyomo <= 6.7.3 duplicate values logged a warning. + self.assertEqual(output.getvalue(), "") + self.assertEqual(N, len(m.I)) _verify(m.I, [1, 3, 2, 4]) m.I.remove(3) @@ -3959,12 +3969,13 @@ def _verify(_s, _l): m.I.add(4) _verify(m.I, [1, 2, 3, 4]) + N = len(m.I) output = StringIO() with LoggingIntercept(output, 'pyomo.core'): m.I.add(3) - self.assertEqual( - output.getvalue(), "Element 3 already exists in Set I; no action taken\n" - ) + # In Pyomo <= 6.7.3 duplicate values logged a warning. + self.assertEqual(output.getvalue(), "") + self.assertEqual(N, len(m.I)) _verify(m.I, [1, 2, 3, 4]) m.I.remove(3) @@ -4052,12 +4063,13 @@ def _verify(_s, _l): m.I.add(4) _verify(m.I, [1, 2, 3, 4]) + N = len(m.I) output = StringIO() with LoggingIntercept(output, 'pyomo.core'): m.I.add(3) - self.assertEqual( - output.getvalue(), "Element 3 already exists in Set I; no action taken\n" - ) + # In Pyomo <= 6.7.3 duplicate values logged a warning. + self.assertEqual(output.getvalue(), "") + self.assertEqual(N, len(m.I)) _verify(m.I, [1, 2, 3, 4]) m.I.remove(3) @@ -4248,26 +4260,23 @@ def test_add_filter_validate(self): self.assertIn(1, m.I) self.assertIn(1.0, m.I) + N = len(m.I) output = StringIO() with LoggingIntercept(output, 'pyomo.core'): self.assertFalse(m.I.add(1)) - self.assertEqual( - output.getvalue(), "Element 1 already exists in Set I; no action taken\n" - ) + # In Pyomo <= 6.7.3 duplicate values logged a warning. + self.assertEqual(output.getvalue(), "") + self.assertEqual(N, len(m.I)) output = StringIO() with LoggingIntercept(output, 'pyomo.core'): self.assertFalse(m.I.add((1,))) - self.assertEqual( - output.getvalue(), "Element (1,) already exists in Set I; no action taken\n" - ) + # In Pyomo <= 6.7.3 duplicate values logged a warning. + self.assertEqual(output.getvalue(), "") m.J = Set() # Note that pypy raises a different exception from cpython - err = ( - r"Unable to insert '{}' into Set J:\n\tTypeError: " - r"((unhashable type: 'dict')|('dict' objects are unhashable))" - ) + err = r"((unhashable type: 'dict')|('dict' objects are unhashable))" with self.assertRaisesRegex(TypeError, err): m.J.add({}) @@ -4275,9 +4284,9 @@ def test_add_filter_validate(self): output = StringIO() with LoggingIntercept(output, 'pyomo.core'): self.assertFalse(m.J.add(1)) - self.assertEqual( - output.getvalue(), "Element 1 already exists in Set J; no action taken\n" - ) + # In Pyomo <= 6.7.3 duplicate values logged a warning. + self.assertEqual(output.getvalue(), "") + self.assertEqual(N, len(m.I)) def _l_tri(model, i, j): self.assertIs(model, m) @@ -5254,6 +5263,21 @@ def Bindex(m): self.assertIs(m.K.index_set()._domain, Integers) self.assertEqual(m.K.index_set(), [0, 1, 2, 3, 4]) + def test_normalize_index(self): + try: + _oldFlatten = normalize_index.flatten + normalize_index.flatten = True + + m = ConcreteModel() + with self.assertRaisesRegex( + ValueError, + r"The value=\(\(2, 3\),\) has dimension 2 and is not " + "valid for Set I which has dimen=1", + ): + m.I = Set(initialize=[1, ((2, 3),)]) + finally: + normalize_index.flatten = _oldFlatten + def test_no_normalize_index(self): try: _oldFlatten = normalize_index.flatten