Skip to content

Commit

Permalink
Improve API, Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Jun 12, 2024
1 parent 0b847a5 commit 3d79727
Show file tree
Hide file tree
Showing 8 changed files with 601 additions and 372 deletions.
92 changes: 2 additions & 90 deletions lerobot/common/datasets/push_dataset_to_hub/_download_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,95 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains all obsolete download scripts. They are centralized here to not have to load
useless dependencies when using datasets.
This file contains download scripts for raw datasets.
"""

import io
import logging
import shutil
from pathlib import Path

import tqdm
from huggingface_hub import snapshot_download


def download_raw(raw_dir, dataset_id):
if "aloha" in dataset_id or "image" in dataset_id:
download_hub(raw_dir, dataset_id)
elif "pusht" in dataset_id:
download_pusht(raw_dir)
elif "xarm" in dataset_id:
download_xarm(raw_dir)
elif "umi" in dataset_id:
download_umi(raw_dir)
else:
raise ValueError(dataset_id)


def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
import zipfile

import requests

print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)

zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))

progress_bar.close()

zip_file.seek(0)

with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)


def download_pusht(raw_dir: str):
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"

raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True)
download_and_extract_zip(pusht_url, raw_dir)
# file is created inside a useful "pusht" directory, so we move it out and delete the dir
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
shutil.move(raw_dir / "pusht" / "pusht_cchi_v7_replay.zarr", zarr_path)
shutil.rmtree(raw_dir / "pusht")


def download_xarm(raw_dir: Path):
"""Download all xarm datasets at once"""
import zipfile

import gdown

raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True)
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
print("Extracting...")
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
for pkl_path in zip_f.namelist():
if pkl_path.startswith("data/xarm") and pkl_path.endswith(".pkl"):
zip_f.extract(member=pkl_path)
# move to corresponding raw directory
extract_dir = pkl_path.replace("/buffer.pkl", "")
raw_pkl_path = raw_dir / "buffer.pkl"
shutil.move(pkl_path, raw_pkl_path)
shutil.rmtree(extract_dir)
zip_path.unlink()


def download_hub(raw_dir: Path, dataset_id: str):
def download_raw(raw_dir: Path, dataset_id: str):
raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True)

Expand All @@ -111,15 +32,6 @@ def download_hub(raw_dir: Path, dataset_id: str):
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")


def download_umi(raw_dir: Path):
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
zarr_path = raw_dir / "cup_in_the_wild.zarr"

raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True)
download_and_extract_zip(url_cup_in_the_wild, zarr_path)


if __name__ == "__main__":
data_dir = Path("data")
dataset_ids = [
Expand Down
46 changes: 24 additions & 22 deletions lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
Expand Down Expand Up @@ -70,16 +71,17 @@ def check_format(raw_dir) -> bool:
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."


def load_from_raw(raw_dir, out_dir, fps, video, debug):
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
# only frames from simulation are uncompressed
compressed_images = "sim" not in raw_dir.name

hdf5_files = list(raw_dir.glob("*.hdf5"))
ep_dicts = []
episode_data_index = {"from": [], "to": []}
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
num_episodes = len(hdf5_files)

id_from = 0
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
ep_dicts = []
ep_ids = episodes if episodes else range(num_episodes)
for ep_idx in tqdm.tqdm(ep_ids):
ep_path = hdf5_files[ep_idx]
with h5py.File(ep_path, "r") as ep:
num_frames = ep["/action"].shape[0]

Expand Down Expand Up @@ -114,12 +116,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):

if video:
# save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images"
tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)

# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)

# clean temporary images directory
Expand Down Expand Up @@ -147,19 +149,13 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
assert isinstance(ep_idx, int)
ep_dicts.append(ep_dict)

episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)

id_from += num_frames

gc.collect()

# process first episode only
if debug:
break

data_dict = concatenate_episodes(ep_dicts)
return data_dict, episode_data_index

total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict


def to_hf_dataset(data_dict, video) -> Dataset:
Expand Down Expand Up @@ -197,16 +193,22 @@ def to_hf_dataset(data_dict, video) -> Dataset:
return hf_dataset


def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
# sanity check
check_format(raw_dir)

if fps is None:
fps = 50

data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
hf_dataset = to_hf_dataset(data_dir, video)

data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Contains utilities to process raw data format from dora-record
"""

import logging
import re
from pathlib import Path

Expand All @@ -26,10 +25,10 @@
from datasets import Dataset, Features, Image, Sequence, Value

from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame
from lerobot.common.utils.utils import init_logging


def check_format(raw_dir) -> bool:
Expand All @@ -41,7 +40,7 @@ def check_format(raw_dir) -> bool:
return True


def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
# Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0:
Expand Down Expand Up @@ -122,8 +121,7 @@ def get_episode_index(row):
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")

# Create symlink to raw videos directory (that needs to be absolute not relative)
out_dir.mkdir(parents=True, exist_ok=True)
videos_dir = out_dir / "videos"
videos_dir.parent.mkdir(parents=True, exist_ok=True)
videos_dir.symlink_to((raw_dir / "videos").absolute())

# sanity check the video paths are well formated
Expand Down Expand Up @@ -156,16 +154,7 @@ def get_episode_index(row):
else:
raise ValueError(key)

# Get the episode index containing for each unique episode index
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
from_ = first_ep_index_df["start_index"].tolist()
to_ = from_[1:] + [len(df)]
episode_data_index = {
"from": from_,
"to": to_,
}

return data_dict, episode_data_index
return data_dict


def to_hf_dataset(data_dict, video) -> Dataset:
Expand Down Expand Up @@ -203,12 +192,13 @@ def to_hf_dataset(data_dict, video) -> Dataset:
return hf_dataset


def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
init_logging()

if debug:
logging.warning("debug=True not implemented. Falling back to debug=False.")

def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
# sanity check
check_format(raw_dir)

Expand All @@ -220,9 +210,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
if not video:
raise NotImplementedError()

data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps)
data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
hf_dataset = to_hf_dataset(data_df, video)

episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,
Expand Down
Loading

0 comments on commit 3d79727

Please sign in to comment.