Skip to content

Commit

Permalink
Merge pull request #2809 from emma58/reversible-fix-disjunctive-logic
Browse files Browse the repository at this point in the history
Adding (reversible) `gdp.transform_current_disjunctive_logic` transformation
  • Loading branch information
emma58 authored Aug 23, 2023
2 parents 8abb24b + d595a32 commit 9a5ef6b
Show file tree
Hide file tree
Showing 11 changed files with 744 additions and 11 deletions.
6 changes: 5 additions & 1 deletion pyomo/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@
ConcreteModel,
AbstractModel,
)
from pyomo.core.base.transformation import Transformation, TransformationFactory
from pyomo.core.base.transformation import (
Transformation,
TransformationFactory,
ReverseTransformationToken,
)

from pyomo.core.base.instance2dat import instance2dat

Expand Down
6 changes: 5 additions & 1 deletion pyomo/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@
ConcreteModel,
AbstractModel,
)
from pyomo.core.base.transformation import Transformation, TransformationFactory
from pyomo.core.base.transformation import (
Transformation,
TransformationFactory,
ReverseTransformationToken,
)

from pyomo.core.base.instance2dat import instance2dat

Expand Down
63 changes: 62 additions & 1 deletion pyomo/core/base/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# ___________________________________________________________________________

from pyomo.common import Factory
from pyomo.common.collections import ComponentSet
from pyomo.common.errors import MouseTrap
from pyomo.common.deprecation import deprecated
from pyomo.common.modeling import unique_component_name
from pyomo.common.timing import TransformationTimer
Expand Down Expand Up @@ -72,9 +74,11 @@ def apply_to(self, model, **kwds):
timer = TransformationTimer(self, 'in-place')
if not hasattr(model, '_transformation_data'):
model._transformation_data = TransformationData()
self._apply_to(model, **kwds)
reverse_token = self._apply_to(model, **kwds)
timer.report()

return reverse_token

def create_using(self, model, **kwds):
"""
Create a new model with this transformation
Expand Down Expand Up @@ -105,6 +109,63 @@ def _create_using(self, model, **kwds):
return instance


class ReverseTransformationToken(object):
"""
Class returned by reversible transformations' apply_to methods that
can be passed back to the transformation in order to revert its changes
to the model.
We store the transformation that created it, so that we have some basic
error checking when the user attempts to revert, and we store a dictionary
that can be whatever the transformation wants/needs in order to revert
itself.
args:
transformation: The class of the transformation that created this token
model: The model being transformed when this token was created
targets: The targets on 'model' being transformed when this token
was created.
reverse_dict: Dictionary with everything the transformation needs to
undo itself.
"""

def __init__(self, transformation, model, targets, reverse_dict):
self._transformation = transformation
self._model = model
self._targets = ComponentSet(targets)
self._reverse_dict = reverse_dict

@property
def transformation(self):
return self._transformation

@property
def reverse_dict(self):
return self._reverse_dict

def check_token_valid(self, cls, model, targets):
if cls is not self._transformation:
raise ValueError(
"Attempting to reverse transformation of class '%s' "
"using a token created by a transformation of class "
"'%s'. Cannot revert transformation with a token from "
"another transformation." % (cls, self._transformation)
)
if model is not self._model:
raise MouseTrap(
"A reverse transformation was called on model '%s', but the "
"transformation that created this token was created from "
"model '%s'. We do not currently support reversing "
"transformations on clones of the transformed model."
% (model.name, self._model.name)
)
# TODO: Do we need to pass targets into this? I'm thinking no because
# people can untransform selectively. We just need to enforce that the
# targets given here were indeed transformed, but they don't have to
# correspond exactly to what happened before. I think we leave that for
# the transformation to sort out


TransformationFactory = Factory('transformation type')


Expand Down
2 changes: 0 additions & 2 deletions pyomo/dae/plugins/colloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,6 @@ def _apply_to(self, instance, **kwds):

self._transformBlock(instance, currentds)

return instance

def _transformBlock(self, block, currentds):
self._fe = {}
for ds in block.component_objects(ContinuousSet, descend_into=True):
Expand Down
2 changes: 0 additions & 2 deletions pyomo/dae/plugins/finitedifference.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@ def _apply_to(self, instance, **kwds):

self._transformBlock(instance, currentds)

return instance

def _transformBlock(self, block, currentds):
self._fe = {}
for ds in block.component_objects(ContinuousSet):
Expand Down
1 change: 1 addition & 0 deletions pyomo/gdp/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ def load():
import pyomo.gdp.plugins.partition_disjuncts
import pyomo.gdp.plugins.between_steps
import pyomo.gdp.plugins.multiple_bigm
import pyomo.gdp.plugins.transform_current_disjunctive_state
import pyomo.gdp.plugins.bound_pretransformation
11 changes: 11 additions & 0 deletions pyomo/gdp/plugins/fix_disjuncts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# ___________________________________________________________________________
#
# Pyomo: Python Optimization Modeling Objects
# Copyright (c) 2008-2022
# National Technology and Engineering Solutions of Sandia, LLC
# Under the terms of Contract DE-NA0003525 with National Technology and
# Engineering Solutions of Sandia, LLC, the U.S. Government retains certain
# rights in this software.
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

# -*- coding: utf-8 -*-
"""Transformation to fix and enforce disjunct True/False status."""

Expand Down
8 changes: 4 additions & 4 deletions pyomo/gdp/plugins/gdp_to_mip_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def _restore_state(self):
def _process_arguments(self, instance, **kwds):
if not instance.ctype in (Block, Disjunct):
raise GDP_Error(
"Transformation called on %s of type %s. 'instance'"
" must be a ConcreteModel, Block, or Disjunct (in "
"Transformation called on %s of type %s. 'instance' "
"must be a ConcreteModel, Block, or Disjunct (in "
"the case of nested disjunctions)." % (instance.name, instance.ctype)
)

Expand Down Expand Up @@ -145,8 +145,8 @@ def _filter_inactive(targets):
for t in targets:
if not t.active:
self.logger.warning(
'GDP.Hull transformation passed a deactivated '
f'target ({t.name}). Skipping.'
f'GDP.{self.transformation_name} transformation passed '
f'a deactivated target ({t.name}). Skipping.'
)
else:
yield t
Expand Down
Loading

0 comments on commit 9a5ef6b

Please sign in to comment.