diff --git a/.circleci/config.yml b/.circleci/config.yml index bd0009f24..c781c6910 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -41,7 +41,6 @@ executors: (?x)( src/imitation/algorithms/preference_comparisons.py$ | src/imitation/rewards/reward_nets.py$ - | src/imitation/util/sacred.py$ | src/imitation/algorithms/base.py$ | src/imitation/scripts/train_preference_comparisons.py$ | src/imitation/rewards/serialize.py$ diff --git a/ci/code_checks.sh b/ci/code_checks.sh index 3b7dd0610..733c0fff3 100755 --- a/ci/code_checks.sh +++ b/ci/code_checks.sh @@ -5,7 +5,6 @@ SRC_FILES=(src/ tests/ experiments/ examples/ docs/conf.py setup.py ci/) EXCLUDE_MYPY="(?x)( src/imitation/algorithms/preference_comparisons.py$ | src/imitation/rewards/reward_nets.py$ - | src/imitation/util/sacred.py$ | src/imitation/algorithms/base.py$ | src/imitation/scripts/train_preference_comparisons.py$ | src/imitation/rewards/serialize.py$ diff --git a/docs/tutorials/1_train_bc.ipynb b/docs/tutorials/1_train_bc.ipynb index 500f0a3db..5c4f16b86 100644 --- a/docs/tutorials/1_train_bc.ipynb +++ b/docs/tutorials/1_train_bc.ipynb @@ -200,4 +200,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/docs/tutorials/3_train_gail.ipynb b/docs/tutorials/3_train_gail.ipynb index 200c0fbed..5cdeca671 100644 --- a/docs/tutorials/3_train_gail.ipynb +++ b/docs/tutorials/3_train_gail.ipynb @@ -187,4 +187,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/docs/tutorials/4_train_airl.ipynb b/docs/tutorials/4_train_airl.ipynb index 27ac1c332..1067ae2df 100644 --- a/docs/tutorials/4_train_airl.ipynb +++ b/docs/tutorials/4_train_airl.ipynb @@ -181,4 +181,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/docs/tutorials/5_train_preference_comparisons.ipynb b/docs/tutorials/5_train_preference_comparisons.ipynb index fe63732d6..b2cf6a500 100644 --- a/docs/tutorials/5_train_preference_comparisons.ipynb +++ b/docs/tutorials/5_train_preference_comparisons.ipynb @@ -203,4 +203,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb index 6c26c8f32..0dadeac4b 100644 --- a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb +++ b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb @@ -236,4 +236,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/docs/tutorials/7_train_density.ipynb b/docs/tutorials/7_train_density.ipynb index d4c1c3f85..8a3654b6d 100644 --- a/docs/tutorials/7_train_density.ipynb +++ b/docs/tutorials/7_train_density.ipynb @@ -158,4 +158,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 41c1b4129..34402aa9c 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -3,7 +3,6 @@ import collections import dataclasses import logging -import os from typing import ( Callable, Iterable, @@ -127,7 +126,7 @@ def __init__( gen_algo: base_class.BaseAlgorithm, reward_net: reward_nets.RewardNet, n_disc_updates_per_round: int = 2, - log_dir: str = "output/", + log_dir: types.AnyPath = "output/", disc_opt_cls: Type[th.optim.Optimizer] = th.optim.Adam, disc_opt_kwargs: Optional[Mapping] = None, gen_train_timesteps: Optional[int] = None, @@ -202,7 +201,7 @@ def __init__( self.venv = venv self.gen_algo = gen_algo self._reward_net = reward_net.to(gen_algo.device) - self._log_dir = log_dir + self._log_dir = types.parse_path(log_dir) # Create graph for optimising/recording stats on discriminator self._disc_opt_cls = disc_opt_cls @@ -215,10 +214,10 @@ def __init__( ) if self._init_tensorboard: - logging.info("building summary directory at " + self._log_dir) - summary_dir = os.path.join(self._log_dir, "summary") - os.makedirs(summary_dir, exist_ok=True) - self._summary_writer = thboard.SummaryWriter(summary_dir) + logging.info(f"building summary directory at {self._log_dir}") + summary_dir = self._log_dir / "summary" + summary_dir.mkdir(parents=True, exist_ok=True) + self._summary_writer = thboard.SummaryWriter(str(summary_dir)) self.venv_buffering = wrappers.BufferingWrapper(self.venv) diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index cc9a20cb2..d080e6e3c 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -472,4 +472,4 @@ def save_policy(self, policy_path: types.AnyPath) -> None: Args: policy_path: path to save policy to. """ - th.save(self.policy, types.path_to_str(policy_path)) + th.save(self.policy, types.parse_path(policy_path)) diff --git a/src/imitation/algorithms/dagger.py b/src/imitation/algorithms/dagger.py index 9a3b086f8..fbd0a5fb3 100644 --- a/src/imitation/algorithms/dagger.py +++ b/src/imitation/algorithms/dagger.py @@ -90,10 +90,8 @@ def reconstruct_trainer( A deserialized `DAggerTrainer`. """ custom_logger = custom_logger or imit_logger.configure() - checkpoint_path = pathlib.Path( - types.path_to_str(scratch_dir), - "checkpoint-latest.pt", - ) + scratch_dir = types.parse_path(scratch_dir) + checkpoint_path = scratch_dir / "checkpoint-latest.pt" trainer = th.load(checkpoint_path, map_location=utils.get_device(device)) trainer.venv = venv trainer._logger = custom_logger @@ -109,14 +107,14 @@ def _save_dagger_demo( # however that NPZ save here is likely more space efficient than # pickle from types.save(), and types.save only accepts # TrajectoryWithRew right now (subclass of Trajectory). - save_dir_obj = pathlib.Path(types.path_to_str(save_dir)) + save_dir = types.parse_path(save_dir) assert isinstance(trajectory, types.Trajectory) actual_prefix = f"{prefix}-" if prefix else "" timestamp = util.make_unique_timestamp() filename = f"{actual_prefix}dagger-demo-{timestamp}.npz" - save_dir_obj.mkdir(parents=True, exist_ok=True) - npz_path = save_dir_obj / filename + save_dir.mkdir(parents=True, exist_ok=True) + npz_path = save_dir / filename np.savez_compressed(npz_path, **dataclasses.asdict(trajectory)) logging.info(f"Saved demo at '{npz_path}'") @@ -344,7 +342,7 @@ def __init__( if beta_schedule is None: beta_schedule = LinearBetaSchedule(15) self.beta_schedule = beta_schedule - self.scratch_dir = pathlib.Path(types.path_to_str(scratch_dir)) + self.scratch_dir = types.parse_path(scratch_dir) self.venv = venv self.round_num = 0 self._last_loaded_round = -1 @@ -397,11 +395,7 @@ def _load_all_demos(self): return demo_transitions, num_demos_by_round def _get_demo_paths(self, round_dir): - return [ - os.path.join(round_dir, p) - for p in os.listdir(round_dir) - if p.endswith(".npz") - ] + return [round_dir / p for p in os.listdir(round_dir) if p.endswith(".npz")] def _demo_dir_path_for_round(self, round_num: Optional[int] = None) -> pathlib.Path: if round_num is None: @@ -411,7 +405,7 @@ def _demo_dir_path_for_round(self, round_num: Optional[int] = None) -> pathlib.P def _try_load_demos(self) -> None: """Load the dataset for this round into self.bc_trainer as a DataLoader.""" demo_dir = self._demo_dir_path_for_round() - demo_paths = self._get_demo_paths(demo_dir) if os.path.isdir(demo_dir) else [] + demo_paths = self._get_demo_paths(demo_dir) if demo_dir.is_dir() else [] if len(demo_paths) == 0: raise NeedsDemosException( f"No demos found for round {self.round_num} in dir '{demo_dir}'. " diff --git a/src/imitation/algorithms/mce_irl.py b/src/imitation/algorithms/mce_irl.py index a2ebcff91..c739ec284 100644 --- a/src/imitation/algorithms/mce_irl.py +++ b/src/imitation/algorithms/mce_irl.py @@ -176,9 +176,7 @@ def set_pi(self, pi: np.ndarray) -> None: self.pi = pi def _predict(self, observation: th.Tensor, deterministic: bool = False): - raise NotImplementedError( - "Should never be called as predict overridden.", - ) + raise NotImplementedError("Should never be called as predict overridden.") def forward( # type: ignore[override] self, diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index c0801446a..70c3f03f3 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -44,11 +44,87 @@ def dataclass_quick_asdict(obj) -> Dict[str, Any]: return d -def path_to_str(path: AnyPath) -> str: - if isinstance(path, bytes): - return path.decode() +def parse_path( + path: AnyPath, + allow_relative: bool = True, + base_directory: Optional[pathlib.Path] = None, +) -> pathlib.Path: + """Parse a path to a `pathlib.Path` object. + + All resulting paths are resolved, absolute paths. If `allow_relative` is True, + then relative paths are allowed as input, and are resolved relative to the + current working directory, or relative to `base_directory` if it is + specified. + + Args: + path: The path to parse. Can be a string, bytes, or `os.PathLike`. + allow_relative: If True, then relative paths are allowed as input, and + are resolved relative to the current working directory. If False, + an error is raised if the path is not absolute. + base_directory: If specified, then relative paths are resolved relative + to this directory, instead of the current working directory. + + Returns: + A `pathlib.Path` object. + + Raises: + ValueError: If `allow_relative` is False and the path is not absolute. + ValueError: If `base_directory` is specified and `allow_relative` is + False. + """ + if base_directory is not None and not allow_relative: + raise ValueError( + "If `base_directory` is specified, then `allow_relative` must be True.", + ) + + parsed_path: pathlib.Path + if isinstance(path, pathlib.Path): + parsed_path = path + elif isinstance(path, str): + parsed_path = pathlib.Path(path) + elif isinstance(path, bytes): + parsed_path = pathlib.Path(path.decode()) + else: + parsed_path = pathlib.Path(str(path)) + + if parsed_path.is_absolute(): + return parsed_path + else: + if allow_relative: + base_directory = base_directory or pathlib.Path.cwd() + # relative to current working directory + return base_directory / parsed_path + else: + raise ValueError(f"Path {str(parsed_path)} is not absolute") + + +def parse_optional_path( + path: Optional[AnyPath], + allow_relative: bool = True, + base_directory: Optional[pathlib.Path] = None, +) -> Optional[pathlib.Path]: + """Parse an optional path to a `pathlib.Path` object. + + All resulting paths are resolved, absolute paths. If `allow_relative` is True, + then relative paths are allowed as input, and are resolved relative to the + current working directory, or relative to `base_directory` if it is + specified. + + Args: + path: The path to parse. Can be a string, bytes, or `os.PathLike`. + allow_relative: If True, then relative paths are allowed as input, and + are resolved relative to the current working directory. If False, + an error is raised if the path is not absolute. + base_directory: If specified, then relative paths are resolved relative + to this directory, instead of the current working directory. + + Returns: + A `pathlib.Path` object, or None if `path` is None. + """ + if path is None: + return None else: - return str(path) + return parse_path(path, allow_relative, base_directory) @dataclasses.dataclass(frozen=True) @@ -417,10 +493,10 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]): trajectories: The trajectories to save. Raises: - ValueError: If the trajectories are not all of the same type, i.e. some are + ValueError: If not all trajectories have the same type, i.e. some are `Trajectory` and others are `TrajectoryWithRew`. """ - p = pathlib.Path(path_to_str(path)) + p = parse_path(path) p.parent.mkdir(parents=True, exist_ok=True) tmp_path = f"{p}.tmp" diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index f0ee6d588..ffa3a281d 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -40,7 +40,7 @@ def _choose_action(self, obs: np.ndarray) -> np.ndarray: def forward(self, *args): # technically BasePolicy is a Torch module, so this needs a forward() # method - raise NotImplementedError() + raise NotImplementedError() # pragma: no cover class RandomPolicy(HardCodedPolicy): diff --git a/src/imitation/policies/serialize.py b/src/imitation/policies/serialize.py index 15b408412..e57d39e4a 100644 --- a/src/imitation/policies/serialize.py +++ b/src/imitation/policies/serialize.py @@ -4,13 +4,13 @@ # torch.load() and torch.save() calls import logging -import os import pathlib from typing import Callable, Type, TypeVar import huggingface_sb3 as hfsb3 from stable_baselines3.common import base_class, callbacks, policies, vec_env +from imitation.data import types from imitation.policies import base from imitation.util import registry @@ -52,7 +52,7 @@ def load_stable_baselines_model( The deserialized RL algorithm. """ logging.info(f"Loading Stable Baselines policy for '{cls}' from '{path}'") - path_obj = pathlib.Path(path) + path_obj = types.parse_path(path) if path_obj.is_dir(): path_obj = path_obj / "model.zip" @@ -66,7 +66,7 @@ def load_stable_baselines_model( if vec_normalize_path.exists(): raise FileExistsError( "Outdated policy format: we do not support restoring normalization " - "statistics from '{vec_normalize_path}'", + f"statistics from '{vec_normalize_path}'", ) return cls.load(path_obj, env=venv, **kwargs) @@ -181,7 +181,7 @@ def load_policy( def save_stable_model( - output_dir: str, + output_dir: pathlib.Path, model: base_class.BaseAlgorithm, filename: str = "model.zip", ) -> None: @@ -197,9 +197,9 @@ def save_stable_model( # Save each model in new directory in case we want to add metadata or other # information in future. (E.g. we used to save `VecNormalize` statistics here, # although that is no longer necessary.) - os.makedirs(output_dir, exist_ok=True) - model.save(os.path.join(output_dir, filename)) - logging.info("Saved policy to %s", output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + model.save(output_dir / filename) + logging.info(f"Saved policy to {output_dir}") class SavePolicyCallback(callbacks.EventCallback): @@ -211,7 +211,7 @@ class SavePolicyCallback(callbacks.EventCallback): def __init__( self, - policy_dir: str, + policy_dir: pathlib.Path, *args, **kwargs, ): @@ -227,6 +227,6 @@ def __init__( def _on_step(self) -> bool: assert self.model is not None - output_dir = os.path.join(self.policy_dir, f"{self.num_timesteps:012d}") + output_dir = self.policy_dir / f"{self.num_timesteps:012d}" save_stable_model(output_dir, self.model) return True diff --git a/src/imitation/scripts/analyze.py b/src/imitation/scripts/analyze.py index 3787ceb88..417799752 100644 --- a/src/imitation/scripts/analyze.py +++ b/src/imitation/scripts/analyze.py @@ -5,7 +5,7 @@ import json import logging import os -import os.path as osp +import pathlib import tempfile import warnings from collections import OrderedDict @@ -15,6 +15,7 @@ from sacred.observers import FileStorageObserver import imitation.util.sacred as sacred_util +from imitation.data import types from imitation.scripts.config.analyze import analysis_ex from imitation.util.sacred import dict_get_nested as get @@ -47,7 +48,8 @@ def _gather_sacred_dicts( # e.g. chain.from_iterable([["pathone", "pathtwo"], [], ["paththree"]]) => # ("pathone", "pathtwo", "paththree") sacred_dirs = itertools.chain.from_iterable( - sacred_util.filter_subdirs(source_dir) for source_dir in source_dirs + sacred_util.filter_subdirs(types.parse_path(source_dir)) + for source_dir in source_dirs ) sacred_dicts_list = [] @@ -98,17 +100,17 @@ def gather_tb_directories() -> dict: Raises: OSError: If the symlink cannot be created. """ - os.makedirs("/tmp/analysis_tb", exist_ok=True) - tmp_dir = tempfile.mkdtemp(dir="/tmp/analysis_tb/") + tb_analysis_dir = pathlib.Path("/tmp/analysis_tb") + tb_analysis_dir.mkdir(exist_ok=True) + tmp_dir = pathlib.Path(tempfile.mkdtemp(dir=tb_analysis_dir)) tb_dirs_count = 0 for sd in _gather_sacred_dicts(): # Expecting a path like "~/ray_results/{run_name}/sacred/1". # Want to search for all Tensorboard dirs inside # "~/ray_results/{run_name}". - sacred_dir = sd.sacred_dir.rstrip("/") - run_dir = osp.dirname(osp.dirname(sacred_dir)) - run_name = osp.basename(run_dir) + run_dir = sd.sacred_dir.parent.parent + run_name = run_dir.name # log is what we use as subdirectory in new code. # rl, tb, sb_tb all appear in old versions. @@ -116,19 +118,19 @@ def gather_tb_directories() -> dict: tb_src_dirs = tuple( sacred_util.filter_subdirs( run_dir, - lambda path: osp.basename(path) == basename, + lambda path: path.name == basename, ), ) if tb_src_dirs: assert len(tb_src_dirs) == 1, "expect at most one TB dir of each type" tb_src_dir = tb_src_dirs[0] - symlinks_dir = osp.join(tmp_dir, basename) - os.makedirs(symlinks_dir, exist_ok=True) + symlinks_dir = tmp_dir / basename + symlinks_dir.mkdir(exist_ok=True) - tb_symlink = osp.join(symlinks_dir, run_name) + tb_symlink = symlinks_dir / run_name try: - os.symlink(tb_src_dir, tb_symlink) + tb_symlink.symlink_to(tb_src_dir) except OSError as e: if os.name == "nt": # Windows msg = ( @@ -318,7 +320,8 @@ def _make_return_summary(stats: dict, prefix="") -> str: def main_console(): - observer = FileStorageObserver(osp.join("output", "sacred", "analyze")) + observer_path = pathlib.Path.cwd() / "output" / "sacred" / "analyze" + observer = FileStorageObserver(observer_path) analysis_ex.observers.append(observer) analysis_ex.run_commandline() diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index 532e8780b..72d44f2f4 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -2,13 +2,14 @@ import contextlib import logging -import os +import pathlib from typing import Any, Generator, Mapping, Sequence, Tuple, Union import numpy as np import sacred from stable_baselines3.common import vec_env +from imitation.data import types from imitation.scripts.common import wb from imitation.util import logger as imit_logger from imitation.util import sacred as sacred_util @@ -45,18 +46,15 @@ def update_log_format_strs(log_format_strs, log_format_strs_additional): @common_ingredient.config_hook -def hook(config, command_name, logger): +def hook(config, command_name: str, logger): del logger updates = {} if config["common"]["log_dir"] is None: env_sanitized = config["common"]["env_name"].replace("/", "_") - log_root = config["common"]["log_root"] or "output" - log_dir = os.path.join( - log_root, - command_name, - env_sanitized, - util.make_unique_timestamp(), - ) + assert isinstance(env_sanitized, str) + config_log_root = config["common"]["log_root"] or "output" + log_root = types.parse_path(config_log_root) + log_dir = log_root / command_name / env_sanitized / util.make_unique_timestamp() updates["log_dir"] = log_dir return updates @@ -86,7 +84,7 @@ def make_log_dir( _run, log_dir: str, log_level: Union[int, str], -) -> str: +) -> pathlib.Path: """Creates log directory and sets up symlink to Sacred logs. Args: @@ -98,23 +96,24 @@ def make_log_dir( Returns: The `log_dir`. This avoids the caller needing to capture this argument. """ - os.makedirs(log_dir, exist_ok=True) + parsed_log_dir = types.parse_path(log_dir) + parsed_log_dir.mkdir(parents=True, exist_ok=True) # convert strings of digits to numbers; but leave levels like 'INFO' unmodified try: log_level = int(log_level) except ValueError: pass logging.basicConfig(level=log_level) - logger.info("Logging to %s", log_dir) - sacred_util.build_sacred_symlink(log_dir, _run) - return log_dir + logger.info("Logging to %s", parsed_log_dir) + sacred_util.build_sacred_symlink(parsed_log_dir, _run) + return parsed_log_dir @common_ingredient.capture def setup_logging( _run, log_format_strs: Sequence[str], -) -> Tuple[imit_logger.HierarchicalLogger, str]: +) -> Tuple[imit_logger.HierarchicalLogger, pathlib.Path]: """Builds the imitation logger. Args: @@ -126,9 +125,9 @@ def setup_logging( """ log_dir = make_log_dir() if "wandb" in log_format_strs: - wb.wandb_init(log_dir=log_dir) + wb.wandb_init(log_dir=str(log_dir)) custom_logger = imit_logger.configure( - folder=os.path.join(log_dir, "log"), + folder=log_dir / "log", format_strs=log_format_strs, ) return custom_logger, log_dir diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index ae3add76f..b9ede3165 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -6,7 +6,11 @@ train_rl_ex = sacred.Experiment( "train_rl", - ingredients=[common.common_ingredient, train.train_ingredient, rl.rl_ingredient], + ingredients=[ + common.common_ingredient, + train.train_ingredient, + rl.rl_ingredient, + ], ) diff --git a/src/imitation/scripts/convert_trajs.py b/src/imitation/scripts/convert_trajs.py index 5ae4eb0a8..85db4d3f9 100644 --- a/src/imitation/scripts/convert_trajs.py +++ b/src/imitation/scripts/convert_trajs.py @@ -9,21 +9,22 @@ (i.e. "A.pkl" -> "A.npz", "A.npz" -> "A.npz", "A" -> "A.npz", "A.foo" -> "A.foo.npz"). """ -import os import warnings from imitation.data import types -def update_traj_file_in_place(path: str) -> None: +def update_traj_file_in_place(path_str: str, /) -> None: """Modifies trajectories pickle file in-place to update data to new format. The new data is saved as `Sequence[imitation.types.TrajectoryWithRew]`. Args: - path: Path to a pickle file containing `Sequence[imitation.types.Trajectory]` + path_str: Path to a pickle file containing + `Sequence[imitation.types.Trajectory]` or `Sequence[imitation.old_types.TrajectoryWithRew]`. """ + path = types.parse_path(path_str) with warnings.catch_warnings(): # Filter out DeprecationWarning because we expect to load old trajectories here. warnings.filterwarnings( @@ -33,9 +34,9 @@ def update_traj_file_in_place(path: str) -> None: ) trajs = types.load(path) - path, ext = os.path.splitext(path) + ext = path.suffix new_ext = ".npz" if ext in (".pkl", ".npz") else ext + ".npz" - types.save(path + new_ext, trajs) + types.save(path.with_suffix(new_ext), trajs) def main(): diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 610882fd5..06eab4820 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -1,8 +1,7 @@ """Evaluate policies: render policy interactively, save videos, log episode return.""" import logging -import os -import os.path as osp +import pathlib import time from typing import Any, Mapping, Optional @@ -39,12 +38,12 @@ def step_wait(self): return ob -def video_wrapper_factory(log_dir: str, **kwargs): +def video_wrapper_factory(log_dir: pathlib.Path, **kwargs): """Returns a function that wraps the environment in a video recorder.""" - def f(env: gym.Env, i: int) -> gym.Env: + def f(env: gym.Env, i: int) -> video_wrapper.VideoWrapper: """Wraps `env` in a recorder saving videos to `{log_dir}/videos/{i}`.""" - directory = os.path.join(log_dir, "videos", str(i)) + directory = log_dir / "videos" / str(i) return video_wrapper.VideoWrapper(env, directory=directory, **kwargs) return f @@ -105,13 +104,14 @@ def eval_policy( ) if rollout_save_path: - types.save(rollout_save_path.replace("{log_dir}", log_dir), trajs) + types.save(log_dir / rollout_save_path.replace("{log_dir}/", ""), trajs) return rollout.rollout_stats(trajs) def main_console(): - observer = FileStorageObserver(osp.join("output", "sacred", "eval_policy")) + observer_path = pathlib.Path.cwd() / "output" / "sacred" / "eval_policy" + observer = FileStorageObserver(observer_path) eval_policy_ex.observers.append(observer) eval_policy_ex.run_commandline() diff --git a/src/imitation/scripts/parallel.py b/src/imitation/scripts/parallel.py index c30ad1149..325646f0f 100644 --- a/src/imitation/scripts/parallel.py +++ b/src/imitation/scripts/parallel.py @@ -2,7 +2,7 @@ import collections.abc import copy -import os +import pathlib from typing import Any, Callable, Dict, Mapping, Optional, Sequence import ray @@ -203,7 +203,8 @@ def inner(config: Mapping[str, Any], reporter) -> Mapping[str, Any]: def main_console(): - observer = FileStorageObserver(os.path.join("output", "sacred", "parallel")) + observer_path = pathlib.Path.cwd() / "output" / "sacred" / "parallel" + observer = FileStorageObserver(observer_path) parallel_ex.observers.append(observer) parallel_ex.run_commandline() diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 11db52341..b84aec720 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -2,8 +2,7 @@ import functools import logging -import os -import os.path as osp +import pathlib from typing import Any, Mapping, Optional, Type import sacred.commands @@ -22,15 +21,15 @@ logger = logging.getLogger("imitation.scripts.train_adversarial") -def save(trainer, save_path): +def save(trainer: common.AdversarialTrainer, save_path: pathlib.Path): """Save discriminator and generator.""" # We implement this here and not in Trainer since we do not want to actually # serialize the whole Trainer (including e.g. expert demonstrations). - os.makedirs(save_path, exist_ok=True) - th.save(trainer.reward_train, os.path.join(save_path, "reward_train.pt")) - th.save(trainer.reward_test, os.path.join(save_path, "reward_test.pt")) + save_path.mkdir(parents=True, exist_ok=True) + th.save(trainer.reward_train, save_path / "reward_train.pt") + th.save(trainer.reward_test, save_path / "reward_test.pt") serialize.save_stable_model( - os.path.join(save_path, "gen_policy"), + save_path / "gen_policy", trainer.gen_algo, ) @@ -148,16 +147,16 @@ def train_adversarial( **algorithm_kwargs, ) - def callback(round_num): + def callback(round_num: int, /) -> None: if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: - save(trainer, os.path.join(log_dir, "checkpoints", f"{round_num:05d}")) + save(trainer, log_dir / "checkpoints" / f"{round_num:05d}") trainer.train(total_timesteps, callback) imit_stats = train.eval_policy(trainer.policy, trainer.venv_train) # Save final artifacts. if checkpoint_interval >= 0: - save(trainer, os.path.join(log_dir, "checkpoints", "final")) + save(trainer, log_dir / "checkpoints" / "final") return { "imit_stats": imit_stats, @@ -176,7 +175,8 @@ def airl(): def main_console(): - observer = FileStorageObserver(osp.join("output", "sacred", "train_adversarial")) + observer_path = pathlib.Path.cwd() / "output" / "sacred" / "train_adversarial" + observer = FileStorageObserver(observer_path) train_adversarial_ex.observers.append(observer) train_adversarial_ex.run_commandline() diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index d7d7cdfe9..09393366e 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -2,6 +2,7 @@ import logging import os.path as osp +import pathlib import warnings from typing import Any, Mapping, Optional, Sequence, Type, cast @@ -168,7 +169,8 @@ def dagger() -> Mapping[str, Mapping[str, float]]: def main_console(): - observer = FileStorageObserver(osp.join("output", "sacred", "train_dagger")) + observer_path = pathlib.Path.cwd() / "output" / "sacred" / "train_dagger" + observer = FileStorageObserver(observer_path) train_imitation_ex.observers.append(observer) train_imitation_ex.run_commandline() diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 006c5854f..5e6677a2d 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -5,7 +5,7 @@ """ import functools -import os +import pathlib from typing import Any, Mapping, Optional, Type, Union import torch as th @@ -25,23 +25,23 @@ def save_model( agent_trainer: preference_comparisons.AgentTrainer, - save_path: str, + save_path: pathlib.Path, ): """Save the model as model.pkl.""" serialize.save_stable_model( - output_dir=os.path.join(save_path, "policy"), + output_dir=save_path / "policy", model=agent_trainer.algorithm, ) def save_checkpoint( trainer: preference_comparisons.PreferenceComparisons, - save_path: str, + save_path: pathlib.Path, allow_save_policy: Optional[bool], ): """Save reward model and optionally policy.""" - os.makedirs(save_path, exist_ok=True) - th.save(trainer.model, os.path.join(save_path, "reward_net.pt")) + save_path.mkdir(parents=True, exist_ok=True) + th.save(trainer.model, save_path / "reward_net.pt") if allow_save_policy: # Note: We should only save the model as model.pkl if `trajectory_generator` # contains one. Specifically we check if the `trajectory_generator` contains an @@ -244,11 +244,7 @@ def save_callback(iteration_num): if checkpoint_interval > 0 and iteration_num % checkpoint_interval == 0: save_checkpoint( trainer=main_trainer, - save_path=os.path.join( - log_dir, - "checkpoints", - f"{iteration_num:04d}", - ), + save_path=log_dir / "checkpoints" / f"{iteration_num:04d}", allow_save_policy=bool(trajectory_path is None), ) @@ -264,13 +260,13 @@ def save_callback(iteration_num): results["rollout"] = train.eval_policy(agent, venv) if save_preferences: - main_trainer.dataset.save(os.path.join(log_dir, "preferences.pkl")) + main_trainer.dataset.save(log_dir / "preferences.pkl") # Save final artifacts. if checkpoint_interval >= 0: save_checkpoint( trainer=main_trainer, - save_path=os.path.join(log_dir, "checkpoints", "final"), + save_path=log_dir / "checkpoints" / "final", allow_save_policy=bool(trajectory_path is None), ) @@ -278,9 +274,10 @@ def save_callback(iteration_num): def main_console(): - observer = FileStorageObserver( - os.path.join("output", "sacred", "train_preference_comparisons"), + observer_path = ( + pathlib.Path.cwd() / "output" / "sacred" / "train_preference_comparisons" ) + observer = FileStorageObserver(observer_path) train_preference_comparisons_ex.observers.append(observer) train_preference_comparisons_ex.run_commandline() diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index d68778d1a..5979c2608 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -9,8 +9,7 @@ """ import logging -import os -import os.path as osp +import pathlib import warnings from typing import Any, Mapping, Optional @@ -89,10 +88,10 @@ def train_rl( """ rng = common.make_rng() custom_logger, log_dir = common.setup_logging() - rollout_dir = osp.join(log_dir, "rollouts") - policy_dir = osp.join(log_dir, "policies") - os.makedirs(rollout_dir, exist_ok=True) - os.makedirs(policy_dir, exist_ok=True) + rollout_dir = log_dir / "rollouts" + policy_dir = log_dir / "policies" + rollout_dir.mkdir(parents=True, exist_ok=True) + policy_dir.mkdir(parents=True, exist_ok=True) post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] with common.make_venv(post_wrappers=post_wrappers) as venv: @@ -140,7 +139,7 @@ def train_rl( # Save final artifacts after training is complete. if rollout_save_final: - save_path = osp.join(rollout_dir, "final.pkl") + save_path = rollout_dir / "final.pkl" sample_until = rollout.make_sample_until( rollout_save_n_timesteps, rollout_save_n_episodes, @@ -150,7 +149,7 @@ def train_rl( rollout.rollout(rl_algo, rl_algo.get_env(), sample_until, rng=rng), ) if policy_save_final: - output_dir = os.path.join(policy_dir, "final") + output_dir = policy_dir / "final" serialize.save_stable_model(output_dir, rl_algo) # Final evaluation of expert policy. @@ -158,7 +157,8 @@ def train_rl( def main_console(): - observer = FileStorageObserver(osp.join("output", "sacred", "train_rl")) + observer_path = pathlib.Path.cwd() / "output" / "sacred" / "train_rl" + observer = FileStorageObserver(observer_path) train_rl_ex.observers.append(observer) train_rl_ex.run_commandline() diff --git a/src/imitation/util/logger.py b/src/imitation/util/logger.py index 77e54df8c..70190b1fb 100644 --- a/src/imitation/util/logger.py +++ b/src/imitation/util/logger.py @@ -3,6 +3,7 @@ import contextlib import datetime import os +import pathlib import sys import tempfile from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union @@ -43,7 +44,7 @@ def make_output_format( def _build_output_formats( - folder: str, + folder: pathlib.Path, format_strs: Sequence[str], ) -> Sequence[sb_logger.KVWriter]: """Build output formats for initializing a Stable Baselines Logger. @@ -56,13 +57,13 @@ def _build_output_formats( Returns: A list of output formats, one corresponding to each `format_strs`. """ - os.makedirs(folder, exist_ok=True) + folder.mkdir(parents=True, exist_ok=True) output_formats: List[sb_logger.KVWriter] = [] for f in format_strs: if f == "wandb": output_formats.append(WandbOutputFormat()) else: - output_formats.append(make_output_format(f, folder)) + output_formats.append(make_output_format(f, str(folder))) return output_formats @@ -266,11 +267,12 @@ def accumulate_means(self, name: str) -> Generator[None, None, None]: if subdir in self._cached_loggers: logger = self._cached_loggers[subdir] else: - assert self.default_logger.dir is not None - folder = os.path.join(self.default_logger.dir, "raw", subdir) - os.makedirs(folder, exist_ok=True) + default_logger_dir = self.default_logger.dir + assert default_logger_dir is not None + folder = types.parse_path(default_logger_dir) / "raw" / subdir + folder.mkdir(exist_ok=True, parents=True) output_formats = _build_output_formats(folder, self.format_strs) - logger = sb_logger.Logger(folder, list(output_formats)) + logger = sb_logger.Logger(str(folder), list(output_formats)) self._cached_loggers[subdir] = logger try: @@ -400,15 +402,16 @@ def configure( The configured HierarchicalLogger instance. """ if folder is None: + tempdir = types.parse_path(tempfile.gettempdir()) now = datetime.datetime.now() timestamp = now.strftime("imitation-%Y-%m-%d-%H-%M-%S-%f") - folder = os.path.join(tempfile.gettempdir(), timestamp) + folder = tempdir / timestamp else: - folder = types.path_to_str(folder) + folder = types.parse_path(folder) if format_strs is None: format_strs = ["stdout", "log", "csv"] output_formats = _build_output_formats(folder, format_strs) - default_logger = sb_logger.Logger(folder, list(output_formats)) + default_logger = sb_logger.Logger(str(folder), list(output_formats)) hier_format_strs = [f for f in format_strs if f != "wandb"] hier_logger = HierarchicalLogger(default_logger, hier_format_strs) return hier_logger diff --git a/src/imitation/util/sacred.py b/src/imitation/util/sacred.py index b96f872df..2c8ee421a 100644 --- a/src/imitation/util/sacred.py +++ b/src/imitation/util/sacred.py @@ -4,7 +4,7 @@ import os import pathlib import warnings -from typing import Any, Callable, NamedTuple, Sequence, Union +from typing import Any, Callable, NamedTuple, Optional, Sequence import sacred import sacred.observers @@ -16,35 +16,31 @@ class SacredDicts(NamedTuple): """Each dict `foo` is loaded from `f"{sacred_dir}/foo.json"`.""" - sacred_dir: str + sacred_dir: pathlib.Path config: dict run: dict @classmethod - def load_from_dir(cls, sacred_dir: str): - args = [] - for field in cls._fields: - if field == "sacred_dir": - args.append(sacred_dir) - else: - json_path = os.path.join(sacred_dir, f"{field}.json") - with open(json_path, "r") as f: - args.append(json.load(f)) - return cls(*args) + def load_from_dir(cls, sacred_dir: pathlib.Path): + return cls( + sacred_dir=sacred_dir, + config=json.loads((sacred_dir / "config.json").read_text()), + run=json.loads((sacred_dir / "run.json").read_text()), + ) -def dir_contains_sacred_jsons(dir_path: str) -> bool: - run_path = os.path.join(dir_path, "run.json") - config_path = os.path.join(dir_path, "config.json") - return os.path.isfile(run_path) and os.path.isfile(config_path) +def dir_contains_sacred_jsons(dir_path: pathlib.Path) -> bool: + run_path = dir_path / "run.json" + config_path = dir_path / "config.json" + return run_path.is_file() and config_path.is_file() def filter_subdirs( - root_dir: str, - filter_fn: Callable[[str], bool] = dir_contains_sacred_jsons, + root_dir: pathlib.Path, + filter_fn: Callable[[pathlib.Path], bool] = dir_contains_sacred_jsons, *, nested_ok: bool = False, -) -> Sequence[str]: +) -> Sequence[pathlib.Path]: """Walks through a directory tree, returning paths to filtered subdirectories. Does not follow symlinks. @@ -64,31 +60,30 @@ def filter_subdirs( paths is a subdirecotry of another. """ filtered_dirs = set() - for root, _, _ in os.walk(root_dir, followlinks=False): + for root_str, _, _ in os.walk(root_dir, followlinks=False): + root = pathlib.Path(root_str) if filter_fn(root): filtered_dirs.add(root) if not nested_ok: for dirpath in filtered_dirs: - components = os.path.split(dirpath) - for i in range(1, len(components)): - prefix = os.path.join(*components[0:i]) - if prefix in filtered_dirs: - raise ValueError(f"Parent {prefix} to {dir} also a dir directory") + for other_dirpath in filtered_dirs: + if dirpath != other_dirpath and other_dirpath in dirpath.parents: + raise ValueError( + f"Found nested directories: {dirpath} and {other_dirpath}", + ) return list(filtered_dirs) def build_sacred_symlink(log_dir: types.AnyPath, run: sacred.run.Run) -> None: """Constructs a symlink "{log_dir}/sacred" => "${SACRED_PATH}".""" - if isinstance(log_dir, bytes): - log_dir = log_dir.decode("utf-8") - log_dir = pathlib.Path(log_dir) + log_dir = types.parse_path(log_dir) sacred_dir = get_sacred_dir_from_run(run) if sacred_dir is None: warnings.warn(RuntimeWarning("Couldn't find sacred directory.")) return - symlink_path = pathlib.Path(log_dir, "sacred") + symlink_path = log_dir / "sacred" target_path = pathlib.Path(os.path.relpath(sacred_dir, start=log_dir)) # Path.symlink_to errors if the symlink already exists. In our case, we actually @@ -116,11 +111,11 @@ def build_sacred_symlink(log_dir: types.AnyPath, run: sacred.run.Run) -> None: raise e -def get_sacred_dir_from_run(run: sacred.run.Run) -> Union[pathlib.Path, None]: +def get_sacred_dir_from_run(run: sacred.run.Run) -> Optional[pathlib.Path]: """Returns path to the sacred directory, or None if not found.""" for obs in run.observers: if isinstance(obs, sacred.observers.FileStorageObserver): - return pathlib.Path(obs.dir) + return types.parse_path(obs.dir) return None diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index 6da02aa70..a59641aa1 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -1,13 +1,11 @@ """Wrapper to record rendered video frames from an environment.""" -import os +import pathlib from typing import Optional import gym from gym.wrappers.monitoring import video_recorder -from imitation.data import types - class VideoWrapper(gym.Wrapper): """Creates videos from wrapped environment by calling render after each timestep.""" @@ -15,12 +13,12 @@ class VideoWrapper(gym.Wrapper): episode_id: int video_recorder: Optional[video_recorder.VideoRecorder] single_video: bool - directory: str + directory: pathlib.Path def __init__( self, env: gym.Env, - directory: types.AnyPath, + directory: pathlib.Path, single_video: bool = True, ): """Builds a VideoWrapper. @@ -39,8 +37,8 @@ def __init__( self.video_recorder = None self.single_video = single_video - self.directory = str(os.path.abspath(directory)) - os.makedirs(self.directory) + self.directory = directory + self.directory.mkdir(parents=True, exist_ok=True) def _reset_video_recorder(self) -> None: """Creates a video recorder if one does not already exist. @@ -59,10 +57,7 @@ def _reset_video_recorder(self) -> None: # No video recorder -- start a new one. self.video_recorder = video_recorder.VideoRecorder( env=self.env, - base_path=os.path.join( - self.directory, - "video.{:06}".format(self.episode_id), - ), + base_path=str(self.directory / f"video.{self.episode_id:06}"), metadata={"episode_id": self.episode_id}, ) diff --git a/tests/data/test_rollout.py b/tests/data/test_rollout.py index 34c136bcc..db125908a 100644 --- a/tests/data/test_rollout.py +++ b/tests/data/test_rollout.py @@ -300,7 +300,7 @@ def test_generate_trajectories_type_error(rng): sample_until = rollout.make_min_episodes(1) with pytest.raises(TypeError, match="Policy must be.*got instead"): rollout.generate_trajectories( - "strings_are_not_valid_policies", + "strings_are_not_valid_policies", # type: ignore[arg-type] venv, rng=rng, sample_until=sample_until, diff --git a/tests/data/test_types.py b/tests/data/test_types.py index dbfacf3b2..f7223b93e 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -110,7 +110,7 @@ def _check_transitions_get_item(trans, key): @contextlib.contextmanager def pushd(dir_path): """Change directory temporarily inside context.""" - orig_dir = os.getcwd() + orig_dir = pathlib.Path.cwd() try: os.chdir(dir_path) yield @@ -212,15 +212,16 @@ def test_save_trajectories( if use_chdir: # Test no relative path without directory edge-case. chdir_context = pushd(tmpdir) - save_dir = "" + save_dir_str = "" else: chdir_context = contextlib.nullcontext() - save_dir = tmpdir - - trajs = [trajectory_rew if use_rewards else trajectory] - save_path = pathlib.Path(save_dir, "trajs") + save_dir_str = tmpdir with chdir_context: + save_dir = types.parse_path(save_dir_str) + trajs = [trajectory_rew if use_rewards else trajectory] + save_path = save_dir / "trajs" + if use_pickle: # Pickle format with open(save_path, "wb") as f: @@ -389,12 +390,46 @@ def test_zero_length_fails(): types.Trajectory(obs=np.array([42]), acts=empty, infos=None, terminal=True) -def test_path_to_str(): - assert types.path_to_str("") == "" - assert types.path_to_str(b"") == "" - assert types.path_to_str("foo") == "foo" - assert types.path_to_str(b"foo") == "foo" - assert types.path_to_str(pathlib.Path("foo")) == "foo" - assert types.path_to_str("/foo/bar") == "/foo/bar" - assert types.path_to_str(b"/foo/bar") == "/foo/bar" - assert types.path_to_str(pathlib.Path("/foo", "bar")) +def test_parse_path(): + if os.name == "nt": # pragma: no cover + pytest.skip( + "Windows uses path.WindowsPath instead when paths are resolved, which" + "cannot be compared directly to pathlib.Path objects.", + ) + # absolute paths + assert types.parse_path("/foo/bar") == pathlib.Path("/foo/bar") + assert types.parse_path(pathlib.Path("/foo/bar")) == pathlib.Path("/foo/bar") + assert types.parse_path(b"/foo/bar") == pathlib.Path("/foo/bar") + + # relative paths. implicit conversion to cwd + assert types.parse_path("foo/bar") == pathlib.Path.cwd() / "foo/bar" + assert types.parse_path(pathlib.Path("foo/bar")) == pathlib.Path.cwd() / "foo/bar" + assert types.parse_path(b"foo/bar") == pathlib.Path.cwd() / "foo/bar" + + # relative paths. conversion using custom base directory + base_dir = pathlib.Path("/foo/bar") + assert types.parse_path("baz", base_directory=base_dir) == base_dir / "baz" + assert ( + types.parse_path(pathlib.Path("baz"), base_directory=base_dir) + == base_dir / "baz" + ) + assert types.parse_path(b"baz", base_directory=base_dir) == base_dir / "baz" + + # pass a relative path but disallowing relative paths. should raise error. + with pytest.raises(ValueError, match="Path .* is not absolute"): + types.parse_path("foo/bar", allow_relative=False) + + # pass a base direectory but disallowing relative paths. should raise error. + with pytest.raises( + ValueError, + match="If `base_directory` is specified, then `allow_relative` must be True.", + ): + types.parse_path( + "foo/bar", + base_directory=pathlib.Path("/foo/bar"), + allow_relative=False, + ) + + # Parse optional path. Works the same way but passes None down the line. + assert types.parse_optional_path(None) is None + assert types.parse_optional_path("/foo/bar") == types.parse_path("/foo/bar") diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index c7953a006..d819cc6a1 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -1,7 +1,6 @@ """Tests `imitation.policies.*`.""" import functools -import pathlib from typing import cast import gym @@ -11,7 +10,7 @@ from stable_baselines3.common import preprocessing from torch import nn -from imitation.data import rollout +from imitation.data import rollout, types from imitation.policies import base, serialize from imitation.util import registry, util @@ -59,7 +58,7 @@ def test_save_stable_model_errors_and_warnings( ): """Check errors and warnings in `save_stable_model()`.""" policy, env_name = policy_env_name_pair - tmpdir = pathlib.Path(tmpdir) + tmpdir = types.parse_path(tmpdir) venv = util.make_vec_env(env_name, rng=rng) # Trigger FileNotFoundError for no model.{zip,pkl} @@ -104,7 +103,7 @@ def _test_serialize_identity(env_name, model_cfg, tmpdir, rng): rng=np.random.default_rng(0), ) - serialize.save_stable_model(tmpdir, model) + serialize.save_stable_model(types.parse_path(tmpdir), model) loaded = serialize.load_policy(model_name, venv, path=tmpdir) venv.env_method("seed", 0) venv.reset() diff --git a/tests/policies/test_replay_buffer_wrapper.py b/tests/policies/test_replay_buffer_wrapper.py index 0eeee334b..40fc6eac5 100644 --- a/tests/policies/test_replay_buffer_wrapper.py +++ b/tests/policies/test_replay_buffer_wrapper.py @@ -56,7 +56,7 @@ def test_invalid_args(rng): # we ignore the type because we are intentionally # passing the wrong type for the test make_algo_with_wrapped_buffer( - rl_cls=sb3.PPO, + rl_cls=sb3.PPO, # type: ignore[arg-type] policy_cls=policies.ActorCriticPolicy, replay_buffer_class=buffers.ReplayBuffer, rng=rng, diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 5b8959312..3b90628e0 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -53,7 +53,14 @@ train_rl, ] -TEST_DATA_PATH = pathlib.Path("tests/testdata") +TEST_DATA_PATH = types.parse_path("tests/testdata") + +if not TEST_DATA_PATH.exists(): # pragma: no cover + raise RuntimeError( + "Folder with test data has not been found. Make sure you are " + "running tests relative to the base imitation project folder.", + ) + CARTPOLE_TEST_DATA_PATH = TEST_DATA_PATH / "expert_models/cartpole_0/" CARTPOLE_TEST_ROLLOUT_PATH = CARTPOLE_TEST_DATA_PATH / "rollouts/final.pkl" CARTPOLE_TEST_POLICY_PATH = CARTPOLE_TEST_DATA_PATH / "policies/final" @@ -279,7 +286,7 @@ def test_train_dagger_warmstart(tmpdir): ) assert run.status == "COMPLETED" - log_dir = pathlib.Path(run.config["common"]["log_dir"]) + log_dir = types.parse_path(run.config["common"]["log_dir"]) policy_path = log_dir / "scratch" / "policy-latest.pt" run_warmstart = train_imitation.train_imitation_ex.run( command_name="dagger", @@ -354,7 +361,7 @@ def test_train_bc_warmstart(tmpdir): assert run.status == "COMPLETED" assert isinstance(run.result, dict) - policy_path = pathlib.Path(run.config["common"]["log_dir"]) / "final.th" + policy_path = types.parse_path(run.config["common"]["log_dir"]) / "final.th" run_warmstart = train_imitation.train_imitation_ex.run( command_name="bc", named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], @@ -516,7 +523,7 @@ def test_train_adversarial_warmstart(tmpdir, command): config_updates=config_updates, ) - log_dir = pathlib.Path(run.config["common"]["log_dir"]) + log_dir = types.parse_path(run.config["common"]["log_dir"]) policy_path = log_dir / "checkpoints" / "final" / "gen_policy" run_warmstart = train_adversarial.train_adversarial_ex.run( @@ -601,7 +608,7 @@ def test_transfer_learning(tmpdir: str) -> None: Args: tmpdir: Temporary directory to save results to. """ - tmpdir_path = pathlib.Path(tmpdir) + tmpdir_path = types.parse_path(tmpdir) log_dir_train = tmpdir_path / "train" run = train_adversarial.train_adversarial_ex.run( command_name="airl", @@ -650,7 +657,7 @@ def test_preference_comparisons_transfer_learning( tmpdir: Temporary directory to save results to. named_configs_dict: Named configs for preference_comparisons and rl. """ - tmpdir_path = pathlib.Path(tmpdir) + tmpdir_path = types.parse_path(tmpdir) log_dir_train = tmpdir_path / "train" run = train_preference_comparisons.train_preference_comparisons_ex.run( @@ -791,7 +798,7 @@ def test_parallel_arg_errors(tmpdir): def _generate_test_rollouts(tmpdir: str, env_named_config: str) -> pathlib.Path: - tmpdir_path = pathlib.Path(tmpdir) + tmpdir_path = types.parse_path(tmpdir) train_rl.train_rl_ex.run( named_configs=[env_named_config] + ALGO_FAST_CONFIGS["rl"], config_updates=dict( @@ -863,7 +870,7 @@ def _run_train_bc_for_test_analyze_imit(run_name, sacred_logs_dir, log_dir): ), ) def test_analyze_imitation(tmpdir: str, run_names: List[str], run_sacred_fn): - sacred_logs_dir = tmpdir_path = pathlib.Path(tmpdir) + sacred_logs_dir = tmpdir_path = types.parse_path(tmpdir) # Generate sacred logs (other logs are put in separate tmpdir for deletion). for run_name in run_names: