-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Introduce Checkpointable
API for RLlib components and subcomponents.
#46376
Changes from 1 commit
d4cb54f
36dc87d
3b89119
9efffc9
dbfddd7
2b9f566
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,20 @@ | ||
import abc | ||
import logging | ||
import json | ||
import os | ||
from packaging import version | ||
import pathlib | ||
import re | ||
from typing import Any, Dict, Union | ||
import tempfile | ||
from typing import Any, Collection, Dict, List, Optional, Tuple, Union | ||
|
||
import ray | ||
import ray.cloudpickle as pickle | ||
from ray.rllib.utils.annotations import ( | ||
OverrideToImplementCustomLogic_CallToSuperRecommended, | ||
) | ||
from ray.rllib.utils.serialization import NOT_SERIALIZABLE, serialize_type | ||
from ray.rllib.utils.typing import StateDict | ||
from ray.train import Checkpoint | ||
from ray.util import log_once | ||
from ray.util.annotations import PublicAPI | ||
|
@@ -29,15 +37,287 @@ | |
# 1.1: Same as 1.0, but has a new "format" field in the rllib_checkpoint.json file | ||
# indicating, whether the checkpoint is `cloudpickle` (default) or `msgpack`. | ||
|
||
# 1.2: Introduces the checkpoint for the new Learner API if the Learner api is enabled. | ||
# 1.2: Introduces the checkpoint for the new Learner API if the Learner API is enabled. | ||
|
||
# 2.0: Introduces the Checkpointable API for all components on the new API stack | ||
# (if the Learner-, RLModule, EnvRunner, and ConnectorV2 APIs are enabled). | ||
|
||
CHECKPOINT_VERSION = version.Version("1.1") | ||
CHECKPOINT_VERSION_LEARNER = version.Version("1.2") | ||
CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER = version.Version("2.0") | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class Checkpointable(abc.ABC): | ||
"""Abstract base class for a component of RLlib that can be checkpointed to disk. | ||
|
||
Subclasses must implement the following APIs: | ||
- save_to_path() | ||
- restore_from_path() | ||
- from_checkpoint() | ||
- get_state() | ||
- set_state() | ||
- ctor_args_and_kwargs() | ||
- get_metadata() | ||
- get_checkpointable_components() | ||
""" | ||
|
||
# The subdirectory of this class (if it's a subcomponent of another Checkpointable). | ||
# For example, if A is-a Checkpointable and contains an instance | ||
# of class B (also a Checkpointable), then class B should set this to something like | ||
# "component_B". When class A's `save_to_path([path])` is called, the state of | ||
# class B can then be found in the dir: `path/component_B/`. | ||
COMPONENT_DIR_NAME = "" | ||
|
||
# The state file for the implementing class. | ||
# This file contains any state information that does NOT belong to any subcomponent | ||
# of the implementing class (which are `Checkpointable` themselves and thus should | ||
# have their own state- and metadata files). | ||
# After a `save_to_path([path])` this file can be found directly in: `path/`. | ||
STATE_FILE_NAME = "state.pkl" | ||
|
||
# The filename of the pickle file that contains the class information of the | ||
# Checkpointable as well as all constructor args to be passed to such a class in | ||
# order to construct a new instance. | ||
CLASS_AND_CTOR_ARGS_FILE_NAME = "class_and_ctor_args.pkl" | ||
|
||
# Subclasses may set this to their own metadata filename. | ||
# The dict returned by self.get_metadata() is stored in this JSON file inside | ||
# `COMPONENT_DIR_NAME/`. | ||
METADATA_FILE_NAME = "metadata.json" | ||
|
||
def save_to_path( | ||
self, | ||
path: Optional[Union[str, pathlib.Path]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does exclude paths to cloud storage like S3 I think. In the future we should provide this. When running with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct. Actually, for Keep in mind that the subcomponents of RLlib (e.g. RLModule, Learner, etc..) - if used in isolation - will never be run in connection with Tune, so this should be fine. It would be super nice, though, to have Algorithm also abide to this new API, but we'll see. For now, it'll use the tune.Trainable's |
||
state: Optional[StateDict] = None, | ||
) -> str: | ||
"""Saves the state of the implementing class (or `state`) to `path`. | ||
|
||
The state of the implementing class is always saved in the following format: | ||
|
||
.. testcode:: | ||
:skipif: True | ||
|
||
path/ | ||
[component1]/ | ||
[component1 subcomponentA]/ | ||
... | ||
[component1 subcomponentB]/ | ||
... | ||
[component2]/ | ||
... | ||
[cls.METADATA_FILE_NAME].json | ||
[cls.STATE_FILE_NAME].pkl | ||
|
||
Args: | ||
path: The path to the directory to save the state of the implementing class | ||
to. If `path` doesn't exist or is None, then a new directory will be | ||
created (and returned). | ||
state: An optional state dict to be used instead of getting a new state of | ||
the implementing class through `self.get_state()`. | ||
|
||
Returns: | ||
The path (str) where the state has been saved. | ||
""" | ||
# Create path, if necessary. | ||
if path is None: | ||
path = path or tempfile.mkdtemp() | ||
|
||
# Make sure, path exists. | ||
path = pathlib.Path(path) | ||
path.mkdir(parents=True, exist_ok=True) | ||
|
||
# Write metadata file to disk. | ||
metadata = self.get_metadata() | ||
if "checkpoint_version" not in metadata: | ||
metadata["checkpoint_version"] = str( | ||
CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER | ||
) | ||
with open(path / self.METADATA_FILE_NAME, "w") as f: | ||
json.dump(metadata, f) | ||
|
||
# Write the class and constructor args information to disk. | ||
with open(path / self.CLASS_AND_CTOR_ARGS_FILE_NAME, "w") as f: | ||
json.dump( | ||
{ | ||
"class": type(self), | ||
"ctor_args_and_kwargs": self.ctor_args_and_kwargs(), | ||
}, | ||
f, | ||
) | ||
|
||
# Get the entire state of this Checkpointable, or use provided `state`. | ||
state = state or self.get_state() | ||
|
||
# Write components of `self` that themselves are `Checkpointable`. | ||
for component_name, component in self.get_checkpointable_components(): | ||
# If subcomponent's name is not in `state`, ignore it and don't write this | ||
# subcomponent's state to disk. | ||
if component_name not in state: | ||
continue | ||
component_state = state.pop(component_name) | ||
# By providing the `state` arg, we make sure that the component does not | ||
# have to call its own `get_state()` anymore, but uses what's provided here. | ||
component.save_to_path(path / component_name, state=component_state) | ||
|
||
# Write all the remaining state to disk. | ||
# Write state (w/o policies) to disk. | ||
with open(path / self.STATE_FILE_NAME, "wb") as f: | ||
pickle.dump(state, f) | ||
|
||
return str(path) | ||
|
||
def restore_from_path(self, path: Union[str, pathlib.Path], **kwargs) -> None: | ||
"""Restores the state of the implementing class from the given path. | ||
|
||
The given `path` should have the following structure and contain the following | ||
files: | ||
|
||
.. testcode:: | ||
:skipif: True | ||
|
||
path/ | ||
[component1]/ | ||
[component1 subcomponentA]/ | ||
... | ||
[component1 subcomponentB]/ | ||
... | ||
[component2]/ | ||
... | ||
[cls.STATE_FILE_NAME].pkl | ||
|
||
Note that the self.METADATA_FILE_NAME file is not required to restore the state. | ||
|
||
Args: | ||
path: The path to load the implementing class' state from. | ||
**kwargs: Forward compatibility kwargs. | ||
""" | ||
path = pathlib.Path(path) | ||
|
||
# Restore components of `self` that themselves are `Checkpointable`. | ||
for component_name, component in self.get_checkpointable_components(): | ||
component_dir = path / component_name | ||
# If subcomponent's dir is not in path, ignore it and don't restore this | ||
# subcomponent's state from disk. | ||
if not os.path.isdir(component_dir): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we stick here to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great call. Always get confused with os.path vs pathlib. :| |
||
continue | ||
# Call `restore_from_path()` on subcomponent, thereby passing in the | ||
# **kwargs. | ||
component.restore_from_path(component_dir, **kwargs) | ||
|
||
state = pickle.load(open(path / self.STATE_FILE_NAME, "rb")) | ||
self.set_state(state) | ||
|
||
def from_checkpoint( | ||
self, path: Union[str, pathlib.Path], **kwargs | ||
) -> "Checkpointable": | ||
"""Creates a new Checkpointable instance from the given location and returns it. | ||
|
||
Args: | ||
path: The checkpoint path to load (a) the information on how to construct | ||
a new instance of the implementing class and (b) the state to restore | ||
the created instance to. | ||
kwargs: Forward compatibility kwargs. Note that these kwargs are sent to | ||
each subcomponent's `from_checkpoint()` call. | ||
|
||
Returns: | ||
A new instance of the implementing class, already set to the state stored | ||
under `path`. | ||
""" | ||
path = pathlib.Path(path) | ||
|
||
# Get the class constructor to call. | ||
ctor_info = pickle.load(open(path / self.CLASS_AND_CTOR_ARGS_FILE_NAME, "rb")) | ||
# Construct an initial object. | ||
obj = ctor_info["class"]( | ||
*ctor_info["ctor_args_and_kwargs"][0], | ||
**ctor_info["ctor_args_and_kwargs"][1], | ||
) | ||
# Restore the state of the constructed object. | ||
obj.restore_from_path(path, **kwargs) | ||
# Return the new object. | ||
return obj | ||
|
||
@abc.abstractmethod | ||
def get_state( | ||
self, | ||
components: Optional[Collection[str]] = None, | ||
**kwargs, | ||
) -> StateDict: | ||
"""Returns the implementing class's current state as a dict. | ||
|
||
Args: | ||
components: An optional list of string keys to be included in the | ||
returned state. This might be useful, if getting certain components | ||
of the state is expensive (e.g. reading/compiling the weights of a large | ||
NN) and at the same time, these components are not required by the | ||
caller. | ||
not_components: An optional list of string keys to be excluded in the | ||
returned state, even if the same string is part of `components`. | ||
This is useful to get the complete state of the class, except | ||
one or a few components. | ||
kwargs: Forward-compatibility kwargs. | ||
|
||
Returns: | ||
The current state of the implementing class (or only the `components` | ||
specified, w/o those in `not_components`). | ||
""" | ||
|
||
@abc.abstractmethod | ||
def set_state(self, state: StateDict) -> None: | ||
"""Sets the implementing class' state to the given state dict. | ||
|
||
If component keys are missing in `state`, these components of the implementing | ||
class will not be updated/set. | ||
|
||
Args: | ||
state: The state dict to restore the state from. Maps component keys | ||
to the corresponding subcomponent's own state. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: | ||
"""Returns the args/kwargs used to create `self` from its constructor. | ||
|
||
Returns: | ||
A tuple of the args (as a tuple) and kwargs (as a Dict[str, Any]) used to | ||
construct `self` from its class constructor. | ||
""" | ||
|
||
@OverrideToImplementCustomLogic_CallToSuperRecommended | ||
def get_metadata(self) -> Dict: | ||
"""Returns JSON writable metadata further describing the implementing class. | ||
|
||
Note that this metadata is NOT part of any state and is thus NOT needed to | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! Like this ot becomes more handy. |
||
restore the state of a Checkpointable instance from a directory. Rather, the | ||
metadata will be written into `self.METADATA_FILE_NAME` when calling | ||
`self.save_to_path()` for the user's convenience. | ||
|
||
Returns: | ||
A JSON-encodable dict of metadata information. | ||
""" | ||
return { | ||
"class_and_ctor_args_file": self.CLASS_AND_CTOR_ARGS_FILE_NAME, | ||
"state_file": self.STATE_FILE_NAME, | ||
"ray_version": ray.__version__, | ||
"ray_commit": ray.__commit__, | ||
} | ||
|
||
def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"], ...]: | ||
"""Returns the implementing class's own Checkpointable subcomponents. | ||
|
||
Returns: | ||
A list of 2-tuples (name, subcomponent) describing the implementing class' | ||
subcomponents, all of which have to be `Checkpointable` themselves and | ||
whose state is therefore written into subdirectories (rather than the main | ||
state file (self.STATE_FILE_NAME) when calling `self.save_to_path()`). | ||
""" | ||
return [] | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: | ||
"""Returns a dict with information about a Algorithm/Policy checkpoint. | ||
"""Returns a dict with information about an Algorithm/Policy checkpoint. | ||
|
||
If the given checkpoint is a >=v1.0 checkpoint directory, try reading all | ||
information from the contained `rllib_checkpoint.json` file. | ||
|
@@ -248,7 +528,7 @@ def convert_to_msgpack_checkpoint( | |
f, | ||
) | ||
|
||
# Write individual policies to disk, each in their own sub-directory. | ||
# Write individual policies to disk, each in their own subdirectory. | ||
for pid, policy_state in policy_states.items(): | ||
# From here on, disallow policyIDs that would not work as directory names. | ||
validate_policy_id(pid, error=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dumb question at the beginning: Can't we directly derive from
train.Checkpoint
or at least its metaclass '_CheckpointMetaClass`?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are different concepts, I believe. train.Checkpoint is just describing the checkpoint itself, NOT the component that does get checkpointed.
I would probably interpret the new Checkpointable more as a competitor to tune.Trainable with its existing
save
andrestore
APIs and subsequent (save_checkpoint and load_checkpoint). We'll have to see, how to fit the Algorithm class itself into this new API and whether it's even possible. But for all the sub-components (which - in isolation - will never be used with Tune), this new API will unify and simplify a lot of things for our users. Especially when going to production/deployment.Delivers an important promise of the new API stack: Plugability and re-usability of all components outside RLlib.