-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsave_trajectories_tz.py
56 lines (39 loc) · 1.41 KB
/
save_trajectories_tz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import random
import os
from quasimetric_rl.data.d4rl.grid_tank_goal import Tank_reach_goal
random.seed(0)
np.random.seed(0)
env = Tank_reach_goal()
name = 'trajectories_custom'
if not os.path.exists(name):
os.makedirs(name)
for i in range(1000):
env.reset()
observation_list = []
next_obervation_list = []
reward_list = []
terminal_list = []
actions_list = []
for j in range(1000):
dict_data = {}
random_action = random.randrange(len(env.action_ditct))
observation = env.get_observation()
next_observation, reward, terminal, _ = env.step(random_action)
observation_list.append(observation)
next_obervation_list.append(next_observation)
reward_list.append(reward)
terminal_list.append(terminal)
actions_list.append(random_action)
if terminal:
print("Found the end!")
break
dict_data['observations']=np.array(observation_list)
dict_data['next_observations'] = np.array(next_obervation_list)
dict_data['rewards'] = np.array(reward_list)
dict_data['terminals'] = np.array(terminal_list)
dict_data['all_observations'] = np.concatenate(
[dict_data['observations'], dict_data['next_observations'][-1:]], axis=0)
dict_data['actions'] = np.array(actions_list,dtype=np.int64)
np.savez(name+f'/test_{i:04}', **dict_data)
print(i)