From 8dc617246ddd4710479636c104178c7e1003a338 Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Mon, 27 Mar 2023 23:45:34 -0700 Subject: [PATCH 1/3] AlgorithmConfig.update_from_dict needs to work for MultiCallbacks. Signed-off-by: Jun Gong --- rllib/algorithms/algorithm_config.py | 6 ++++-- rllib/algorithms/tests/test_algorithm_config.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 17f7ed50bd3d..72b9e42edf27 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -600,8 +600,10 @@ def update_from_dict( # correct methods to properly `.update()` those from given config dict # (to not lose any sub-keys). elif key == "callbacks_class": - # Resolve possible classpath. - value = deserialize_type(value, error=True) + # For backward compatibility reasons, only resolve possible + # classpath if value is a str type. + if isinstance(value, str): + value = deserialize_type(value, error=True) self.callbacks(callbacks_class=value) elif key == "env_config": self.environment(env_config=value) diff --git a/rllib/algorithms/tests/test_algorithm_config.py b/rllib/algorithms/tests/test_algorithm_config.py index d928e673036b..eb72819b548f 100644 --- a/rllib/algorithms/tests/test_algorithm_config.py +++ b/rllib/algorithms/tests/test_algorithm_config.py @@ -3,6 +3,7 @@ import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.algorithms.callbacks import MultiCallbacks from ray.rllib.algorithms.ppo import PPO, PPOConfig from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec @@ -34,6 +35,19 @@ def test_running_specific_algo_with_generic_config(self): algo.train() algo.stop() + def test_update_from_dict_works_for_multi_callbacks(self): + """Test to make sure callbacks config dict works.""" + config_dict = { + "callbacks": MultiCallbacks([]) + } + config = ( + AlgorithmConfig(algo_class=PPO) + .environment("CartPole-v0") + .training(lr=0.12345, train_batch_size=3000) + ) + # This should work. + config.update_from_dict(config_dict) + def test_freezing_of_algo_config(self): """Tests, whether freezing an AlgorithmConfig actually works as expected.""" config = ( From b1c07eb62813d729df1f4598b0dc18bc119c37d0 Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Mon, 27 Mar 2023 23:52:27 -0700 Subject: [PATCH 2/3] handle serialization as well Signed-off-by: Jun Gong --- rllib/algorithms/algorithm_config.py | 6 +++++- rllib/algorithms/tests/test_algorithm_config.py | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 72b9e42edf27..342bd227cb67 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -3144,7 +3144,11 @@ def items(self): @staticmethod def _serialize_dict(config): # Serialize classes to classpaths: - config["callbacks"] = serialize_type(config["callbacks"]) + if isinstance(config.get("callbacks"), type): + config["callbacks"] = serialize_type(config["callbacks"]) + else: + # TODO(sven): Figure out how to serialize MultiCallbacks + config["callbacks"] = NOT_SERIALIZABLE config["sample_collector"] = serialize_type(config["sample_collector"]) if isinstance(config["env"], type): config["env"] = serialize_type(config["env"]) diff --git a/rllib/algorithms/tests/test_algorithm_config.py b/rllib/algorithms/tests/test_algorithm_config.py index eb72819b548f..0a35161f309f 100644 --- a/rllib/algorithms/tests/test_algorithm_config.py +++ b/rllib/algorithms/tests/test_algorithm_config.py @@ -48,6 +48,11 @@ def test_update_from_dict_works_for_multi_callbacks(self): # This should work. config.update_from_dict(config_dict) + serialized = config.serialize() + + # For now, we don't support serializing MultiCallbacks. + self.assertEqual(serialized["callbacks"], "__not_serializable__") + def test_freezing_of_algo_config(self): """Tests, whether freezing an AlgorithmConfig actually works as expected.""" config = ( From 9d6dc697d1392e5f40f9528c924f45a1af4e41f7 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Tue, 28 Mar 2023 14:10:52 +0200 Subject: [PATCH 3/3] fixes Signed-off-by: sven1977 --- rllib/algorithms/algorithm_config.py | 6 +++--- rllib/algorithms/tests/test_algorithm.py | 2 +- rllib/algorithms/tests/test_algorithm_config.py | 7 +++---- rllib/policy/policy.py | 9 ++++++--- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 342bd227cb67..a8777f3166bb 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -599,7 +599,7 @@ def update_from_dict( # Some keys specify config sub-dicts and therefore should go through the # correct methods to properly `.update()` those from given config dict # (to not lose any sub-keys). - elif key == "callbacks_class": + elif key == "callbacks_class" and value != NOT_SERIALIZABLE: # For backward compatibility reasons, only resolve possible # classpath if value is a str type. if isinstance(value, str): @@ -3144,10 +3144,10 @@ def items(self): @staticmethod def _serialize_dict(config): # Serialize classes to classpaths: - if isinstance(config.get("callbacks"), type): + if isinstance(config.get("callbacks"), (type, str)): config["callbacks"] = serialize_type(config["callbacks"]) else: - # TODO(sven): Figure out how to serialize MultiCallbacks + # TODO(sven): Figure out how to serialize MultiCallbacks. config["callbacks"] = NOT_SERIALIZABLE config["sample_collector"] = serialize_type(config["sample_collector"]) if isinstance(config["env"], type): diff --git a/rllib/algorithms/tests/test_algorithm.py b/rllib/algorithms/tests/test_algorithm.py index 910e135de8e0..41168e06b47e 100644 --- a/rllib/algorithms/tests/test_algorithm.py +++ b/rllib/algorithms/tests/test_algorithm.py @@ -84,7 +84,7 @@ def new_mapping_fn(agent_id, episode, worker, **kwargs): # Add a new policy either by class (and options) or by instance. pid = f"p{i}" print(f"Adding policy {pid} ...") - # By instance. + # By (already instantiated) instance. if i == 2: new_pol = algo.add_policy( pid, diff --git a/rllib/algorithms/tests/test_algorithm_config.py b/rllib/algorithms/tests/test_algorithm_config.py index 0a35161f309f..8c0fe5b372f7 100644 --- a/rllib/algorithms/tests/test_algorithm_config.py +++ b/rllib/algorithms/tests/test_algorithm_config.py @@ -11,6 +11,7 @@ MultiAgentRLModuleSpec, MultiAgentRLModule, ) +from ray.rllib.utils.serialization import NOT_SERIALIZABLE class TestAlgorithmConfig(unittest.TestCase): @@ -37,9 +38,7 @@ def test_running_specific_algo_with_generic_config(self): def test_update_from_dict_works_for_multi_callbacks(self): """Test to make sure callbacks config dict works.""" - config_dict = { - "callbacks": MultiCallbacks([]) - } + config_dict = {"callbacks": MultiCallbacks([])} config = ( AlgorithmConfig(algo_class=PPO) .environment("CartPole-v0") @@ -51,7 +50,7 @@ def test_update_from_dict_works_for_multi_callbacks(self): serialized = config.serialize() # For now, we don't support serializing MultiCallbacks. - self.assertEqual(serialized["callbacks"], "__not_serializable__") + self.assertEqual(serialized["callbacks"], NOT_SERIALIZABLE) def test_freezing_of_algo_config(self): """Tests, whether freezing an AlgorithmConfig actually works as expected.""" diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index eb88e3519f2b..b616afa42f00 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -235,9 +235,12 @@ def __init__( if isinstance(callbacks, DefaultCallbacks): self.callbacks = callbacks() elif isinstance(callbacks, (str, type)): - self.callbacks: "DefaultCallbacks" = deserialize_type( - self.config.get("callbacks") - )() + try: + self.callbacks: "DefaultCallbacks" = deserialize_type( + self.config.get("callbacks") + )() + except Exception: + pass # TEST else: self.callbacks: "DefaultCallbacks" = DefaultCallbacks()