Skip to content

Commit

Permalink
Add __setattr__ to wrappers. Fixes Farama-Foundation#1176
Browse files Browse the repository at this point in the history
Adding this ensures that variables get set to the appropriate
location. By default, any public value set by the wrapper is sent
to the env to be set there instead of on the wrapper. If a variable
is meant to be set by the wrapper, it should be listed in the
_local_vars class variable of the wrapper. This is not ideal, but
seems to be the most reasonable design.

An example of needing to specify which vars to keep locally is here:
https://python-patterns.guide/gang-of-four/decorator-pattern/#implementing-dynamic-wrapper
The solution is to list which vars should be in the wrapper and
check them when setting a value. That is the approach used in this
commit, but more generalized.

In line with __getattr__, private values cannot be set on underlying
envs. There are two exceptions:
_cumulative_rewards was previously exempted in __getattr__ because it
is used by many envs.
_skip_agent_selection is added because is used byt the dead step
handling. If a wrapper can't set this, that functionality will break.
  • Loading branch information
dm-ackerman committed Feb 10, 2024
1 parent 6c8e8c1 commit c082b23
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 2 deletions.
35 changes: 34 additions & 1 deletion pettingzoo/utils/wrappers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,49 @@ class BaseWrapper(AECEnv[AgentID, ObsType, ActionType]):
All AECEnv wrappers should inherit from this base class
"""

# This is a list of object variables (as strings), used by THIS wrapper,
# which should be stored by the wrapper object and not by the underlying
# environment. They are used to store information that the wrapper needs
# to behave correctly. The list is used by __setattr__() to determine where
# to store variables. It is very important that this list is correct to
# prevent confusing bugs.
# Wrappers inheriting from this class should include their own _local_vars
# list with object variables used by that class. Note that 'env' is hardcoded
# as part of the __setattr__ function so should not be included.
_local_vars = []

def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
super().__init__()
self.env = env

def __getattr__(self, name: str) -> Any:
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
if name.startswith("_") and name != "_cumulative_rewards":
if name.startswith("_") and name not in [
"_cumulative_rewards",
"_skip_agent_selection",
]:
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
return getattr(self.env, name)

def __setattr__(self, name: str, value: Any) -> None:
"""Set attribute ``name`` if it is this class's value, otherwise send to env."""
# these are the attributes that can be set on this wrapper directly
if name == "env" or name in self._local_vars:
self.__dict__[name] = value
else:
# If this is being raised by your wrapper while you are trying to access
# a variable that is owned by the wrapper and NOT part of the env, you
# may have forgotten to add the variable to the _local_vars list.
if name.startswith("_") and name not in [
"_cumulative_rewards",
"_skip_agent_selection",
]:
raise AttributeError(
f"setting private attribute '{name}' is prohibited"
)
# send to the underlying environment to handle
setattr(self.__dict__["env"], name, value)

@property
def unwrapped(self) -> AECEnv:
return self.env.unwrapped
Expand Down
2 changes: 2 additions & 0 deletions pettingzoo/utils/wrappers/multi_episode_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class MultiEpisodeEnv(BaseWrapper):
The result of this wrapper is that the environment is no longer Markovian around the environment reset.
"""

_local_vars = ["_num_episodes", "_episodes_elapsed", "_seed", "_options"]

def __init__(self, env: AECEnv, num_episodes: int):
"""__init__.
Expand Down
4 changes: 3 additions & 1 deletion pettingzoo/utils/wrappers/order_enforcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ class OrderEnforcingWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
* warn on calling step after environment is terminated or truncated
"""

_local_vars = ["_has_reset", "_has_rendered", "_has_updated"]

def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
assert isinstance(
env, AECEnv
), "OrderEnforcingWrapper is only compatible with AEC environments"
super().__init__(env)
self._has_reset = False
self._has_rendered = False
self._has_updated = False
super().__init__(env)

def __getattr__(self, value: str) -> Any:
"""Raises an error message when data is gotten from the env.
Expand Down
2 changes: 2 additions & 0 deletions pettingzoo/utils/wrappers/terminate_illegal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class TerminateIllegalWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
illegal_reward: number that is the value of the player making an illegal move.
"""

_local_vars = ["_prev_obs", "_prev_info", "_terminated", "_illegal_value"]

def __init__(
self, env: AECEnv[AgentID, ObsType, ActionType], illegal_reward: float
):
Expand Down

0 comments on commit c082b23

Please sign in to comment.