From e5ed2764563be9ebf9ca03b469c2d778505fb183 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Sun, 27 Oct 2024 22:18:57 +0100 Subject: [PATCH] [RLlib] Fix `ConnectorPipelineV2` restoring from checkpoint (by writing information about individual connector pieces to the `ctor_args_and_kwargs` file). (#48213) Signed-off-by: mohitjain2504 --- rllib/connectors/connector_pipeline_v2.py | 24 ++++++++++++++++++++++- rllib/connectors/connector_v2.py | 3 ++- rllib/tuned_examples/ppo/cartpole_ppo.py | 1 - 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/rllib/connectors/connector_pipeline_v2.py b/rllib/connectors/connector_pipeline_v2.py index f7d570a49ffa..583825fb1d67 100644 --- a/rllib/connectors/connector_pipeline_v2.py +++ b/rllib/connectors/connector_pipeline_v2.py @@ -59,7 +59,18 @@ def __init__( pipeline during construction. Note that you can always add (or remove) more ConnectorV2 pieces later on the fly. """ - self.connectors = connectors or [] + self.connectors = [] + + for conn in connectors: + # If we have a `ConnectorV2` instance just append. + if isinstance(conn, ConnectorV2): + self.connectors.append(conn) + # If, we have a class with `args` and `kwargs`, build the instance. + # Note that this way of constructing a pipeline should only be + # used internally when restoring the pipeline state from a + # checkpoint. + elif isinstance(conn, tuple) and len(conn) == 3: + self.connectors.append(conn[0](*conn[1], **conn[2])) super().__init__(input_observation_space, input_action_space, **kwargs) @@ -266,6 +277,17 @@ def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]: # don't have to return the `connectors` c'tor kwarg from there. This is b/c all # connector pieces in this pipeline are themselves Checkpointable components, # so they will be properly written into this pipeline's checkpoint. + @override(Checkpointable) + def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: + return ( + (self.input_observation_space, self.input_action_space), # *args + { + "connectors": [ + (type(conn), *conn.get_ctor_args_and_kwargs()) + for conn in self.connectors + ] + }, + ) @override(ConnectorV2) def reset_state(self) -> None: diff --git a/rllib/connectors/connector_v2.py b/rllib/connectors/connector_v2.py index 83eada4ba87f..535f9eeb1657 100644 --- a/rllib/connectors/connector_v2.py +++ b/rllib/connectors/connector_v2.py @@ -97,6 +97,7 @@ def __init__( self._action_space = None self._input_observation_space = None self._input_action_space = None + self._kwargs = kwargs self.input_action_space = input_action_space self.input_observation_space = input_observation_space @@ -949,7 +950,7 @@ def set_state(self, state: StateDict) -> None: def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: return ( (self.input_observation_space, self.input_action_space), # *args - {}, # **kwargs + self._kwargs, # **kwargs ) def reset_state(self) -> None: diff --git a/rllib/tuned_examples/ppo/cartpole_ppo.py b/rllib/tuned_examples/ppo/cartpole_ppo.py index de33650280b0..a297989b53ac 100644 --- a/rllib/tuned_examples/ppo/cartpole_ppo.py +++ b/rllib/tuned_examples/ppo/cartpole_ppo.py @@ -25,7 +25,6 @@ ) ) - if __name__ == "__main__": from ray.rllib.utils.test_utils import run_rllib_example_script_experiment