Skip to content

Commit

Permalink
HDF5 trace recording wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
perezjln committed May 27, 2024
1 parent 37056f5 commit 14f8603
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 0 deletions.
30 changes: 30 additions & 0 deletions examples/trace_hdf5_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import argparse

from gym_lowcostrobot.envs.lift_cube_env import LiftCubeEnv
from gym_lowcostrobot.envs.reach_cube_env import ReachCubeEnv
from gym_lowcostrobot.envs.push_cube_env import PushCubeEnv
from gym_lowcostrobot.envs.pick_place_cube_env import PickPlaceCubeEnv
from gym_lowcostrobot.envs.stack_env import StackEnv

from gym_lowcostrobot.envs.wrappers.record_hdf5 import RecordHDF5Wrapper

def do_record_hdf5(args):

env = ReachCubeEnv(render_mode=None, image_state="multi", action_mode="ee")
env = RecordHDF5Wrapper(env, hdf5_folder="data", length=1000, name_prefix="reach")
env.reset()

max_step = 20
for _ in range(max_step):
action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
if terminated:
env.reset()

env.close()

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Trace video from HDF5 trace file")
parser.add_argument("--file_path", type=str, default="data/episode_49.hdf5", help="Path to HDF5 file")
args = parser.parse_args()
do_record_hdf5(args)
157 changes: 157 additions & 0 deletions gym_lowcostrobot/envs/wrappers/record_hdf5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Wrapper for recording videos."""
import os
from typing import Callable, Optional

import gymnasium as gym
from gymnasium import logger

import h5py
import numpy as np

"""
observations
observations/images
observations/images/front
observations/images/top
observations/qpos
observations/qvel
"""

class HDF5_Recorder:

def __init__(self):

self.step_id = 0
self.terminated = False
self.truncated = False
self.recorded_frames = 0
self.episode_id = 0
self.hdf5_file = None

self.lst_observations = []
self.lst_actions = []

def start_hdf5_recorder(self, hdf5_file):
"""Starts HDF5 recorder using :class:`HDF5_Recorder`."""
self.close()
self.hdf5_file = hdf5_file
self.recorded_frames = 1
self.recording = True
self.episode_id += 1

def capture_frame(self, observations, action):
"""Captures frame to video."""
assert self.hdf5_file is not None
self.lst_observations.append(observations)
self.lst_actions.append(action)
self.recorded_frames += 1

# numpy.stack([item["image_front"] for item in self.lst_observations])

def close(self):
"""Closes the hdf5 file."""
if self.hdf5_file is not None:
with h5py.File(self.hdf5_file, "w") as file:
file.create_dataset("observations/images/front", data=np.stack([item["image_front"] for item in self.lst_observations]))
file.create_dataset("observations/images/top", data=np.stack([item["image_top"] for item in self.lst_observations]))
file.create_dataset("observations/qpos", data=np.stack([item["arm_qpos"] for item in self.lst_observations]))
file.create_dataset("observations/qvel", data=np.stack([item["arm_qvel"] for item in self.lst_observations]))
file.create_dataset("action", data=self.lst_actions)
self.recorded_frames = 1
self.lst_observations = []
self.lst_actions = []


class RecordHDF5Wrapper(gym.Wrapper):

def __init__(
self,
env: gym.Env,
hdf5_folder: str,
length: int = 0,
name_prefix: str = "hdf5_record",
disable_logger: bool = False,
):
gym.Wrapper.__init__(self, env)

self.hdf5_folder = os.path.abspath(hdf5_folder)
self.hdf5_recorder = HDF5_Recorder()

if env.image_state != "multi":
raise ValueError(
f"Image state is {env.image_state}, which is incompatible with"
f" RecordHDF5Wrapper. Initialize your environment with a image_state == multi"
)

# Create output folder if needed
if os.path.isdir(self.hdf5_folder):
logger.warn(
f"Overwriting existing videos at {self.hdf5_folder} folder "
f"(try specifying a different `hdf5_folder` for the `RecordHDF5` wrapper if this is not desired)"
)
os.makedirs(self.hdf5_folder, exist_ok=True)

self.name_prefix = name_prefix
self.length = length
self.terminated = False
self.episode_id = 0
self.env = env

try:
self.is_vector_env = self.get_wrapper_attr("is_vector_env")
except AttributeError:
self.is_vector_env = False

def reset(self, **kwargs):
"""Reset the environment using kwargs and then starts recording if video enabled."""
observations, _ = self.env.reset(**kwargs)
self.terminated = False
self.start_hdf5_recorder()
return observations

def start_hdf5_recorder(self):
"""Starts video recorder using :class:`video_recorder.VideoRecorder`."""
self.close_hdf5_recorder()

video_name = f"{self.name_prefix}-episode-{self.episode_id}.hdf5"
self.hdf5_recorder.start_hdf5_recorder(hdf5_file=os.path.join(self.hdf5_folder, video_name))
self.recording = True


def step(self, action):
"""Steps through the environment using action, recording observations if :attr:`self.recording`."""

observations, rewards, terminateds, truncateds, infos = self.env.step(action)

# increment steps and episodes

if self.recording:
assert self.hdf5_recorder is not None

self.hdf5_recorder.capture_frame(observations, action)

if self.length > 0:
if self.hdf5_recorder.recorded_frames > self.length:
self.close_hdf5_recorder()
else:
if not self.is_vector_env:
if terminateds or truncateds:
self.close_hdf5_recorder()
elif terminateds[0] or truncateds[0]:
self.close_hdf5_recorder()

return observations, rewards, terminateds, truncateds, infos

def close_hdf5_recorder(self):
"""Closes the hdf5 recorder if currently recording."""
if self.hdf5_recorder is not None:
self.hdf5_recorder.close()
self.recorded_frames = 1

def render(self, *args, **kwargs):
return super().render(*args, **kwargs)

def close(self):
"""Closes the wrapper then the video recorder."""
super().close()
self.close_hdf5_recorder()

0 comments on commit 14f8603

Please sign in to comment.