Skip to content

Commit

Permalink
Convert observation to dictionary
Browse files Browse the repository at this point in the history
 - image_front
 - image_top
 - arm_qpos
 - arm_qvel
 - object_qpos
 - object_qvel

Add environment wrapper to filter and flatten observations.
  • Loading branch information
perezjln committed May 27, 2024
1 parent e22c8ad commit 37056f5
Show file tree
Hide file tree
Showing 31 changed files with 183 additions and 226 deletions.
12 changes: 10 additions & 2 deletions examples/gym_check_determinism.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import gymnasium as gym

import gymnasium
from gymnasium.wrappers.filter_observation import FilterObservation
from gymnasium.wrappers.flatten_observation import FlattenObservation

import gym_lowcostrobot

if __name__ == "__main__":
for env_name in [
Expand All @@ -8,7 +13,10 @@
"ReachCube-v0",
"Stack-v0",
]:
env = gym.make(env_name)
env = gymnasium.make(env_name)
env = FilterObservation(env, ["arm_qpos", "object_qpos"])
env = FlattenObservation(env)

observation1, info = env.reset(seed=123)
observation2, info = env.reset(seed=123)
observation3, info = env.reset(seed=123)
Expand Down
1 change: 1 addition & 0 deletions examples/gym_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def do_env_sim():

max_step = 1000000
for _ in range(max_step):

action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)

Expand Down
4 changes: 4 additions & 0 deletions examples/gym_manipulation_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@ def do_env_sim_image():

max_step = 1000
for _ in range(max_step):

action = env.action_space.sample()
_, _, terminated, truncated, info = env.step(action)

print(info["img"].shape)
input("Press Enter to continue...")

plt.imshow(info["img"])
plt.show()

Expand Down
1 change: 1 addition & 0 deletions examples/gym_manipulation_img_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def do_env_sim_image():

max_step = 1000
for _ in range(max_step):

action = env.action_space.sample()
_, _, terminated, truncated, info = env.step(action)

Expand Down
30 changes: 19 additions & 11 deletions examples/gym_manipulation_sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
from gym_lowcostrobot.envs.lift_cube_env import LiftCubeEnv
from gym_lowcostrobot.envs.reach_cube_env import ReachCubeEnv

from gymnasium.wrappers.filter_observation import FilterObservation
from gymnasium.wrappers.flatten_observation import FlattenObservation

def do_td3_reach():

env = ReachCubeEnv()
env = FilterObservation(env, ["arm_qpos", "object_qpos"])
env = FlattenObservation(env)

# Define and train the TD3 agent
model = TD3("MlpPolicy", env, verbose=1)
Expand All @@ -22,55 +26,59 @@ def do_td3_reach():
print(f"Mean reward: {mean_reward} +/- {std_reward}")


def do_ppo_reach(device="cuda", render=True):
nb_parallel_env = 8
def do_ppo_reach(device="cpu", render=True):
nb_parallel_env = 4
envs = SubprocVecEnv(
[lambda: ReachCubeEnv() for _ in range(nb_parallel_env)]
[lambda: FlattenObservation(FilterObservation(ReachCubeEnv(), ["arm_qpos", "object_qpos"])) for _ in range(nb_parallel_env)]
)

# Define and train the TD3 agent
model = PPO("MlpPolicy", envs, verbose=1, device=device)
model.learn(total_timesteps=int(1e3), tb_log_name="ppo_reach_cube", progress_bar=True)

# Evaluate the agent
env = ReachCubeEnv(render_mode="human")
env = FlattenObservation(FilterObservation(ReachCubeEnv(render_mode="human"), ["arm_qpos", "object_qpos"]))
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10, render=render)
print(f"Mean reward: {mean_reward} +/- {std_reward}")


def do_td3_lift():
env = LiftCubeEnv()
env = FilterObservation(env, ["arm_qpos", "object_qpos"])
env = FlattenObservation(env)

# Define the evaluation callback
eval_env = LiftCubeEnv()
eval_env = FilterObservation(eval_env, ["arm_qpos", "object_qpos"])
eval_env = FlattenObservation(eval_env)

eval_callback = EvalCallback(
eval_env, eval_freq=1000, n_eval_episodes=10, deterministic=True, callback_on_new_best=None
)

