Skip to content

Commit

Permalink
Merge pull request #3153 from jsiirola/private-data-initializer
Browse files Browse the repository at this point in the history
Add `Block.register_private_data_initializer()`
  • Loading branch information
emma58 authored Feb 20, 2024
2 parents 783872b + b96bd2a commit 9f0d7eb
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 4 deletions.
27 changes: 23 additions & 4 deletions pyomo/core/base/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@
import sys
import weakref
import textwrap
from contextlib import contextmanager

from collections import defaultdict
from contextlib import contextmanager
from inspect import isclass, currentframe
from io import StringIO
from itertools import filterfalse, chain
from operator import itemgetter, attrgetter
from io import StringIO
from pyomo.common.pyomo_typing import overload

from pyomo.common.autoslots import AutoSlots
from pyomo.common.collections import Mapping
from pyomo.common.deprecation import deprecated, deprecation_warning, RenamedClass
from pyomo.common.formatting import StreamIndenter
from pyomo.common.gc_manager import PauseGC
from pyomo.common.log import is_debug_set
from pyomo.common.pyomo_typing import overload
from pyomo.common.timing import ConstructionTimer
from pyomo.core.base.component import (
Component,
Expand Down Expand Up @@ -1986,7 +1987,7 @@ def private_data(self, scope=None):
if self._private_data is None:
self._private_data = {}
if scope not in self._private_data:
self._private_data[scope] = {}
self._private_data[scope] = Block._private_data_initializers[scope]()
return self._private_data[scope]


Expand All @@ -2004,6 +2005,7 @@ class Block(ActiveIndexedComponent):
"""

_ComponentDataClass = _BlockData
_private_data_initializers = defaultdict(lambda: dict)

def __new__(cls, *args, **kwds):
if cls != Block:
Expand Down Expand Up @@ -2207,6 +2209,23 @@ def display(self, filename=None, ostream=None, prefix=""):
for key in sorted(self):
_BlockData.display(self[key], filename, ostream, prefix)

@staticmethod
def register_private_data_initializer(initializer, scope=None):
mod = currentframe().f_back.f_globals['__name__']
if scope is None:
scope = mod
elif not mod.startswith(scope):
raise ValueError(
"'private_data' scope must be substrings of the caller's module name. "
f"Received '{scope}' when calling register_private_data_initializer()."
)
if scope in Block._private_data_initializers:
raise RuntimeError(
"Duplicate initializer registration for 'private_data' dictionary "
f"(scope={scope})"
)
Block._private_data_initializers[scope] = initializer


class ScalarBlock(_BlockData, Block):
def __init__(self, *args, **kwds):
Expand Down
58 changes: 58 additions & 0 deletions pyomo/core/tests/unit/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3437,6 +3437,64 @@ def test_private_data(self):
mfe4 = m.b.b[1].private_data('pyomo.core.tests')
self.assertIs(mfe4, mfe3)

def test_register_private_data(self):
_save = Block._private_data_initializers

Block._private_data_initializers = pdi = _save.copy()
pdi.clear()
try:
self.assertEqual(len(pdi), 0)
b = Block(concrete=True)
ps = b.private_data()
self.assertEqual(ps, {})
self.assertEqual(len(pdi), 1)
finally:
Block._private_data_initializers = _save

def init():
return {'a': None, 'b': 1}

Block._private_data_initializers = pdi = _save.copy()
pdi.clear()
try:
self.assertEqual(len(pdi), 0)
Block.register_private_data_initializer(init)
self.assertEqual(len(pdi), 1)

b = Block(concrete=True)
ps = b.private_data()
self.assertEqual(ps, {'a': None, 'b': 1})
self.assertEqual(len(pdi), 1)
finally:
Block._private_data_initializers = _save

Block._private_data_initializers = pdi = _save.copy()
pdi.clear()
try:
Block.register_private_data_initializer(init)
self.assertEqual(len(pdi), 1)
Block.register_private_data_initializer(init, 'pyomo')
self.assertEqual(len(pdi), 2)

with self.assertRaisesRegex(
RuntimeError,
r"Duplicate initializer registration for 'private_data' "
r"dictionary \(scope=pyomo.core.tests.unit.test_block\)",
):
Block.register_private_data_initializer(init)

with self.assertRaisesRegex(
ValueError,
r"'private_data' scope must be substrings of the caller's "
r"module name. Received 'invalid' when calling "
r"register_private_data_initializer\(\).",
):
Block.register_private_data_initializer(init, 'invalid')

self.assertEqual(len(pdi), 2)
finally:
Block._private_data_initializers = _save


if __name__ == "__main__":
unittest.main()

0 comments on commit 9f0d7eb

Please sign in to comment.