Skip to content

Commit

Permalink
add reward
Browse files Browse the repository at this point in the history
Signed-off-by: youliangtan <tan_you_liang@hotmail.com>
  • Loading branch information
youliangtan committed Jun 20, 2024
1 parent 3e4d7be commit 61e51c9
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def load_from_raw(

states = []
actions = []
rewards = torch.zeros(num_frames, dtype=torch.float32)
ep_dict = {}

image_array_dict = {key: [] for key in image_keys}
Expand All @@ -112,9 +113,9 @@ def load_from_raw(
for j, step in enumerate(steps):
states.append(tf_to_torch(step['observation']['state']))
actions.append(tf_to_torch(step['action']))
rewards[j] = torch.tensor(step['reward'].numpy(), dtype=torch.float32)

# if "language_text" in step:
# print(" - lang: ", step["language_text"])
# TODO: language_text, is_terminal, is_last etc.

for im_key in image_keys:
if im_key not in step['observation']:
Expand Down Expand Up @@ -156,6 +157,7 @@ def load_from_raw(
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["reward"] = rewards
ep_dict["next.done"] = done

ep_dicts.append(ep_dict)
Expand Down Expand Up @@ -198,6 +200,7 @@ def to_hf_dataset(data_dict, video) -> Dataset:
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["reward"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None)

Expand Down Expand Up @@ -229,8 +232,8 @@ def from_raw_to_lerobot_format(

if __name__ == "__main__":
# TODO (YL) remove this
raw_dir = Path("/hdd/serl/serl_task1_combine_13jun/")
videos_dir = Path("/hdd/serl/tmp/")
raw_dir = Path("/hdd/tensorflow_datasets/austin_buds_dataset_converted_externally_to_rlds/0.1.0/")
videos_dir = Path("/hdd/tmp/")
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
raw_dir, videos_dir, fps=5, video=True, episodes=None,
)
Expand Down

0 comments on commit 61e51c9

Please sign in to comment.