Skip to content

Commit

Permalink
[gym_jiminy/toolbox|rllib] Avoid relying on buggy 'gym.Env.set_wrappe…
Browse files Browse the repository at this point in the history
…r_attr'.
  • Loading branch information
duburcqa committed Jan 19, 2025
1 parent 819bdaa commit 92315c2
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 45 deletions.
17 changes: 11 additions & 6 deletions python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import OrderedDict
from typing import (
Dict, Any, Tuple, List, TypeVar, Generic, TypedDict, Optional, Callable,
no_type_check, TYPE_CHECKING)
Mapping, no_type_check, TYPE_CHECKING)

import numpy as np
import numpy.typing as npt
Expand All @@ -20,7 +20,7 @@

import pinocchio as pin

from ..utils import DataNested
from ..utils import FieldNested, DataNested
if TYPE_CHECKING:
from ..envs.generic import BaseJiminyEnv
from ..quantities import QuantityManager
Expand Down Expand Up @@ -194,13 +194,22 @@ class InterfaceJiminyEnv(
action_space: gym.Space[Act]
observation_space: gym.Space[Obs]

action: Act

simulator: Simulator
robot: jiminy.Robot
stepper_state: jiminy.StepperState
robot_state: jiminy.RobotState
measurements: EngineObsType
is_simulation_running: npt.NDArray[np.bool_]

quantities: "QuantityManager"

log_fieldnames: Mapping[str, FieldNested]
"""Fielnames associated with all the variables that have been recorded to
the telemetry by any of the layer of the whole pipeline environment.
"""

num_steps: npt.NDArray[np.int64]
"""Number of simulation steps that has been performed since last reset of
the base environment.
Expand All @@ -211,10 +220,6 @@ class InterfaceJiminyEnv(
termination conditions.
"""

quantities: "QuantityManager"

action: Act

def __init__(self, *args: Any, **kwargs: Any) -> None:
# Track whether the observation has been refreshed manually since the
# last called '_controller_handle'. It typically happens at the end of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(self,
self.is_simulation_running = env.is_simulation_running
self.num_steps = env.num_steps
self.quantities = env.quantities
self.log_fieldnames = env.log_fieldnames

# Backup the parent environment
self.env = env
Expand Down
3 changes: 1 addition & 2 deletions python/gym_jiminy/common/gym_jiminy/common/envs/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ def __init__(self,
# Store references to the variables to register to the telemetry
self._registered_variables: MutableMappingT[
str, Tuple[FieldNested, DataNested, bool]] = {}
self.log_fieldnames: MappingT[str, FieldNested] = _LazyDictItemFilter(
self._registered_variables, 0)
self.log_fieldnames = _LazyDictItemFilter(self._registered_variables, 0)

# Random number generator.
# This is used for generating random observations and actions, sampling
Expand Down
56 changes: 28 additions & 28 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/curriculum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np
import gymnasium as gym
from packaging.version import parse as parse_version

from ray.rllib.core.rl_module import RLModule
from ray.rllib.env.env_context import EnvContext
Expand Down Expand Up @@ -233,36 +232,37 @@ def on_train_result(self,
score_and_proba_task_branches.append((
score_task_branch_, proba_task_branch_, space))

# Update the probability tree at runner-level
# Update the probability tree at runner-level.
# FIXME: `set_attr` is buggy on`gymnasium<=1.0` and cannot be used
# reliability in conjunction with `BasePipelineWrapper`.
# See PR: https://github.com/Farama-Foundation/Gymnasium/pull/1294
self._proba_task_tree = proba_task_tree
workers = algorithm.env_runner_group
assert workers is not None
if parse_version(gym.__version__) >= parse_version("1.0"):
workers.foreach_worker(
lambda worker: worker.env.unwrapped.set_attr(
'proba_task_tree',
(proba_task_tree,) * worker.num_envs))
else:
# Legacy code fallback because of buggy `set_attr`
def _update_runner_proba_task_tree(
env_runner: EnvRunner) -> None:
"""Update the probability task tree of all the environments
being managed by a given runner.
:param env_runner: Environment runner to consider.
"""
nonlocal proba_task_tree
assert isinstance(env_runner, SingleAgentEnvRunner)
env = env_runner.env.unwrapped
assert isinstance(env, gym.vector.SyncVectorEnv)
for env in env.unwrapped.envs:
while not isinstance(env, BaseTaskSettableWrapper):
assert isinstance(
env, (gym.Wrapper, BasePipelineWrapper))
env = env.env
env.proba_task_tree = proba_task_tree

workers.foreach_worker(_update_runner_proba_task_tree)

def _update_runner_proba_task_tree(
env_runner: EnvRunner) -> None:
"""Update the probability task tree of all the environments
being managed by a given runner.
:param env_runner: Environment runner to consider.
"""
nonlocal proba_task_tree
assert isinstance(env_runner, SingleAgentEnvRunner)
env = env_runner.env.unwrapped
assert isinstance(env, gym.vector.SyncVectorEnv)
for env in env.unwrapped.envs:
while not isinstance(env, BaseTaskSettableWrapper):
assert isinstance(
env, (gym.Wrapper, BasePipelineWrapper))
env = env.env
env.proba_task_tree = proba_task_tree

workers.foreach_worker(_update_runner_proba_task_tree)
# workers.foreach_worker(
# lambda worker: worker.env.unwrapped.set_attr(
# 'proba_task_tree',
# (proba_task_tree,) * worker.num_envs))

# Compute flattened probability tree
proba_task_tree_flat: List[float] = []
Expand Down
17 changes: 8 additions & 9 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import numpy as np
import gymnasium as gym
import plotext as plt
from packaging.version import parse as parse_version

import ray
from ray._private import services
Expand Down Expand Up @@ -1016,14 +1015,14 @@ def sample_from_runner(
if log_path is not None and log_path not in log_paths:
os.remove(log_path)

# Restore the original training/evaluation mode
if parse_version(gym.__version__) >= parse_version("1.0"):
env.set_attr('training', is_training_all)
else:
# Legacy code fallback because `set_attr` is buggy for `gymnasium<1.0`
assert isinstance(env, gym.vector.SyncVectorEnv)
for env, is_training in zip(env.envs, is_training_all):
env.get_wrapper_attr("train")(is_training)
# Restore the original training/evaluation mode.
# FIXME: `set_attr` is buggy on`gymnasium<=1.0` and cannot be used
# reliability in conjunction with `BasePipelineWrapper`.
# See PR: https://github.com/Farama-Foundation/Gymnasium/pull/1294
# env.set_attr('training', is_training_all)
assert isinstance(env, gym.vector.SyncVectorEnv)
for env, is_training in zip(env.envs, is_training_all):
env.get_wrapper_attr("train")(is_training)

return (metrics,), episodes, log_paths

Expand Down

0 comments on commit 92315c2

Please sign in to comment.