diff --git a/lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py b/lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py index dbfd73fe1..854d31a68 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py @@ -103,6 +103,7 @@ def load_from_raw( states = [] actions = [] + rewards = torch.zeros(num_frames, dtype=torch.float32) ep_dict = {} image_array_dict = {key: [] for key in image_keys} @@ -112,9 +113,9 @@ def load_from_raw( for j, step in enumerate(steps): states.append(tf_to_torch(step['observation']['state'])) actions.append(tf_to_torch(step['action'])) + rewards[j] = torch.tensor(step['reward'].numpy(), dtype=torch.float32) - # if "language_text" in step: - # print(" - lang: ", step["language_text"]) + # TODO: language_text, is_terminal, is_last etc. for im_key in image_keys: if im_key not in step['observation']: @@ -156,6 +157,7 @@ def load_from_raw( ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames) ep_dict["frame_index"] = torch.arange(0, num_frames, 1) + ep_dict["reward"] = rewards ep_dict["next.done"] = done ep_dicts.append(ep_dict) @@ -198,6 +200,7 @@ def to_hf_dataset(data_dict, video) -> Dataset: features["episode_index"] = Value(dtype="int64", id=None) features["frame_index"] = Value(dtype="int64", id=None) features["timestamp"] = Value(dtype="float32", id=None) + features["reward"] = Value(dtype="float32", id=None) features["next.done"] = Value(dtype="bool", id=None) features["index"] = Value(dtype="int64", id=None) @@ -229,8 +232,8 @@ def from_raw_to_lerobot_format( if __name__ == "__main__": # TODO (YL) remove this - raw_dir = Path("/hdd/serl/serl_task1_combine_13jun/") - videos_dir = Path("/hdd/serl/tmp/") + raw_dir = Path("/hdd/tensorflow_datasets/austin_buds_dataset_converted_externally_to_rlds/0.1.0/") + videos_dir = Path("/hdd/tmp/") hf_dataset, episode_data_index, info = from_raw_to_lerobot_format( raw_dir, videos_dir, fps=5, video=True, episodes=None, )