Skip to content

Commit

Permalink
[RLlib] AlgorithmConfig.update_from_dict needs to work for MultiCal…
Browse files Browse the repository at this point in the history
…lbacks. (#33796)

Signed-off-by: Jun Gong <jungong@anyscale.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
  • Loading branch information
Jun Gong and sven1977 authored Mar 28, 2023
1 parent 639188b commit bb1909c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 8 deletions.
14 changes: 10 additions & 4 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,11 @@ def update_from_dict(
# Some keys specify config sub-dicts and therefore should go through the
# correct methods to properly `.update()` those from given config dict
# (to not lose any sub-keys).
elif key == "callbacks_class":
# Resolve possible classpath.
value = deserialize_type(value, error=True)
elif key == "callbacks_class" and value != NOT_SERIALIZABLE:
# For backward compatibility reasons, only resolve possible
# classpath if value is a str type.
if isinstance(value, str):
value = deserialize_type(value, error=True)
self.callbacks(callbacks_class=value)
elif key == "env_config":
self.environment(env_config=value)
Expand Down Expand Up @@ -3142,7 +3144,11 @@ def items(self):
@staticmethod
def _serialize_dict(config):
# Serialize classes to classpaths:
config["callbacks"] = serialize_type(config["callbacks"])
if isinstance(config.get("callbacks"), (type, str)):
config["callbacks"] = serialize_type(config["callbacks"])
else:
# TODO(sven): Figure out how to serialize MultiCallbacks.
config["callbacks"] = NOT_SERIALIZABLE
config["sample_collector"] = serialize_type(config["sample_collector"])
if isinstance(config["env"], type):
config["env"] = serialize_type(config["env"])
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def new_mapping_fn(agent_id, episode, worker, **kwargs):
# Add a new policy either by class (and options) or by instance.
pid = f"p{i}"
print(f"Adding policy {pid} ...")
# By instance.
# By (already instantiated) instance.
if i == 2:
new_pol = algo.add_policy(
pid,
Expand Down
18 changes: 18 additions & 0 deletions rllib/algorithms/tests/test_algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

import ray
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.callbacks import MultiCallbacks
from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.rl_module.marl_module import (
MultiAgentRLModuleSpec,
MultiAgentRLModule,
)
from ray.rllib.utils.serialization import NOT_SERIALIZABLE


class TestAlgorithmConfig(unittest.TestCase):
Expand All @@ -34,6 +36,22 @@ def test_running_specific_algo_with_generic_config(self):
algo.train()
algo.stop()

def test_update_from_dict_works_for_multi_callbacks(self):
"""Test to make sure callbacks config dict works."""
config_dict = {"callbacks": MultiCallbacks([])}
config = (
AlgorithmConfig(algo_class=PPO)
.environment("CartPole-v0")
.training(lr=0.12345, train_batch_size=3000)
)
# This should work.
config.update_from_dict(config_dict)

serialized = config.serialize()

# For now, we don't support serializing MultiCallbacks.
self.assertEqual(serialized["callbacks"], NOT_SERIALIZABLE)

def test_freezing_of_algo_config(self):
"""Tests, whether freezing an AlgorithmConfig actually works as expected."""
config = (
Expand Down
9 changes: 6 additions & 3 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,12 @@ def __init__(
if isinstance(callbacks, DefaultCallbacks):
self.callbacks = callbacks()
elif isinstance(callbacks, (str, type)):
self.callbacks: "DefaultCallbacks" = deserialize_type(
self.config.get("callbacks")
)()
try:
self.callbacks: "DefaultCallbacks" = deserialize_type(
self.config.get("callbacks")
)()
except Exception:
pass # TEST
else:
self.callbacks: "DefaultCallbacks" = DefaultCallbacks()

Expand Down

0 comments on commit bb1909c

Please sign in to comment.