-
Notifications
You must be signed in to change notification settings - Fork 532
/
Copy pathtest_env.py
41 lines (34 loc) · 1.22 KB
/
test_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import retro
from retro.testing import testenv, handle
import os
import pytest
def test_env_create(testenv):
json_path = os.path.join(os.path.dirname(__file__), 'dummy.json')
assert testenv(info=json_path, scenario=json_path)
@pytest.mark.parametrize('obs_type', [retro.Observations.IMAGE, retro.Observations.RAM])
def test_env_basic(obs_type, testenv):
import gym
import numpy as np
json_path = os.path.join(os.path.dirname(__file__), 'dummy.json')
env = testenv(info=json_path, scenario=json_path, obs_type=obs_type)
obs = env.reset()
assert obs.shape == env.observation_space.shape
obs, rew, done, info = env.step(env.action_space.sample())
assert obs.shape == env.observation_space.shape
assert isinstance(rew, float)
assert rew == 0
assert isinstance(done, bool)
assert not done
assert isinstance(info, dict)
def test_env_data(testenv):
json_path = os.path.join(os.path.dirname(__file__), 'dummy.json')
env = testenv(info=json_path, scenario=json_path)
assert isinstance(env.data[env.system], int)
env.data['foo'] = 1
assert env.data['foo'] == 1
env.reset()
try:
a = env.data['foo']
assert a != 1
except KeyError:
pass