Skip to content

Commit

Permalink
[RLlib] Increase backward compatibility of checkpoints. (#47708)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored Sep 25, 2024
1 parent f17bb99 commit b1624c9
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 20 deletions.
28 changes: 19 additions & 9 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def __init__(
object. If unspecified, a default logger is created.
**kwargs: Arguments passed to the Trainable base class.
"""
config = config or self.get_default_config()
config = config # or self.get_default_config()

# Translate possible dict into an AlgorithmConfig object, as well as,
# resolving generic config objects into specific ones (e.g. passing
Expand All @@ -466,22 +466,31 @@ def __init__(
# `self.get_default_config()` also returned a dict ->
# Last resort: Create core AlgorithmConfig from merged dicts.
if isinstance(default_config, dict):
config = AlgorithmConfig.from_dict(
config_dict=self.merge_algorithm_configs(
default_config, config, True
if "class" in config:
AlgorithmConfig.from_state(config)
else:
config = AlgorithmConfig.from_dict(
config_dict=self.merge_algorithm_configs(
default_config, config, True
)
)
)

# Default config is an AlgorithmConfig -> update its properties
# from the given config dict.
else:
config = default_config.update_from_dict(config)
if isinstance(config, dict) and "class" in config:
config = default_config.from_state(config)
else:
config = default_config.update_from_dict(config)
else:
default_config = self.get_default_config()
# Given AlgorithmConfig is not of the same type as the default config:
# This could be the case e.g. if the user is building an algo from a
# generic AlgorithmConfig() object.
if not isinstance(config, type(default_config)):
config = default_config.update_from_dict(config.to_dict())
else:
config = default_config.from_state(config.get_state())

# In case this algo is using a generic config (with no algo_class set), set it
# here.
Expand Down Expand Up @@ -2899,7 +2908,7 @@ def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]:
@override(Checkpointable)
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
return (
(self.config,), # *args,
(self.config.get_state(),), # *args,
{}, # **kwargs
)

Expand Down Expand Up @@ -3296,7 +3305,7 @@ def __getstate__(self) -> Dict:
# Add config to state so complete Algorithm can be reproduced w/o it.
state = {
"algorithm_class": type(self),
"config": self.config,
"config": self.config.get_state(),
}

if hasattr(self, "env_runner_group"):
Expand Down Expand Up @@ -3437,7 +3446,8 @@ def _checkpoint_info_to_algorithm_state(

with open(checkpoint_info["state_file"], "rb") as f:
if msgpack is not None:
state = msgpack.load(f)
data = f.read()
state = msgpack.unpackb(data, raw=False)
else:
state = pickle.load(f)

Expand Down
61 changes: 55 additions & 6 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,12 +635,7 @@ def to_dict(self) -> AlgorithmConfigDict:
policies_dict = {}
for policy_id, policy_spec in config.pop("policies").items():
if isinstance(policy_spec, PolicySpec):
policies_dict[policy_id] = (
policy_spec.policy_class,
policy_spec.observation_space,
policy_spec.action_space,
policy_spec.config,
)
policies_dict[policy_id] = policy_spec.get_state()
else:
policies_dict[policy_id] = policy_spec
config["policies"] = policies_dict
Expand Down Expand Up @@ -783,6 +778,56 @@ def update_from_dict(

return self

def get_state(self) -> Dict[str, Any]:
"""Returns a dict state that can be pickled.
Returns:
A dictionary containing all attributes of the instance.
"""

state = self.__dict__.copy()
state["class"] = type(self)
state.pop("algo_class")
state.pop("_is_frozen")
state = {k: v for k, v in state.items() if v != DEPRECATED_VALUE}

# Convert `policies` (PolicySpecs?) into dict.
# Convert policies dict such that each policy ID maps to a old-style.
# 4-tuple: class, obs-, and action space, config.
# TODO (simon, sven): Remove when deprecating old stack.
if "policies" in state and isinstance(state["policies"], dict):
policies_dict = {}
for policy_id, policy_spec in state.pop("policies").items():
if isinstance(policy_spec, PolicySpec):
policies_dict[policy_id] = policy_spec.get_state()
else:
policies_dict[policy_id] = policy_spec
state["policies"] = policies_dict

# state = self._serialize_dict(state)

return state

@classmethod
def from_state(cls, state: Dict[str, Any]) -> "AlgorithmConfig":
"""Returns an instance constructed from the state.
Args:
cls: An `AlgorithmConfig` class.
state: A dictionary containing the state of an `AlgorithmConfig`.
See `AlgorithmConfig.get_state` for creating a state.
Returns:
An `AlgorithmConfig` instance with attributes from the `state`.
"""

ctor = state["class"]
config = ctor()

config.__dict__.update(state)

return config

# TODO(sven): We might want to have a `deserialize` method as well. Right now,
# simply using the from_dict() API works in this same (deserializing) manner,
# whether the dict used is actually code-free (already serialized) or not
Expand Down Expand Up @@ -4558,6 +4603,10 @@ def _validate_offline_settings(self):
@staticmethod
def _serialize_dict(config):
# Serialize classes to classpaths:
if "callbacks_class" in config:
config["callbacks"] = config.pop("callbacks_class")
if "class" in config:
config["class"] = serialize_type(config["class"])
config["callbacks"] = serialize_type(config["callbacks"])
config["sample_collector"] = serialize_type(config["sample_collector"])
if isinstance(config["env"], type):
Expand Down
17 changes: 17 additions & 0 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,23 @@ def __eq__(self, other: "PolicySpec"):
and self.config == other.config
)

def get_state(self) -> Dict[str, Any]:
"""Returns the state of a `PolicyDict` as a dict."""
return (
self.policy_class,
self.observation_space,
self.action_space,
self.config,
)

@classmethod
def from_state(cls, state: Dict[str, Any]) -> "PolicySpec":
"""Builds a `PolicySpec` from a state."""
policy_spec = PolicySpec()
policy_spec.__dict__.update(state)

return policy_spec

def serialize(self) -> Dict:
from ray.rllib.algorithms.registry import get_policy_class_name

Expand Down
7 changes: 5 additions & 2 deletions rllib/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ def from_checkpoint(
"an implementer of the `Checkpointable` API!"
)

# Construct an initial object.
obj = ctor(
*ctor_info["ctor_args_and_kwargs"][0],
**ctor_info["ctor_args_and_kwargs"][1],
Expand Down Expand Up @@ -713,6 +712,7 @@ def convert_to_msgpack_checkpoint(
this is the same as `msgpack_checkpoint_dir`.
"""
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.core.rl_module import validate_module_id

# Try to import msgpack and msgpack_numpy.
Expand All @@ -726,7 +726,10 @@ def convert_to_msgpack_checkpoint(
# Serialize the algorithm class.
state["algorithm_class"] = serialize_type(state["algorithm_class"])
# Serialize the algorithm's config object.
state["config"] = state["config"].serialize()
if not isinstance(state["config"], dict):
state["config"] = state["config"].serialize()
else:
state["config"] = AlgorithmConfig._serialize_dict(state["config"])

# Extract policy states from worker state (Policies get their own
# checkpoint sub-dirs).
Expand Down
9 changes: 6 additions & 3 deletions rllib/utils/tests/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ def test_msgpack_checkpoint_translation(self):
pickle_w["policy_states"]["default_policy"]["policy_spec"]["config"]
)
check(pickle_w["policy_states"], msgpack_w["policy_states"])
check(pickle_state["config"].serialize(), msgpack_state["config"].serialize())
check(
AlgorithmConfig._serialize_dict(pickle_state["config"]),
AlgorithmConfig._serialize_dict(msgpack_state["config"]),
)

algo1.stop()
algo2.stop()
Expand Down Expand Up @@ -219,9 +222,9 @@ def mapping_fn(aid, episode, worker, **kwargs):
# handle comparing types/classes.
# The only exception is the `policies` field as it might have gotten
# regenerated from a set, thus the order of PIDs might be different.
p = pickle_state["config"].serialize()
p = AlgorithmConfig._serialize_dict(pickle_state["config"])
p_pols = p.pop("policies")
m = msgpack_state["config"].serialize()
m = AlgorithmConfig._serialize_dict(msgpack_state["config"])
m_pols = m.pop("policies")
check(p, m)
# Compare sets of policyIDs here.
Expand Down

0 comments on commit b1624c9

Please sign in to comment.