Skip to content

Commit

Permalink
Merge branch 'master' of github.com:HumanCompatibleAI/imitation into …
Browse files Browse the repository at this point in the history
…imitation-envs-to-seals
  • Loading branch information
Rocamonde committed Oct 11, 2022
2 parents 49094c7 + 531fa06 commit 7376f24
Show file tree
Hide file tree
Showing 33 changed files with 317 additions and 211 deletions.
1 change: 0 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ executors:
(?x)(
src/imitation/algorithms/preference_comparisons.py$
| src/imitation/rewards/reward_nets.py$
| src/imitation/util/sacred.py$
| src/imitation/algorithms/base.py$
| src/imitation/scripts/train_preference_comparisons.py$
| src/imitation/rewards/serialize.py$
Expand Down
1 change: 0 additions & 1 deletion ci/code_checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ SRC_FILES=(src/ tests/ experiments/ examples/ docs/conf.py setup.py ci/)
EXCLUDE_MYPY="(?x)(
src/imitation/algorithms/preference_comparisons.py$
| src/imitation/rewards/reward_nets.py$
| src/imitation/util/sacred.py$
| src/imitation/algorithms/base.py$
| src/imitation/scripts/train_preference_comparisons.py$
| src/imitation/rewards/serialize.py$
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/1_train_bc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
2 changes: 1 addition & 1 deletion docs/tutorials/3_train_gail.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
2 changes: 1 addition & 1 deletion docs/tutorials/4_train_airl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
2 changes: 1 addition & 1 deletion docs/tutorials/5_train_preference_comparisons.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
2 changes: 1 addition & 1 deletion docs/tutorials/7_train_density.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
13 changes: 6 additions & 7 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import collections
import dataclasses
import logging
import os
from typing import (
Callable,
Iterable,
Expand Down Expand Up @@ -127,7 +126,7 @@ def __init__(
gen_algo: base_class.BaseAlgorithm,
reward_net: reward_nets.RewardNet,
n_disc_updates_per_round: int = 2,
log_dir: str = "output/",
log_dir: types.AnyPath = "output/",
disc_opt_cls: Type[th.optim.Optimizer] = th.optim.Adam,
disc_opt_kwargs: Optional[Mapping] = None,
gen_train_timesteps: Optional[int] = None,
Expand Down Expand Up @@ -202,7 +201,7 @@ def __init__(
self.venv = venv
self.gen_algo = gen_algo
self._reward_net = reward_net.to(gen_algo.device)
self._log_dir = log_dir
self._log_dir = types.parse_path(log_dir)

# Create graph for optimising/recording stats on discriminator
self._disc_opt_cls = disc_opt_cls
Expand All @@ -215,10 +214,10 @@ def __init__(
)

if self._init_tensorboard:
logging.info("building summary directory at " + self._log_dir)
summary_dir = os.path.join(self._log_dir, "summary")
os.makedirs(summary_dir, exist_ok=True)
self._summary_writer = thboard.SummaryWriter(summary_dir)
logging.info(f"building summary directory at {self._log_dir}")
summary_dir = self._log_dir / "summary"
summary_dir.mkdir(parents=True, exist_ok=True)
self._summary_writer = thboard.SummaryWriter(str(summary_dir))

self.venv_buffering = wrappers.BufferingWrapper(self.venv)

Expand Down
2 changes: 1 addition & 1 deletion src/imitation/algorithms/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,4 +472,4 @@ def save_policy(self, policy_path: types.AnyPath) -> None:
Args:
policy_path: path to save policy to.
"""
th.save(self.policy, types.path_to_str(policy_path))
th.save(self.policy, types.parse_path(policy_path))
22 changes: 8 additions & 14 deletions src/imitation/algorithms/dagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,8 @@ def reconstruct_trainer(
A deserialized `DAggerTrainer`.
"""
custom_logger = custom_logger or imit_logger.configure()
checkpoint_path = pathlib.Path(
types.path_to_str(scratch_dir),
"checkpoint-latest.pt",
)
scratch_dir = types.parse_path(scratch_dir)
checkpoint_path = scratch_dir / "checkpoint-latest.pt"
trainer = th.load(checkpoint_path, map_location=utils.get_device(device))
trainer.venv = venv
trainer._logger = custom_logger
Expand All @@ -109,14 +107,14 @@ def _save_dagger_demo(
# however that NPZ save here is likely more space efficient than
# pickle from types.save(), and types.save only accepts
# TrajectoryWithRew right now (subclass of Trajectory).
save_dir_obj = pathlib.Path(types.path_to_str(save_dir))
save_dir = types.parse_path(save_dir)
assert isinstance(trajectory, types.Trajectory)
actual_prefix = f"{prefix}-" if prefix else ""
timestamp = util.make_unique_timestamp()
filename = f"{actual_prefix}dagger-demo-{timestamp}.npz"

save_dir_obj.mkdir(parents=True, exist_ok=True)
npz_path = save_dir_obj / filename
save_dir.mkdir(parents=True, exist_ok=True)
npz_path = save_dir / filename
np.savez_compressed(npz_path, **dataclasses.asdict(trajectory))
logging.info(f"Saved demo at '{npz_path}'")

Expand Down Expand Up @@ -344,7 +342,7 @@ def __init__(
if beta_schedule is None:
beta_schedule = LinearBetaSchedule(15)
self.beta_schedule = beta_schedule
self.scratch_dir = pathlib.Path(types.path_to_str(scratch_dir))
self.scratch_dir = types.parse_path(scratch_dir)
self.venv = venv
self.round_num = 0
self._last_loaded_round = -1
Expand Down Expand Up @@ -397,11 +395,7 @@ def _load_all_demos(self):
return demo_transitions, num_demos_by_round

def _get_demo_paths(self, round_dir):
return [
os.path.join(round_dir, p)
for p in os.listdir(round_dir)
if p.endswith(".npz")
]
return [round_dir / p for p in os.listdir(round_dir) if p.endswith(".npz")]

def _demo_dir_path_for_round(self, round_num: Optional[int] = None) -> pathlib.Path:
if round_num is None:
Expand All @@ -411,7 +405,7 @@ def _demo_dir_path_for_round(self, round_num: Optional[int] = None) -> pathlib.P
def _try_load_demos(self) -> None:
"""Load the dataset for this round into self.bc_trainer as a DataLoader."""
demo_dir = self._demo_dir_path_for_round()
demo_paths = self._get_demo_paths(demo_dir) if os.path.isdir(demo_dir) else []
demo_paths = self._get_demo_paths(demo_dir) if demo_dir.is_dir() else []
if len(demo_paths) == 0:
raise NeedsDemosException(
f"No demos found for round {self.round_num} in dir '{demo_dir}'. "
Expand Down
4 changes: 1 addition & 3 deletions src/imitation/algorithms/mce_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,7 @@ def set_pi(self, pi: np.ndarray) -> None:
self.pi = pi

def _predict(self, observation: th.Tensor, deterministic: bool = False):
raise NotImplementedError(
"Should never be called as predict overridden.",
)
raise NotImplementedError("Should never be called as predict overridden.")

def forward( # type: ignore[override]
self,
Expand Down
88 changes: 82 additions & 6 deletions src/imitation/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,87 @@ def dataclass_quick_asdict(obj) -> Dict[str, Any]:
return d


def path_to_str(path: AnyPath) -> str:
if isinstance(path, bytes):
return path.decode()
def parse_path(
path: AnyPath,
allow_relative: bool = True,
base_directory: Optional[pathlib.Path] = None,
) -> pathlib.Path:
"""Parse a path to a `pathlib.Path` object.
All resulting paths are resolved, absolute paths. If `allow_relative` is True,
then relative paths are allowed as input, and are resolved relative to the
current working directory, or relative to `base_directory` if it is
specified.
Args:
path: The path to parse. Can be a string, bytes, or `os.PathLike`.
allow_relative: If True, then relative paths are allowed as input, and
are resolved relative to the current working directory. If False,
an error is raised if the path is not absolute.
base_directory: If specified, then relative paths are resolved relative
to this directory, instead of the current working directory.
Returns:
A `pathlib.Path` object.
Raises:
ValueError: If `allow_relative` is False and the path is not absolute.
ValueError: If `base_directory` is specified and `allow_relative` is
False.
"""
if base_directory is not None and not allow_relative:
raise ValueError(
"If `base_directory` is specified, then `allow_relative` must be True.",
)

parsed_path: pathlib.Path
if isinstance(path, pathlib.Path):
parsed_path = path
elif isinstance(path, str):
parsed_path = pathlib.Path(path)
elif isinstance(path, bytes):
parsed_path = pathlib.Path(path.decode())
else:
parsed_path = pathlib.Path(str(path))

if parsed_path.is_absolute():
return parsed_path
else:
if allow_relative:
base_directory = base_directory or pathlib.Path.cwd()
# relative to current working directory
return base_directory / parsed_path
else:
raise ValueError(f"Path {str(parsed_path)} is not absolute")


def parse_optional_path(
path: Optional[AnyPath],
allow_relative: bool = True,
base_directory: Optional[pathlib.Path] = None,
) -> Optional[pathlib.Path]:
"""Parse an optional path to a `pathlib.Path` object.
All resulting paths are resolved, absolute paths. If `allow_relative` is True,
then relative paths are allowed as input, and are resolved relative to the
current working directory, or relative to `base_directory` if it is
specified.
Args:
path: The path to parse. Can be a string, bytes, or `os.PathLike`.
allow_relative: If True, then relative paths are allowed as input, and
are resolved relative to the current working directory. If False,
an error is raised if the path is not absolute.
base_directory: If specified, then relative paths are resolved relative
to this directory, instead of the current working directory.
Returns:
A `pathlib.Path` object, or None if `path` is None.
"""
if path is None:
return None
else:
return str(path)
return parse_path(path, allow_relative, base_directory)


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -417,10 +493,10 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]):
trajectories: The trajectories to save.
Raises:
ValueError: If the trajectories are not all of the same type, i.e. some are
ValueError: If not all trajectories have the same type, i.e. some are
`Trajectory` and others are `TrajectoryWithRew`.
"""
p = pathlib.Path(path_to_str(path))
p = parse_path(path)
p.parent.mkdir(parents=True, exist_ok=True)
tmp_path = f"{p}.tmp"

Expand Down
2 changes: 1 addition & 1 deletion src/imitation/policies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _choose_action(self, obs: np.ndarray) -> np.ndarray:
def forward(self, *args):
# technically BasePolicy is a Torch module, so this needs a forward()
# method
raise NotImplementedError()
raise NotImplementedError() # pragma: no cover


class RandomPolicy(HardCodedPolicy):
Expand Down
18 changes: 9 additions & 9 deletions src/imitation/policies/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# torch.load() and torch.save() calls

import logging
import os
import pathlib
from typing import Callable, Type, TypeVar

import huggingface_sb3 as hfsb3
from stable_baselines3.common import base_class, callbacks, policies, vec_env

from imitation.data import types
from imitation.policies import base
from imitation.util import registry

Expand Down Expand Up @@ -52,7 +52,7 @@ def load_stable_baselines_model(
The deserialized RL algorithm.
"""
logging.info(f"Loading Stable Baselines policy for '{cls}' from '{path}'")
path_obj = pathlib.Path(path)
path_obj = types.parse_path(path)

if path_obj.is_dir():
path_obj = path_obj / "model.zip"
Expand All @@ -66,7 +66,7 @@ def load_stable_baselines_model(
if vec_normalize_path.exists():
raise FileExistsError(
"Outdated policy format: we do not support restoring normalization "
"statistics from '{vec_normalize_path}'",
f"statistics from '{vec_normalize_path}'",
)

return cls.load(path_obj, env=venv, **kwargs)
Expand Down Expand Up @@ -181,7 +181,7 @@ def load_policy(


def save_stable_model(
output_dir: str,
output_dir: pathlib.Path,
model: base_class.BaseAlgorithm,
filename: str = "model.zip",
) -> None:
Expand All @@ -197,9 +197,9 @@ def save_stable_model(
# Save each model in new directory in case we want to add metadata or other
# information in future. (E.g. we used to save `VecNormalize` statistics here,
# although that is no longer necessary.)
os.makedirs(output_dir, exist_ok=True)
model.save(os.path.join(output_dir, filename))
logging.info("Saved policy to %s", output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
model.save(output_dir / filename)
logging.info(f"Saved policy to {output_dir}")


class SavePolicyCallback(callbacks.EventCallback):
Expand All @@ -211,7 +211,7 @@ class SavePolicyCallback(callbacks.EventCallback):

def __init__(
self,
policy_dir: str,
policy_dir: pathlib.Path,
*args,
**kwargs,
):
Expand All @@ -227,6 +227,6 @@ def __init__(

def _on_step(self) -> bool:
assert self.model is not None
output_dir = os.path.join(self.policy_dir, f"{self.num_timesteps:012d}")
output_dir = self.policy_dir / f"{self.num_timesteps:012d}"
save_stable_model(output_dir, self.model)
return True
Loading

0 comments on commit 7376f24

Please sign in to comment.