# Define and train the TD3 agent
model = TD3("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=int(1e5), tb_log_name="td3_lift_cube", callback=eval_callback, progress_bar=True)
model.learn(total_timesteps=int(1e3), tb_log_name="td3_lift_cube", callback=eval_callback, progress_bar=True)

# Evaluate the agent
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward: {mean_reward} +/- {std_reward}")


def do_ppo_lift():
nb_parallel_env = 8
nb_parallel_env = 4
envs = SubprocVecEnv(
[
lambda: LiftCubeEnv()
for _ in range(nb_parallel_env)
lambda: FlattenObservation(FilterObservation(LiftCubeEnv(), ["arm_qpos", "object_qpos"])) for _ in range(nb_parallel_env)
]
)

# Define and train the TD3 agent
model = PPO("MlpPolicy", envs, verbose=1)
model.learn(total_timesteps=int(1e5), tb_log_name="ppo_lift_cube", progress_bar=True)
model.learn(total_timesteps=int(1e3), tb_log_name="ppo_lift_cube", progress_bar=True)

# Evaluate the agent
env = ReachCubeEnv()
env = FlattenObservation(FilterObservation(LiftCubeEnv(render_mode="human"), ["arm_qpos", "object_qpos"]))
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward: {mean_reward} +/- {std_reward}")

Expand All @@ -80,4 +88,4 @@ def do_ppo_lift():
print("Available devices:")
print(torch.cuda.device_count())

do_ppo_reach()
do_td3_lift()
2 changes: 1 addition & 1 deletion examples/mujoco_simple_invk.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def do_simple_invk(robot_id="5dof"):
print("Cube dist:", np.linalg.norm(cube_pos - ee_pos))
if np.linalg.norm(cube_pos - ee_pos) < min_dist or np.linalg.norm(cube_pos - ee_pos) > max_dist:
print("Cube reached the target position")
data.joint(object_id).qpos[:3] = [np.random.rand() * 0.2, np.random.rand() * 0.2, 0.01]
data.joint(object_id).qpos[:3] = [np.random.rand() * 0.5, np.random.rand() * 0.5, 0.01]
mujoco.mj_step(m, data)
viewer.sync()

Expand Down
13 changes: 0 additions & 13 deletions gym_lowcostrobot/assets/low_cost_robot/arm_connector.part

This file was deleted.

13 changes: 0 additions & 13 deletions gym_lowcostrobot/assets/low_cost_robot/base.part

This file was deleted.

13 changes: 0 additions & 13 deletions gym_lowcostrobot/assets/low_cost_robot/connector.part

This file was deleted.

13 changes: 0 additions & 13 deletions gym_lowcostrobot/assets/low_cost_robot/dc11_a01_dummy.part

This file was deleted.

13 changes: 0 additions & 13 deletions gym_lowcostrobot/assets/low_cost_robot/dc11_a01_spacer_dummy.part

This file was deleted.

13 changes: 0 additions & 13 deletions gym_lowcostrobot/assets/low_cost_robot/dc15_a01_case_b_dummy.part

This file was deleted.

13 changes: 0 additions & 13 deletions gym_lowcostrobot/assets/low_cost_robot/dc15_a01_case_f_dummy.part

This file was deleted.

13 changes: 0 additions & 13 deletions gym_lowcostrobot/assets/low_cost_robot/dc15_a01_case_m_dummy.part

This file was deleted.

13 changes: 0 additions & 13 deletions gym_lowcostrobot/assets/low_cost_robot/dc15_a01_horn_dummy.part

This file was deleted.

This file was deleted.

13 changes: 0 additions & 13 deletions gym_lowcostrobot/assets/low_cost_robot/moving_side.part

This file was deleted.

14 changes: 7 additions & 7 deletions gym_lowcostrobot/assets/low_cost_robot/robot.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<compiler angle="radian"/>
<default>
<joint damping="0.2" frictionloss="0.1"/>
<position kp="10" forcerange="-1.0 1.0"/>
<position kp="10" forcerange="-10 10"/>
</default>
<asset>
<mesh name="base" file="base.stl"/>
Expand Down Expand Up @@ -33,7 +33,7 @@
<geom pos="0.0511555 0.000624643 0.0099573" quat="0.707107 0 0 0.707107" type="mesh" rgba="0.647059 0.647059 0.647059 1" mesh="dc11_a01_spacer_dummy"/>
<body name="pitch1_assembly" pos="0.0401555 0.0326246 0.0166573">
<inertial pos="-0.000767103 -0.0121505 0.0134241" quat="0.498429 0.53272 -0.473938 0.493113" mass="0.0606831" diaginertia="1.86261e-05 1.72746e-05 1.11693e-05"/>
<joint name="yaw" pos="0 0 0" axis="0 0 1" range="-3.14159 3.14159" actuatorfrcrange="-1 1"/>
<joint name="yaw" pos="0 0 0" axis="0 0 1" range="-3.14159 3.14159"/>
<geom pos="0 0 -0.0209" type="mesh" rgba="0.231373 0.380392 0.705882 1" mesh="rotation_connector"/>
<geom pos="-0.014 0.008 0.0264" quat="0 -0.707107 0 0.707107" type="mesh" rgba="0.647059 0.647059 0.647059 1" mesh="dc11_a01_spacer_dummy"/>
<geom pos="-0.014 -0.032 0.0044" quat="0 -0.707107 0 0.707107" type="mesh" rgba="0.647059 0.647059 0.647059 1" mesh="dc11_a01_spacer_dummy"/>
Expand All @@ -46,7 +46,7 @@
<geom pos="-0.014 -0.032 0.0264" quat="0.5 -0.5 -0.5 0.5" type="mesh" rgba="0.647059 0.647059 0.647059 1" mesh="dc11_a01_spacer_dummy"/>
<body name="pitch2_assembly" pos="-0.0188 0 0.0154" quat="0 0.707107 0 0.707107">
<inertial pos="0.0766242 -0.00031229 0.0187402" quat="0.52596 0.513053 0.489778 0.469319" mass="0.0432446" diaginertia="7.21796e-05 7.03107e-05 1.07533e-05"/>
<joint name="pitch1" pos="0 0 0" axis="0 0 1" range="-1.5708 1.22173" actuatorfrcrange="-1 1"/>
<joint name="pitch1" pos="0 0 0" axis="0 0 1" range="-1.5708 1.22173"/>
<geom pos="0 0 0.019" quat="0.5 -0.5 -0.5 -0.5" type="mesh" rgba="0.980392 0.713726 0.00392157 1" mesh="arm_connector"/>
<geom pos="0.1083 -0.0148 0.03035" quat="1 0 0 0" type="mesh" rgba="0.647059 0.647059 0.647059 1" mesh="dc15_a01_horn_idle2_dummy"/>
<geom pos="0.1083 -0.0148 0.01075" quat="0 -1 0 0" type="mesh" rgba="0.615686 0.811765 0.929412 1" mesh="dc15_a01_case_m_dummy"/>
Expand All @@ -55,7 +55,7 @@
<geom pos="0.1083 -0.0148 0.03025" quat="0 -1 0 0" type="mesh" rgba="0.498039 0.498039 0.498039 1" mesh="dc15_a01_case_b_dummy"/>
<body name="pitch3_assembly" pos="0.1083 -0.0148 0.00425" quat="0.707107 0 0 0.707107">
<inertial pos="-0.0551014 -0.00287792 0.0144813" quat="0.500323 0.499209 0.499868 0.5006" mass="0.0788335" diaginertia="6.80912e-05 6.45748e-05 9.84479e-06"/>
<joint name="pitch2" pos="0 0 0" axis="0 0 1" range="-1.48353 1.74533" actuatorfrcrange="-1 1"/>
<joint name="pitch2" pos="0 0 0" axis="0 0 1" range="-1.48353 1.74533" />
<geom pos="-0.00863031 0.00847376 0.0145" quat="0.5 0.5 0.5 -0.5" type="mesh" rgba="0.615686 0.811765 0.929412 1" mesh="connector"/>
<geom pos="-0.100476 -0.00269986 0.02635" quat="0.707107 0 0 -0.707107" type="mesh" rgba="0.647059 0.647059 0.647059 1" mesh="dc15_a01_horn_idle2_dummy"/>
<geom pos="-0.100476 -0.00269986 0.00675" quat="0 -0.707107 0.707107 0" type="mesh" rgba="0.615686 0.811765 0.929412 1" mesh="dc15_a01_case_m_dummy"/>
Expand All @@ -64,15 +64,15 @@
<geom pos="-0.100476 -0.00269986 0.02625" quat="0 -0.707107 0.707107 0" type="mesh" rgba="0.498039 0.498039 0.498039 1" mesh="dc15_a01_case_b_dummy"/>
<body name="effector_roll_assembly" pos="-0.100476 -0.00269986 0.02925" quat="0 -0.707107 -0.707107 0">
<inertial pos="-1.65017e-05 -0.02659 0.0195388" quat="0.936813 0.349829 -0.00055331 -0.000300569" mass="0.0240506" diaginertia="6.03208e-06 4.12894e-06 3.3522e-06"/>
<joint name="pitch3" pos="0 0 0" axis="0 0 1" range="-1.91986 1.91986" actuatorfrcrange="-1 1"/>
<joint name="pitch3" pos="0 0 0" axis="0 0 1" range="-1.91986 1.91986" />
<geom pos="-0.0109998 -0.0190002 0.039" quat="0.707107 -0.707107 0 0" type="mesh" rgba="0.615686 0.811765 0.929412 1" mesh="shoulder_rotation"/>
<geom pos="-7.44154e-06 -0.0385002 0.0133967" quat="0 0 0.707107 -0.707107" type="mesh" rgba="0.615686 0.811765 0.929412 1" mesh="dc15_a01_case_m_dummy"/>
<geom pos="-7.44154e-06 -0.0385002 0.0133967" quat="0 0 0.707107 -0.707107" type="mesh" rgba="0.980392 0.713726 0.00392157 1" mesh="dc15_a01_case_f_dummy"/>
<geom pos="-7.44154e-06 -0.0421002 0.0133967" quat="0 1 0 0" type="mesh" rgba="0.972549 0.529412 0.00392157 1" mesh="dc15_a01_horn_dummy"/>
<geom pos="-7.44154e-06 -0.0190002 0.0133967" quat="0 0 0.707107 -0.707107" type="mesh" rgba="0.498039 0.498039 0.498039 1" mesh="dc15_a01_case_b_dummy"/>
<body name="gripper_assembly" pos="-7.44154e-06 -0.0450002 0.0133967" quat="0.5 -0.5 -0.5 -0.5">
<inertial pos="-0.00548595 -0.000433143 -0.0190793" quat="0.700194 0.164851 0.167361 0.674197" mass="0.0360627" diaginertia="1.3261e-05 1.231e-05 5.3532e-06"/>
<joint name="effector_roll" pos="0 0 0" axis="0 0 1" range="-2.96706 2.96706" actuatorfrcrange="-1 1"/>
<joint name="effector_roll" pos="0 0 0" axis="0 0 1" range="-2.96706 2.96706" />
<geom pos="-0.00075 -0.01475 -0.02" quat="0.707107 -0.707107 0 0" type="mesh" rgba="0.917647 0.917647 0.917647 1" mesh="static_side"/>
<geom pos="0.00755 0.01135 -0.013" quat="0.5 -0.5 0.5 0.5" type="mesh" rgba="0.647059 0.647059 0.647059 1" mesh="dc15_a01_horn_idle2_dummy"/>
<geom pos="0.00755 -0.00825 -0.013" quat="0.5 0.5 0.5 -0.5" type="mesh" rgba="0.615686 0.811765 0.929412 1" mesh="dc15_a01_case_m_dummy"/>
Expand All @@ -81,7 +81,7 @@
<geom pos="0.00755 0.01125 -0.013" quat="0.5 0.5 0.5 -0.5" type="mesh" rgba="0.498039 0.498039 0.498039 1" mesh="dc15_a01_case_b_dummy"/>
<body name="moving_side" pos="0.00755 -0.01475 -0.013" quat="0.707107 -0.707107 0 0">
<inertial pos="-0.000395599 0.022415 0.0145636" quat="0.722353 0.689129 0.0389102 0.0423547" mass="0.0089856" diaginertia="3.28451e-06 2.24898e-06 1.41539e-06"/>
<joint name="gripper_opening" pos="0 0 0" axis="0 0 1" range="-1.74533 0.0523599" actuatorfrcrange="-1 1"/>
<joint name="gripper_opening" pos="0 0 0" axis="0 0 1" range="-1.74533 0.0523599" />
<geom pos="-0.00838199 -0.000256591 -0.003" type="mesh" rgba="0.768627 0.886275 0.952941 1" mesh="moving_side"/>
</body>
</body>
Expand Down
13 changes: 0 additions & 13 deletions gym_lowcostrobot/assets/low_cost_robot/rotation_connector.part

This file was deleted.

Loading

0 comments on commit 37056f5

Please sign in to comment.