Skip to content

Commit

Permalink
Remove import rename for seals package
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocamonde committed Oct 12, 2022
1 parent 41d01ee commit 95d37ca
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions docs/algorithms/mce_irl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Detailed example notebook: :doc:`../tutorials/6_train_mce`

from functools import partial

from seals import base_envs as envs
from seals import base_envs
from seals.diagnostics.cliff_world import CliffWorldEnv
import numpy as np

Expand All @@ -32,7 +32,7 @@ Detailed example notebook: :doc:`../tutorials/6_train_mce`
env_creator = partial(CliffWorldEnv, height=4, horizon=8, width=7, use_xy_obs=True)
env_single = env_creator()

state_env_creator = lambda: envs.ExposePOMDPStateWrapper(env_creator())
state_env_creator = lambda: base_envs.ExposePOMDPStateWrapper(env_creator())

# This is just a vectorized environment because `generate_trajectories` expects one
state_venv = DummyVecEnv([state_env_creator] * 4)
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/6_train_mce.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"source": [
"from functools import partial\n",
"\n",
"from seals import base_envs as envs\n",
"from seals import base_envs\n",
"from seals.diagnostics.cliff_world import CliffWorldEnv\n",
"from stable_baselines3.common.vec_env import DummyVecEnv\n",
"\n",
Expand All @@ -43,7 +43,7 @@
"env_creator = partial(CliffWorldEnv, height=4, horizon=8, width=7, use_xy_obs=True)\n",
"env_single = env_creator()\n",
"\n",
"state_env_creator = lambda: envs.ExposePOMDPStateWrapper(env_creator())\n",
"state_env_creator = lambda: base_envs.ExposePOMDPStateWrapper(env_creator())\n",
"\n",
"# This is just a vectorized environment because `generate_trajectories` expects one\n",
"state_venv = DummyVecEnv([state_env_creator] * 4)"
Expand Down Expand Up @@ -247,4 +247,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
8 changes: 4 additions & 4 deletions src/imitation/algorithms/mce_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import scipy.special
import torch as th
from seals import base_envs as envs
from seals import base_envs
from stable_baselines3.common import policies

from imitation.algorithms import base
Expand All @@ -24,7 +24,7 @@


def mce_partition_fh(
env: envs.TabularModelPOMDP,
env: base_envs.TabularModelPOMDP,
*,
reward: Optional[np.ndarray] = None,
discount: float = 1.0,
Expand Down Expand Up @@ -77,7 +77,7 @@ def mce_partition_fh(


def mce_occupancy_measures(
env: envs.TabularModelPOMDP,
env: base_envs.TabularModelPOMDP,
*,
reward: Optional[np.ndarray] = None,
pi: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -257,7 +257,7 @@ class MCEIRL(base.DemonstrationAlgorithm[types.TransitionsMinimal]):
def __init__(
self,
demonstrations: Optional[MCEDemonstrations],
env: envs.TabularModelPOMDP,
env: base_envs.TabularModelPOMDP,
reward_net: reward_nets.RewardNet,
rng: np.random.Generator,
optimizer_cls: Type[th.optim.Optimizer] = th.optim.Adam,
Expand Down
8 changes: 4 additions & 4 deletions tests/algorithms/test_mce_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pytest
import torch as th
from seals import base_envs as envs
from seals import base_envs
from seals.diagnostics import random_trans
from stable_baselines3.common import vec_env

Expand Down Expand Up @@ -116,7 +116,7 @@ def test_policy_om_random_mdp(discount: float):
assert np.allclose(np.sum(D), expected_sum)


class ReasonablePOMDP(envs.TabularModelPOMDP):
class ReasonablePOMDP(base_envs.TabularModelPOMDP):
"""A tabular MDP with sensible parameters."""

def __init__(self):
Expand Down Expand Up @@ -314,7 +314,7 @@ def test_mce_irl_demo_formats(rng):
obs_dim=None,
generator_seed=42,
)
state_env = envs.ExposePOMDPStateWrapper(mdp)
state_env = base_envs.ExposePOMDPStateWrapper(mdp)
state_venv = vec_env.DummyVecEnv([lambda: state_env])
trajs = rollout.generate_trajectories(
policy=None,
Expand Down Expand Up @@ -406,7 +406,7 @@ def test_mce_irl_reasonable_mdp(
# make sure weights have non-insane norm
assert tensor_iter_norm(reward_net.parameters()) < 1000

state_env = envs.ExposePOMDPStateWrapper(mdp)
state_env = base_envs.ExposePOMDPStateWrapper(mdp)
state_venv = vec_env.DummyVecEnv([lambda: state_env])
trajs = rollout.generate_trajectories(
mce_irl.policy,
Expand Down

0 comments on commit 95d37ca

Please sign in to comment.