Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: sven1977 <svenmika1977@gmail.com>
  • Loading branch information
sven1977 committed Sep 28, 2024
1 parent 1248596 commit 379cdfa
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 27 deletions.
34 changes: 17 additions & 17 deletions rllib/algorithms/tests/test_worker_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,26 +673,26 @@ def test_modules_are_restored_on_recovered_worker(self):
self.assertEqual(algo.eval_env_runner_group.num_healthy_remote_workers(), 1)
self.assertEqual(algo.eval_env_runner_group.num_remote_worker_restarts(), 1)

# Let's verify that our custom module exists on both recovered workers.
# TODO (sven): Reinstate once EnvRunners moved to new get/set_state APIs (from
# get/set_weights).
# def has_test_module(w):
# return "test_module" in w.module
# Let's verify that our custom module exists on all recovered workers.
def has_test_module(w):
return "test_module" in w.module

# Rollout worker has test module.
# self.assertTrue(
# all(algo.env_runner_group.foreach_worker(
# has_test_module, local_worker=False
# ))
# )
self.assertTrue(
all(
algo.env_runner_group.foreach_worker(
has_test_module, local_env_runner=False
)
)
)
# Eval worker has test module.
# self.assertTrue(
# all(
# algo.eval_env_runner_group.foreach_worker(
# has_test_module, local_worker=False
# )
# )
# )
self.assertTrue(
all(
algo.eval_env_runner_group.foreach_worker(
has_test_module, local_env_runner=False
)
)
)

def test_eval_workers_failing_recover(self):
# Counter that will survive restarts.
Expand Down
2 changes: 2 additions & 0 deletions rllib/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
COMPONENT_LEARNER_GROUP = "learner_group"
COMPONENT_METRICS_LOGGER = "metrics_logger"
COMPONENT_MODULE_TO_ENV_CONNECTOR = "module_to_env_connector"
COMPONENT_MULTI_RL_MODULE_SPEC = "_multi_rl_module_spec"
COMPONENT_OPTIMIZER = "optimizer"
COMPONENT_RL_MODULE = "rl_module"

Expand All @@ -25,6 +26,7 @@
"COMPONENT_LEARNER_GROUP",
"COMPONENT_METRICS_LOGGER",
"COMPONENT_MODULE_TO_ENV_CONNECTOR",
"COMPONENT_MULTI_RL_MODULE_SPEC",
"COMPONENT_OPTIMIZER",
"COMPONENT_RL_MODULE",
"DEFAULT_AGENT_ID",
Expand Down
45 changes: 35 additions & 10 deletions rllib/core/rl_module/multi_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
ValuesView,
)

from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC
from ray.rllib.core.models.specs.typing import SpecType
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec

from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import (
Expand Down Expand Up @@ -297,11 +297,9 @@ def _forward_train(
) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]:
"""Runs the forward_train pass.
TODO(avnishn, kourosh): Review type hints for forward methods.
Args:
batch: The batch of multi-agent data (i.e. mapping from module ids to
SampleBaches).
individual modules' batches).
Returns:
The output of the forward_train pass the specified modules.
Expand All @@ -314,11 +312,9 @@ def _forward_inference(
) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]:
"""Runs the forward_inference pass.
TODO(avnishn, kourosh): Review type hints for forward methods.
Args:
batch: The batch of multi-agent data (i.e. mapping from module ids to
SampleBaches).
individual modules' batches).
Returns:
The output of the forward_inference pass the specified modules.
Expand All @@ -331,11 +327,9 @@ def _forward_exploration(
) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]:
"""Runs the forward_exploration pass.
TODO(avnishn, kourosh): Review type hints for forward methods.
Args:
batch: The batch of multi-agent data (i.e. mapping from module ids to
SampleBaches).
individual modules' batches).
Returns:
The output of the forward_exploration pass the specified modules.
Expand All @@ -353,6 +347,17 @@ def get_state(
) -> StateDict:
state = {}

# We store the current RLModuleSpec as well as it might have changed over time
# (modules added/removed from `self`).
if self._check_component(
COMPONENT_MULTI_RL_MODULE_SPEC,
components,
not_components,
):
state[COMPONENT_MULTI_RL_MODULE_SPEC] = MultiRLModuleSpec.from_module(
self
).to_dict()

for module_id, rl_module in self.get_checkpointable_components():
if self._check_component(module_id, components, not_components):
state[module_id] = rl_module.get_state(
Expand All @@ -376,7 +381,27 @@ def set_state(self, state: StateDict) -> None:
Args:
state: The state dict to set.
"""
# Check the given MultiRLModuleSpec and - if there are changes in the individual
# sub-modules - apply these to this MultiRLModule.
if COMPONENT_MULTI_RL_MODULE_SPEC in state:
multi_rl_module_spec = MultiRLModuleSpec.from_dict(
state[COMPONENT_MULTI_RL_MODULE_SPEC]
)
# Go through all of our current modules and check, whether they are listed
# in the given MultiRLModuleSpec. If not, erase them from `self`.
for module_id, module in self._rl_modules.items():
if module_id not in multi_rl_module_spec.module_specs:
self.remove_module(module_id, raise_err_if_not_found=True)
# Go through all the modules in the given MultiRLModuleSpec and if
# they are not present in `self`, add them.
for module_id, module_spec in multi_rl_module_spec.module_specs.items():
if module_id not in self:
self.add_module(module_id, module_spec.build(), override=False)

# Now, set the individual states
for module_id, module_state in state.items():
if module_id == COMPONENT_MULTI_RL_MODULE_SPEC:
continue
if module_id in self:
self._rl_modules[module_id].set_state(module_state)

Expand Down
4 changes: 4 additions & 0 deletions rllib/env/multi_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ def get_state(
not_components: Optional[Union[str, Collection[str]]] = None,
**kwargs,
) -> StateDict:
# Basic state dict.
state = {
WEIGHTS_SEQ_NO: self._weights_seq_no,
NUM_ENV_STEPS_SAMPLED_LIFETIME: (
Expand All @@ -696,6 +697,7 @@ def get_state(
"agent_to_module_mapping_fn": self.config.policy_mapping_fn,
}

# RLModule (MultiRLModule) component.
if self._check_component(COMPONENT_RL_MODULE, components, not_components):
state[COMPONENT_RL_MODULE] = self.module.get_state(
components=self._get_subcomponents(COMPONENT_RL_MODULE, components),
Expand All @@ -704,10 +706,12 @@ def get_state(
),
**kwargs,
)
# Env-to-module connector.
if self._check_component(
COMPONENT_ENV_TO_MODULE_CONNECTOR, components, not_components
):
state[COMPONENT_ENV_TO_MODULE_CONNECTOR] = self._env_to_module.get_state()
# Module-to-env connector.
if self._check_component(
COMPONENT_MODULE_TO_ENV_CONNECTOR, components, not_components
):
Expand Down

0 comments on commit 379cdfa

Please sign in to comment.