diff --git a/doc/source/rllib/rllib-learner.rst b/doc/source/rllib/rllib-learner.rst index 3a70479263b06..38bfea05c079c 100644 --- a/doc/source/rllib/rllib-learner.rst +++ b/doc/source/rllib/rllib-learner.rst @@ -319,12 +319,12 @@ Getting and setting state .. testcode:: - :hide: + :hide: - import tempfile + import tempfile - LEARNER_CKPT_DIR = str(tempfile.TemporaryDirectory()) - LEARNER_GROUP_CKPT_DIR = str(tempfile.TemporaryDirectory()) + LEARNER_CKPT_DIR = tempfile.mkdtemp() + LEARNER_GROUP_CKPT_DIR = tempfile.mkdtemp() Checkpointing diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 0bfab714c4b1e..7798c134004dc 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -12,6 +12,7 @@ import os from packaging import version import pathlib +import pyarrow.fs import re import tempfile import time @@ -305,6 +306,7 @@ class Algorithm(Checkpointable, Trainable, AlgorithmBase): def from_checkpoint( cls, path: Optional[Union[str, Checkpoint]] = None, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, *, # @OldAPIStack policy_ids: Optional[Collection[PolicyID]] = None, @@ -324,6 +326,8 @@ def from_checkpoint( Args: path: The path (str) to the checkpoint directory to use or an AIR Checkpoint instance to restore from. + filesystem: PyArrow FileSystem to use to access data at the `path`. If not + specified, this is inferred from the URI scheme of `path`. policy_ids: Optional list of PolicyIDs to recover. This allows users to restore an Algorithm with only a subset of the originally present Policies. @@ -371,7 +375,7 @@ def from_checkpoint( ) # New API stack -> Use Checkpointable's default implementation. elif checkpoint_info["checkpoint_version"] >= version.Version("2.0"): - return super().from_checkpoint(path, **kwargs) + return super().from_checkpoint(path, filesystem=filesystem, **kwargs) # This is a msgpack checkpoint. if checkpoint_info["format"] == "msgpack": diff --git a/rllib/algorithms/tests/test_algorithm.py b/rllib/algorithms/tests/test_algorithm.py index 12c98ce50f604..ce4bf0b8d9258 100644 --- a/rllib/algorithms/tests/test_algorithm.py +++ b/rllib/algorithms/tests/test_algorithm.py @@ -33,7 +33,7 @@ class TestAlgorithm(unittest.TestCase): @classmethod def setUpClass(cls): - ray.init() + ray.init(local_mode=True) register_env("multi_cart", lambda cfg: MultiAgentCartPole(cfg)) @classmethod diff --git a/rllib/utils/checkpoints.py b/rllib/utils/checkpoints.py index f9ad27bea5f84..89106db219921 100644 --- a/rllib/utils/checkpoints.py +++ b/rllib/utils/checkpoints.py @@ -9,6 +9,8 @@ from types import MappingProxyType from typing import Any, Collection, Dict, List, Optional, Tuple, Union +import pyarrow.fs + import ray import ray.cloudpickle as pickle from ray.rllib.core import ( @@ -92,6 +94,7 @@ def save_to_path( path: Optional[Union[str, pathlib.Path]] = None, *, state: Optional[StateDict] = None, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, ) -> str: """Saves the state of the implementing class (or `state`) to `path`. @@ -126,17 +129,39 @@ def save_to_path( created (and returned). state: An optional state dict to be used instead of getting a new state of the implementing class through `self.get_state()`. + filesystem: PyArrow FileSystem to use to access data at the `path`. + If not specified, this is inferred from the URI scheme of `path`. Returns: The path (str) where the state has been saved. """ - # Create path, if necessary. + + # If no path is given create a local temporary directory. if path is None: - path = path or tempfile.mkdtemp() + import uuid + + # Get the location of the temporary directory on the OS. + tmp_dir = pathlib.Path(tempfile.gettempdir()) + # Create a random directory name. + random_dir_name = str(uuid.uuid4()) + # Create the path, but do not craet the directory on the + # filesystem, yet. This is done by `PyArrow`. + path = path or tmp_dir / random_dir_name + + # We need a string path for `pyarrow.fs.FileSystem.from_uri`. + path = path if isinstance(path, str) else path.as_posix() + + # If we have no filesystem, figure it out. + if path and not filesystem: + # Note the path needs to be a path that is relative to the + # filesystem (e.g. `gs://tmp/...` -> `tmp/...`). + filesystem, path = pyarrow.fs.FileSystem.from_uri(path) # Make sure, path exists. + filesystem.create_dir(path, recursive=True) + + # Convert to `pathlib.Path` for easy handling. path = pathlib.Path(path) - path.mkdir(parents=True, exist_ok=True) # Write metadata file to disk. metadata = self.get_metadata() @@ -144,11 +169,15 @@ def save_to_path( metadata["checkpoint_version"] = str( CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER ) - with open(path / self.METADATA_FILE_NAME, "w") as f: - json.dump(metadata, f) + with filesystem.open_output_stream( + (path / self.METADATA_FILE_NAME).as_posix() + ) as f: + f.write(json.dumps(metadata).encode("utf-8")) # Write the class and constructor args information to disk. - with open(path / self.CLASS_AND_CTOR_ARGS_FILE_NAME, "wb") as f: + with filesystem.open_output_stream( + (path / self.CLASS_AND_CTOR_ARGS_FILE_NAME).as_posix() + ) as f: pickle.dump( { "class": type(self), @@ -262,10 +291,12 @@ def _rmdir(_, _dir=worker_temp_dir): # By providing the `state` arg, we make sure that the component does not # have to call its own `get_state()` anymore, but uses what's provided # here. - comp.save_to_path(comp_path, state=comp_state) + comp.save_to_path(comp_path, filesystem=filesystem, state=comp_state) # Write all the remaining state to disk. - with open(path / self.STATE_FILE_NAME, "wb") as f: + with filesystem.open_output_stream( + (path / self.STATE_FILE_NAME).as_posix() + ) as f: pickle.dump(state, f) return str(path) @@ -275,6 +306,7 @@ def restore_from_path( path: Union[str, pathlib.Path], *, component: Optional[str] = None, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, **kwargs, ) -> None: """Restores the state of the implementing class from the given path. @@ -313,10 +345,21 @@ def restore_from_path( the subcomponent and thus, only that subcomponent's state is restored/loaded. All other state of `self` remains unchanged in this case. + filesystem: PyArrow FileSystem to use to access data at the `path`. If not + specified, this is inferred from the URI scheme of `path`. **kwargs: Forward compatibility kwargs. """ + path = path if isinstance(path, str) else path.as_posix() + + if path and not filesystem: + # Note the path needs to be a path that is relative to the + # filesystem (e.g. `gs://tmp/...` -> `tmp/...`). + filesystem, path = pyarrow.fs.FileSystem.from_uri(path) + # Only here convert to a `Path` instance b/c otherwise + # cloud path gets broken (i.e. 'gs://' -> 'gs:/'). path = pathlib.Path(path) - if not path.is_dir(): + + if not _exists_at_fs_path(filesystem, path.as_posix()): raise FileNotFoundError(f"`path` ({path}) not found!") # Restore components of `self` that themselves are `Checkpointable`. @@ -330,7 +373,7 @@ def restore_from_path( comp_dir = path / comp_name # If subcomponent's dir is not in path, ignore it and don't restore this # subcomponent's state from disk. - if not comp_dir.is_dir(): + if not _exists_at_fs_path(filesystem, comp_dir.as_posix()): continue else: comp_dir = path @@ -380,17 +423,24 @@ def _restore( # Call `restore_from_path()` on local subcomponent, thereby passing in the # **kwargs. else: - comp.restore_from_path(comp_dir, component=comp_arg, **kwargs) + comp.restore_from_path( + comp_dir, filesystem=filesystem, component=comp_arg, **kwargs + ) # Restore the rest of the state (not based on subcomponents). if component is None: - with open(path / self.STATE_FILE_NAME, "rb") as f: + with filesystem.open_input_stream( + (path / self.STATE_FILE_NAME).as_posix() + ) as f: state = pickle.load(f) self.set_state(state) @classmethod def from_checkpoint( - cls, path: Union[str, pathlib.Path], **kwargs + cls, + path: Union[str, pathlib.Path], + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + **kwargs, ) -> "Checkpointable": """Creates a new Checkpointable instance from the given location and returns it. @@ -398,6 +448,8 @@ def from_checkpoint( path: The checkpoint path to load (a) the information on how to construct a new instance of the implementing class and (b) the state to restore the created instance to. + filesystem: PyArrow FileSystem to use to access data at the `path`. If not + specified, this is inferred from the URI scheme of `path`. kwargs: Forward compatibility kwargs. Note that these kwargs are sent to each subcomponent's `from_checkpoint()` call. @@ -405,10 +457,22 @@ def from_checkpoint( A new instance of the implementing class, already set to the state stored under `path`. """ + # We need a string path for the `PyArrow` filesystem. + path = path if isinstance(path, str) else path.as_posix() + + # If no filesystem is passed in create one. + if path and not filesystem: + # Note the path needs to be a path that is relative to the + # filesystem (e.g. `gs://tmp/...` -> `tmp/...`). + filesystem, path = pyarrow.fs.FileSystem.from_uri(path) + # Only here convert to a `Path` instance b/c otherwise + # cloud path gets broken (i.e. 'gs://' -> 'gs:/'). path = pathlib.Path(path) # Get the class constructor to call. - with open(path / cls.CLASS_AND_CTOR_ARGS_FILE_NAME, "rb") as f: + with filesystem.open_input_stream( + (path / cls.CLASS_AND_CTOR_ARGS_FILE_NAME).as_posix() + ) as f: ctor_info = pickle.load(f) ctor = ctor_info["class"] @@ -430,7 +494,7 @@ def from_checkpoint( **ctor_info["ctor_args_and_kwargs"][1], ) # Restore the state of the constructed object. - obj.restore_from_path(path, **kwargs) + obj.restore_from_path(path, filesystem=filesystem, **kwargs) # Return the new object. return obj @@ -536,8 +600,22 @@ def _get_subcomponents(self, name, components): return None if not subcomponents else subcomponents +def _exists_at_fs_path(fs: pyarrow.fs.FileSystem, path: str) -> bool: + """Returns `True` if the path can be found in the filesystem.""" + valid = fs.get_file_info(path) + return valid.type != pyarrow.fs.FileType.NotFound + + +def _is_dir(file_info: pyarrow.fs.FileInfo) -> bool: + """Returns `True`, if the file info is from a directory.""" + return file_info.type == pyarrow.fs.FileType.Directory + + @PublicAPI(stability="alpha") -def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: +def get_checkpoint_info( + checkpoint: Union[str, Checkpoint], + filesystem: Optional["pyarrow.fs.FileSystem"] = None, +) -> Dict[str, Any]: """Returns a dict with information about an Algorithm/Policy checkpoint. If the given checkpoint is a >=v1.0 checkpoint directory, try reading all @@ -545,6 +623,8 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: Args: checkpoint: The checkpoint directory (str) or an AIR Checkpoint object. + filesystem: PyArrow FileSystem to use to access data at the `checkpoint`. If not + specified, this is inferred from the URI scheme provided by `checkpoint`. Returns: A dict containing the keys: @@ -573,22 +653,33 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: # `checkpoint` is a Checkpoint instance: Translate to directory and continue. if isinstance(checkpoint, Checkpoint): checkpoint = checkpoint.to_directory() + + if checkpoint and not filesystem: + # Note the path needs to be a path that is relative to the + # filesystem (e.g. `gs://tmp/...` -> `tmp/...`). + filesystem, checkpoint = pyarrow.fs.FileSystem.from_uri(checkpoint) + # Only here convert to a `Path` instance b/c otherwise + # cloud path gets broken (i.e. 'gs://' -> 'gs:/'). checkpoint = pathlib.Path(checkpoint) # Checkpoint is dir. - if checkpoint.is_dir(): + if _exists_at_fs_path(filesystem, checkpoint.as_posix()) and _is_dir( + filesystem.get_file_info(checkpoint.as_posix()) + ): info.update({"checkpoint_dir": str(checkpoint)}) # Figure out whether this is an older checkpoint format # (with a `checkpoint-\d+` file in it). - for file in checkpoint.iterdir(): - path_file = checkpoint / file - if path_file.is_file(): - if re.match("checkpoint-\\d+", file.name): + file_info_list = filesystem.get_file_info( + pyarrow.fs.FileSelector(checkpoint.as_posix(), recursive=False) + ) + for file_info in file_info_list: + if file_info.is_file: + if re.match("checkpoint-\\d+", file_info.base_name): info.update( { "checkpoint_version": version.Version("0.1"), - "state_file": str(path_file), + "state_file": str(file_info.base_name), } ) return info @@ -598,8 +689,14 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: # If rllib_checkpoint.json file present, read available information from it # and then continue with the checkpoint analysis (possibly overriding further # information). - if (checkpoint / "rllib_checkpoint.json").is_file(): - with open(checkpoint / "rllib_checkpoint.json") as f: + if _exists_at_fs_path( + filesystem, (checkpoint / "rllib_checkpoint.json").as_posix() + ): + # if (checkpoint / "rllib_checkpoint.json").is_file(): + with filesystem.open_input_stream( + (checkpoint / "rllib_checkpoint.json").as_posix() + ) as f: + # with open(checkpoint / "rllib_checkpoint.json") as f: rllib_checkpoint_info = json.load(fp=f) if "checkpoint_version" in rllib_checkpoint_info: rllib_checkpoint_info["checkpoint_version"] = version.Version( @@ -618,7 +715,10 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: # Policy checkpoint file found. for extension in ["pkl", "msgpck"]: - if (checkpoint / ("policy_state." + extension)).is_file(): + if _exists_at_fs_path( + filesystem, (checkpoint / ("policy_state." + extension)).as_posix() + ): + # if (checkpoint / ("policy_state." + extension)).is_file(): info.update( { "type": "Policy", @@ -633,7 +733,10 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: format = None for extension in ["pkl", "msgpck"]: state_file = checkpoint / f"algorithm_state.{extension}" - if state_file.is_file(): + if ( + _exists_at_fs_path(filesystem, state_file.as_posix()) + and filesystem.get_file_info(state_file.as_posix()).is_file + ): format = "cloudpickle" if extension == "pkl" else "msgpack" break if format is None: @@ -651,10 +754,15 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: # Collect all policy IDs in the sub-dir "policies/". policies_dir = checkpoint / "policies" - if policies_dir.is_dir(): + if _exists_at_fs_path(filesystem, policies_dir.as_posix()) and _is_dir( + filesystem.get_file_info(policies_dir.as_posix()) + ): policy_ids = set() - for policy_id in policies_dir.iterdir(): - policy_ids.add(policy_id.name) + file_info_list = filesystem.get_file_info( + pyarrow.fs.FileSelector(policies_dir.as_posix(), recursive=False) + ) + for file_info in file_info_list: + policy_ids.add(file_info.base_name) info.update({"policy_ids": policy_ids}) # Collect all module IDs in the sub-dir "learner/module_state/". @@ -664,18 +772,27 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: / COMPONENT_LEARNER / COMPONENT_RL_MODULE ) - if modules_dir.is_dir(): + if _exists_at_fs_path(filesystem, checkpoint.as_posix()) and _is_dir( + filesystem.get_file_info(modules_dir.as_posix()) + ): module_ids = set() - for module_id in modules_dir.iterdir(): + file_info_list = filesystem.get_file_info( + pyarrow.fs.FileSelector(modules_dir.as_posix(), recursive=False) + ) + for file_info in file_info_list: # Only add subdirs (those are the ones where the RLModule data # is stored, not files (could be json metadata files). - if (modules_dir / module_id).is_dir(): - module_ids.add(module_id.name) + module_dir = modules_dir / file_info.base_name + if _is_dir(filesystem.get_file_info(module_dir.as_posix())): + module_ids.add(file_info.base_name) info.update({"module_ids": module_ids}) # Checkpoint is a file: Use as-is (interpreting it as old Algorithm checkpoint # version). - elif checkpoint.is_file(): + elif ( + _exists_at_fs_path(filesystem, checkpoint.as_posix()) + and filesystem.get_file_info(checkpoint.as_posix()).is_file + ): info.update( { "checkpoint_version": version.Version("0.1"), diff --git a/rllib/utils/tests/test_checkpoint_utils.py b/rllib/utils/tests/test_checkpoint_utils.py index dde997e72808c..1b27857ef9efb 100644 --- a/rllib/utils/tests/test_checkpoint_utils.py +++ b/rllib/utils/tests/test_checkpoint_utils.py @@ -41,7 +41,7 @@ def test_get_checkpoint_info_v0_1(self): self.assertTrue(info["type"] == "Algorithm") self.assertTrue(str(info["checkpoint_version"]) == "0.1") self.assertTrue(info["checkpoint_dir"] == checkpoint_dir) - self.assertTrue(info["state_file"] == algo_state_file) + self.assertTrue(info["state_file"] == Path(algo_state_file).name) self.assertTrue(info["policy_ids"] is None) def test_get_checkpoint_info_v1_1(self):