diff --git a/docs/algorithms/mce_irl.rst b/docs/algorithms/mce_irl.rst index 78be84a36..99cc68411 100644 --- a/docs/algorithms/mce_irl.rst +++ b/docs/algorithms/mce_irl.rst @@ -13,7 +13,7 @@ Detailed example notebook: :doc:`../tutorials/6_train_mce` from functools import partial - from seals import base_envs as envs + from seals import base_envs from seals.diagnostics.cliff_world import CliffWorldEnv import numpy as np @@ -32,7 +32,7 @@ Detailed example notebook: :doc:`../tutorials/6_train_mce` env_creator = partial(CliffWorldEnv, height=4, horizon=8, width=7, use_xy_obs=True) env_single = env_creator() - state_env_creator = lambda: envs.ExposePOMDPStateWrapper(env_creator()) + state_env_creator = lambda: base_envs.ExposePOMDPStateWrapper(env_creator()) # This is just a vectorized environment because `generate_trajectories` expects one state_venv = DummyVecEnv([state_env_creator] * 4) diff --git a/docs/tutorials/6_train_mce.ipynb b/docs/tutorials/6_train_mce.ipynb index 66c15cce8..115a1b6e7 100644 --- a/docs/tutorials/6_train_mce.ipynb +++ b/docs/tutorials/6_train_mce.ipynb @@ -25,7 +25,7 @@ "source": [ "from functools import partial\n", "\n", - "from seals import base_envs as envs\n", + "from seals import base_envs\n", "from seals.diagnostics.cliff_world import CliffWorldEnv\n", "from stable_baselines3.common.vec_env import DummyVecEnv\n", "\n", @@ -43,7 +43,7 @@ "env_creator = partial(CliffWorldEnv, height=4, horizon=8, width=7, use_xy_obs=True)\n", "env_single = env_creator()\n", "\n", - "state_env_creator = lambda: envs.ExposePOMDPStateWrapper(env_creator())\n", + "state_env_creator = lambda: base_envs.ExposePOMDPStateWrapper(env_creator())\n", "\n", "# This is just a vectorized environment because `generate_trajectories` expects one\n", "state_venv = DummyVecEnv([state_env_creator] * 4)" @@ -247,4 +247,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/src/imitation/algorithms/mce_irl.py b/src/imitation/algorithms/mce_irl.py index c739ec284..22262a5e1 100644 --- a/src/imitation/algorithms/mce_irl.py +++ b/src/imitation/algorithms/mce_irl.py @@ -13,7 +13,7 @@ import numpy as np import scipy.special import torch as th -from seals import base_envs as envs +from seals import base_envs from stable_baselines3.common import policies from imitation.algorithms import base @@ -24,7 +24,7 @@ def mce_partition_fh( - env: envs.TabularModelPOMDP, + env: base_envs.TabularModelPOMDP, *, reward: Optional[np.ndarray] = None, discount: float = 1.0, @@ -77,7 +77,7 @@ def mce_partition_fh( def mce_occupancy_measures( - env: envs.TabularModelPOMDP, + env: base_envs.TabularModelPOMDP, *, reward: Optional[np.ndarray] = None, pi: Optional[np.ndarray] = None, @@ -257,7 +257,7 @@ class MCEIRL(base.DemonstrationAlgorithm[types.TransitionsMinimal]): def __init__( self, demonstrations: Optional[MCEDemonstrations], - env: envs.TabularModelPOMDP, + env: base_envs.TabularModelPOMDP, reward_net: reward_nets.RewardNet, rng: np.random.Generator, optimizer_cls: Type[th.optim.Optimizer] = th.optim.Adam, diff --git a/tests/algorithms/test_mce_irl.py b/tests/algorithms/test_mce_irl.py index 33d33865a..f8347b46f 100644 --- a/tests/algorithms/test_mce_irl.py +++ b/tests/algorithms/test_mce_irl.py @@ -6,7 +6,7 @@ import numpy as np import pytest import torch as th -from seals import base_envs as envs +from seals import base_envs from seals.diagnostics import random_trans from stable_baselines3.common import vec_env @@ -116,7 +116,7 @@ def test_policy_om_random_mdp(discount: float): assert np.allclose(np.sum(D), expected_sum) -class ReasonablePOMDP(envs.TabularModelPOMDP): +class ReasonablePOMDP(base_envs.TabularModelPOMDP): """A tabular MDP with sensible parameters.""" def __init__(self): @@ -314,7 +314,7 @@ def test_mce_irl_demo_formats(rng): obs_dim=None, generator_seed=42, ) - state_env = envs.ExposePOMDPStateWrapper(mdp) + state_env = base_envs.ExposePOMDPStateWrapper(mdp) state_venv = vec_env.DummyVecEnv([lambda: state_env]) trajs = rollout.generate_trajectories( policy=None, @@ -406,7 +406,7 @@ def test_mce_irl_reasonable_mdp( # make sure weights have non-insane norm assert tensor_iter_norm(reward_net.parameters()) < 1000 - state_env = envs.ExposePOMDPStateWrapper(mdp) + state_env = base_envs.ExposePOMDPStateWrapper(mdp) state_venv = vec_env.DummyVecEnv([lambda: state_env]) trajs = rollout.generate_trajectories( mce_irl.policy,