Skip to content

Commit

Permalink
Remove unrelated changes
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Apr 10, 2024
1 parent 1a2352a commit c02f2f7
Showing 1 changed file with 134 additions and 11 deletions.
145 changes: 134 additions & 11 deletions tests/python/test_atari_env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import warnings
from itertools import product
from unittest.mock import patch

import gymnasium
import numpy as np
import pytest
from ale_py.env import AtariEnv
from ale_py.registration import _register_rom_configs, register_v0_v4_envs
from gymnasium.utils.env_checker import check_env
from utils import test_rom_path, tetris_env # noqa: F401

Expand Down Expand Up @@ -38,8 +40,122 @@ def test_check_env(env_id):
raise ValueError(warning.message.args[0])


def test_register_legacy_env_id():
prefix = "ALETest/"

_original_register_gym_configs = _register_rom_configs

def _mocked_register_gym_configs(*args, **kwargs):
return _original_register_gym_configs(*args, **kwargs, prefix=prefix)

with patch(
"ale_py.registration._register_rom_configs",
new=_mocked_register_gym_configs,
):
# Register internal IDs
register_v0_v4_envs()

# Check if we registered the proper environments
envids = set(map(lambda e: e.id, gymnasium.registry.values()))
legacy_games = [
"Adventure",
"AirRaid",
"Alien",
"Amidar",
"Assault",
"Asterix",
"Asteroids",
"Atlantis",
"BankHeist",
"BattleZone",
"BeamRider",
"Berzerk",
"Bowling",
"Boxing",
"Breakout",
"Carnival",
"Centipede",
"ChopperCommand",
"CrazyClimber",
"Defender",
"DemonAttack",
"DoubleDunk",
"ElevatorAction",
"Enduro",
"FishingDerby",
"Freeway",
"Frostbite",
"Gopher",
"Gravitar",
"Hero",
"IceHockey",
"Jamesbond",
"JourneyEscape",
"Kangaroo",
"Krull",
"KungFuMaster",
"MontezumaRevenge",
"MsPacman",
"NameThisGame",
"Phoenix",
"Pitfall",
"Pong",
"Pooyan",
"PrivateEye",
"Qbert",
"Riverraid",
"RoadRunner",
"Robotank",
"Seaquest",
"Skiing",
"Solaris",
"SpaceInvaders",
"StarGunner",
"Tennis",
"TimePilot",
"Tutankham",
"UpNDown",
"Venture",
"VideoPinball",
"WizardOfWor",
"YarsRevenge",
"Zaxxon",
]
legacy_games = map(lambda game: f"{prefix}{game}", legacy_games)

obs_types = ["", "-ram"]
suffixes = ["Deterministic", "NoFrameskip"]
versions = ["-v0", "-v4"]

all_ids = set(
map("".join, product(legacy_games, obs_types, suffixes, versions))
)
assert all_ids.issubset(envids)


def test_register_gym_envs(test_rom_path):
with patch("ale_py.roms.Tetris", create=True, new_callable=lambda: test_rom_path):
# Register internal IDs
# register_v5_envs()

# Check if we registered the proper environments
envids = set(map(lambda e: e.id, gymnasium.registry.values()))
games = ["ALE/Tetris"]

obs_types = ["", "-ram"]
suffixes = []
versions = ["-v5"]

all_ids = set(map("".join, product(games, obs_types, suffixes, versions)))
assert all_ids.issubset(envids)


def test_gym_make(tetris_env):
assert isinstance(tetris_env, gymnasium.Env)


@pytest.mark.parametrize("tetris_env", [{"render_mode": "rgb_array"}], indirect=True)
def test_render_kwarg(tetris_env):
def test_gym_render_kwarg(tetris_env):
tetris_env.reset()
_, _, _, _, info = tetris_env.step(0)
assert "rgb" not in info
Expand All @@ -51,7 +167,7 @@ def test_render_kwarg(tetris_env):
@pytest.mark.parametrize(
"tetris_env", [{"max_num_frames_per_episode": 10, "frameskip": 1}], indirect=True
)
def test_truncate_on_max_episode_steps(tetris_env):
def test_gym_truncate_on_max_episode_steps(tetris_env):
tetris_env.reset()

is_truncated = False
Expand All @@ -63,12 +179,12 @@ def test_truncate_on_max_episode_steps(tetris_env):


@pytest.mark.parametrize("tetris_env", [{"mode": 0, "difficulty": 0}], indirect=True)
def test_mode_difficulty_kwarg(tetris_env):
def test_gym_mode_difficulty_kwarg(tetris_env):
pass


@pytest.mark.parametrize("tetris_env", [{"obs_type": "ram"}], indirect=True)
def test_ram_obs(tetris_env):
def test_gym_ram_obs(tetris_env):
tetris_env.reset()
obs, _, _, _, _ = tetris_env.step(0)
space = tetris_env.observation_space
Expand All @@ -84,7 +200,7 @@ def test_ram_obs(tetris_env):


@pytest.mark.parametrize("tetris_env", [{"obs_type": "grayscale"}], indirect=True)
def test_img_grayscale_obs(tetris_env):
def test_gym_img_grayscale_obs(tetris_env):
tetris_env.reset()
obs, _, _, _, _ = tetris_env.step(0)
space = tetris_env.observation_space
Expand All @@ -101,7 +217,7 @@ def test_img_grayscale_obs(tetris_env):


@pytest.mark.parametrize("tetris_env", [{"obs_type": "rgb"}], indirect=True)
def test_img_rgb_obs(tetris_env):
def test_gym_img_rgb_obs(tetris_env):
tetris_env.reset()
obs, _, _, _, _ = tetris_env.step(0)
space = tetris_env.observation_space
Expand All @@ -120,7 +236,7 @@ def test_img_rgb_obs(tetris_env):


@pytest.mark.parametrize("tetris_env", [{"full_action_space": True}], indirect=True)
def test_keys_to_action(tetris_env):
def test_gym_keys_to_action(tetris_env):
keys_full_action_space = {
(None,): 0,
(32,): 1,
Expand All @@ -147,7 +263,7 @@ def test_keys_to_action(tetris_env):


@pytest.mark.parametrize("tetris_env", [{"full_action_space": True}], indirect=True)
def test_action_meaning(tetris_env):
def test_gym_action_meaning(tetris_env):
action_meanings = [
"NOOP",
"FIRE",
Expand All @@ -172,7 +288,7 @@ def test_action_meaning(tetris_env):
assert tetris_env.unwrapped.get_action_meanings() == action_meanings


def test_clone_state(tetris_env):
def test_gym_clone_state(tetris_env):
tetris_env = tetris_env.unwrapped

tetris_env.reset(seed=0)
Expand All @@ -187,11 +303,11 @@ def test_clone_state(tetris_env):


@pytest.mark.parametrize("tetris_env", [{"full_action_space": True}], indirect=True)
def test_action_space(tetris_env):
def test_gym_action_space(tetris_env):
assert tetris_env.action_space.n == 18


def test_reset_infos(tetris_env):
def test_gym_reset_with_infos(tetris_env):
pack = tetris_env.reset(seed=0)

assert isinstance(pack, tuple)
Expand Down Expand Up @@ -231,3 +347,10 @@ def test_render_exception(tetris_env):

with pytest.raises(TypeError):
tetris_env.unwrapped.render(mode="human")


def test_gym_compliance(tetris_env):
with warnings.catch_warnings(record=True) as caught_warnings:
check_env(tetris_env.unwrapped, skip_render_check=True)

assert len(caught_warnings) == 0, [w.message for w in caught_warnings]

0 comments on commit c02f2f7

Please sign in to comment.