diff --git a/examples/trace_hdf5_record.py b/examples/trace_hdf5_record.py new file mode 100644 index 0000000..8612aa4 --- /dev/null +++ b/examples/trace_hdf5_record.py @@ -0,0 +1,30 @@ +import argparse + +from gym_lowcostrobot.envs.lift_cube_env import LiftCubeEnv +from gym_lowcostrobot.envs.reach_cube_env import ReachCubeEnv +from gym_lowcostrobot.envs.push_cube_env import PushCubeEnv +from gym_lowcostrobot.envs.pick_place_cube_env import PickPlaceCubeEnv +from gym_lowcostrobot.envs.stack_env import StackEnv + +from gym_lowcostrobot.envs.wrappers.record_hdf5 import RecordHDF5Wrapper + +def do_record_hdf5(args): + + env = ReachCubeEnv(render_mode=None, image_state="multi", action_mode="ee") + env = RecordHDF5Wrapper(env, hdf5_folder="data", length=1000, name_prefix="reach") + env.reset() + + max_step = 20 + for _ in range(max_step): + action = env.action_space.sample() + observation, reward, terminated, truncated, info = env.step(action) + if terminated: + env.reset() + + env.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Trace video from HDF5 trace file") + parser.add_argument("--file_path", type=str, default="data/episode_49.hdf5", help="Path to HDF5 file") + args = parser.parse_args() + do_record_hdf5(args) diff --git a/gym_lowcostrobot/envs/wrappers/record_hdf5.py b/gym_lowcostrobot/envs/wrappers/record_hdf5.py new file mode 100644 index 0000000..d8984dc --- /dev/null +++ b/gym_lowcostrobot/envs/wrappers/record_hdf5.py @@ -0,0 +1,157 @@ +"""Wrapper for recording videos.""" +import os +from typing import Callable, Optional + +import gymnasium as gym +from gymnasium import logger + +import h5py +import numpy as np + +""" +observations +observations/images +observations/images/front +observations/images/top +observations/qpos +observations/qvel +""" + +class HDF5_Recorder: + + def __init__(self): + + self.step_id = 0 + self.terminated = False + self.truncated = False + self.recorded_frames = 0 + self.episode_id = 0 + self.hdf5_file = None + + self.lst_observations = [] + self.lst_actions = [] + + def start_hdf5_recorder(self, hdf5_file): + """Starts HDF5 recorder using :class:`HDF5_Recorder`.""" + self.close() + self.hdf5_file = hdf5_file + self.recorded_frames = 1 + self.recording = True + self.episode_id += 1 + + def capture_frame(self, observations, action): + """Captures frame to video.""" + assert self.hdf5_file is not None + self.lst_observations.append(observations) + self.lst_actions.append(action) + self.recorded_frames += 1 + + # numpy.stack([item["image_front"] for item in self.lst_observations]) + + def close(self): + """Closes the hdf5 file.""" + if self.hdf5_file is not None: + with h5py.File(self.hdf5_file, "w") as file: + file.create_dataset("observations/images/front", data=np.stack([item["image_front"] for item in self.lst_observations])) + file.create_dataset("observations/images/top", data=np.stack([item["image_top"] for item in self.lst_observations])) + file.create_dataset("observations/qpos", data=np.stack([item["arm_qpos"] for item in self.lst_observations])) + file.create_dataset("observations/qvel", data=np.stack([item["arm_qvel"] for item in self.lst_observations])) + file.create_dataset("action", data=self.lst_actions) + self.recorded_frames = 1 + self.lst_observations = [] + self.lst_actions = [] + + +class RecordHDF5Wrapper(gym.Wrapper): + + def __init__( + self, + env: gym.Env, + hdf5_folder: str, + length: int = 0, + name_prefix: str = "hdf5_record", + disable_logger: bool = False, + ): + gym.Wrapper.__init__(self, env) + + self.hdf5_folder = os.path.abspath(hdf5_folder) + self.hdf5_recorder = HDF5_Recorder() + + if env.image_state != "multi": + raise ValueError( + f"Image state is {env.image_state}, which is incompatible with" + f" RecordHDF5Wrapper. Initialize your environment with a image_state == multi" + ) + + # Create output folder if needed + if os.path.isdir(self.hdf5_folder): + logger.warn( + f"Overwriting existing videos at {self.hdf5_folder} folder " + f"(try specifying a different `hdf5_folder` for the `RecordHDF5` wrapper if this is not desired)" + ) + os.makedirs(self.hdf5_folder, exist_ok=True) + + self.name_prefix = name_prefix + self.length = length + self.terminated = False + self.episode_id = 0 + self.env = env + + try: + self.is_vector_env = self.get_wrapper_attr("is_vector_env") + except AttributeError: + self.is_vector_env = False + + def reset(self, **kwargs): + """Reset the environment using kwargs and then starts recording if video enabled.""" + observations, _ = self.env.reset(**kwargs) + self.terminated = False + self.start_hdf5_recorder() + return observations + + def start_hdf5_recorder(self): + """Starts video recorder using :class:`video_recorder.VideoRecorder`.""" + self.close_hdf5_recorder() + + video_name = f"{self.name_prefix}-episode-{self.episode_id}.hdf5" + self.hdf5_recorder.start_hdf5_recorder(hdf5_file=os.path.join(self.hdf5_folder, video_name)) + self.recording = True + + + def step(self, action): + """Steps through the environment using action, recording observations if :attr:`self.recording`.""" + + observations, rewards, terminateds, truncateds, infos = self.env.step(action) + + # increment steps and episodes + + if self.recording: + assert self.hdf5_recorder is not None + + self.hdf5_recorder.capture_frame(observations, action) + + if self.length > 0: + if self.hdf5_recorder.recorded_frames > self.length: + self.close_hdf5_recorder() + else: + if not self.is_vector_env: + if terminateds or truncateds: + self.close_hdf5_recorder() + elif terminateds[0] or truncateds[0]: + self.close_hdf5_recorder() + + return observations, rewards, terminateds, truncateds, infos + + def close_hdf5_recorder(self): + """Closes the hdf5 recorder if currently recording.""" + if self.hdf5_recorder is not None: + self.hdf5_recorder.close() + self.recorded_frames = 1 + + def render(self, *args, **kwargs): + return super().render(*args, **kwargs) + + def close(self): + """Closes the wrapper then the video recorder.""" + super().close() + self.close_hdf5_recorder()