Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into user/aliberts/2024_06_13_a…
Browse files Browse the repository at this point in the history
…dd_cam_capture
  • Loading branch information
aliberts committed Jun 13, 2024
2 parents 29dbd01 + 125bd93 commit f906983
Show file tree
Hide file tree
Showing 11 changed files with 754 additions and 423 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ jobs:
with:
lfs: true # Ensure LFS files are pulled

- name: Install EGL
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
- name: Install apt dependencies
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev ffmpeg

- name: Install poetry
run: |
Expand Down Expand Up @@ -72,6 +72,9 @@ jobs:
with:
lfs: true # Ensure LFS files are pulled

- name: Install apt dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg

- name: Install poetry
run: |
pipx install poetry && poetry config virtualenvs.in-project true
Expand Down Expand Up @@ -106,7 +109,7 @@ jobs:
with:
lfs: true # Ensure LFS files are pulled

- name: Install EGL
- name: Install apt dependencies
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev

- name: Install poetry
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,13 @@ To add a dataset to the hub, you need to login using a write-access token, which
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```

Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with:
Then point to your raw dataset folder (e.g. `data/aloha_static_pingpong_test_raw`), and push your dataset to the hub with:
```bash
python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \
--dataset-id aloha_static_pingpong_test \
--raw-format aloha_hdf5 \
--community-id lerobot
--raw-dir data/aloha_static_pingpong_test_raw \
--out-dir data \
--repo-id lerobot/aloha_static_pingpong_test \
--raw-format aloha_hdf5
```

See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
Expand Down
233 changes: 98 additions & 135 deletions lerobot/common/datasets/push_dataset_to_hub/_download_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,156 +14,119 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains all obsolete download scripts. They are centralized here to not have to load
useless dependencies when using datasets.
This file contains download scripts for raw datasets.
Example of usage:
```
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
--raw-dir data/cadene/pusht_raw \
--repo-id cadene/pusht_raw
```
"""

import io
import argparse
import logging
import shutil
import warnings
from pathlib import Path

import tqdm
from huggingface_hub import snapshot_download


def download_raw(raw_dir, dataset_id):
if "aloha" in dataset_id or "image" in dataset_id:
download_hub(raw_dir, dataset_id)
elif "pusht" in dataset_id:
download_pusht(raw_dir)
elif "xarm" in dataset_id:
download_xarm(raw_dir)
elif "umi" in dataset_id:
download_umi(raw_dir)
else:
raise ValueError(dataset_id)
def download_raw(raw_dir: Path, repo_id: str):
# Check repo_id is well formated
if len(repo_id.split("/")) != 2:
raise ValueError(
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'."
)
user_id, dataset_id = repo_id.split("/")


def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
import zipfile

import requests

print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)

zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))

progress_bar.close()

zip_file.seek(0)

with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)


def download_pusht(raw_dir: str):
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"

raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True)
download_and_extract_zip(pusht_url, raw_dir)
# file is created inside a useful "pusht" directory, so we move it out and delete the dir
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
shutil.move(raw_dir / "pusht" / "pusht_cchi_v7_replay.zarr", zarr_path)
shutil.rmtree(raw_dir / "pusht")


def download_xarm(raw_dir: Path):
"""Download all xarm datasets at once"""
import zipfile

import gdown
if not dataset_id.endswith("_raw"):
warnings.warn(
f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.",
stacklevel=1,
)

raw_dir = Path(raw_dir)
# Send warning if raw_dir isn't well formated
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
warnings.warn(
f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.",
stacklevel=1,
)
raw_dir.mkdir(parents=True, exist_ok=True)
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
print("Extracting...")
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
for pkl_path in zip_f.namelist():
if pkl_path.startswith("data/xarm") and pkl_path.endswith(".pkl"):
zip_f.extract(member=pkl_path)
# move to corresponding raw directory
extract_dir = pkl_path.replace("/buffer.pkl", "")
raw_pkl_path = raw_dir / "buffer.pkl"
shutil.move(pkl_path, raw_pkl_path)
shutil.rmtree(extract_dir)
zip_path.unlink()


def download_hub(raw_dir: Path, dataset_id: str):
raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True)

logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}")
snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")

logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")

def download_umi(raw_dir: Path):
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
zarr_path = raw_dir / "cup_in_the_wild.zarr"

raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True)
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
def download_all_raw_datasets():
data_dir = Path("data")
repo_ids = [
"cadene/pusht_image_raw",
"cadene/xarm_lift_medium_image_raw",
"cadene/xarm_lift_medium_replay_image_raw",
"cadene/xarm_push_medium_image_raw",
"cadene/xarm_push_medium_replay_image_raw",
"cadene/aloha_sim_insertion_human_image_raw",
"cadene/aloha_sim_insertion_scripted_image_raw",
"cadene/aloha_sim_transfer_cube_human_image_raw",
"cadene/aloha_sim_transfer_cube_scripted_image_raw",
"cadene/pusht_raw",
"cadene/xarm_lift_medium_raw",
"cadene/xarm_lift_medium_replay_raw",
"cadene/xarm_push_medium_raw",
"cadene/xarm_push_medium_replay_raw",
"cadene/aloha_sim_insertion_human_raw",
"cadene/aloha_sim_insertion_scripted_raw",
"cadene/aloha_sim_transfer_cube_human_raw",
"cadene/aloha_sim_transfer_cube_scripted_raw",
"cadene/aloha_mobile_cabinet_raw",
"cadene/aloha_mobile_chair_raw",
"cadene/aloha_mobile_elevator_raw",
"cadene/aloha_mobile_shrimp_raw",
"cadene/aloha_mobile_wash_pan_raw",
"cadene/aloha_mobile_wipe_wine_raw",
"cadene/aloha_static_battery_raw",
"cadene/aloha_static_candy_raw",
"cadene/aloha_static_coffee_raw",
"cadene/aloha_static_coffee_new_raw",
"cadene/aloha_static_cups_open_raw",
"cadene/aloha_static_fork_pick_up_raw",
"cadene/aloha_static_pingpong_test_raw",
"cadene/aloha_static_pro_pencil_raw",
"cadene/aloha_static_screw_driver_raw",
"cadene/aloha_static_tape_raw",
"cadene/aloha_static_thread_velcro_raw",
"cadene/aloha_static_towel_raw",
"cadene/aloha_static_vinh_cup_raw",
"cadene/aloha_static_vinh_cup_left_raw",
"cadene/aloha_static_ziploc_slide_raw",
"cadene/umi_cup_in_the_wild_raw",
]
for repo_id in repo_ids:
raw_dir = data_dir / repo_id
download_raw(raw_dir, repo_id)


def main():
parser = argparse.ArgumentParser()

parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).",
)
args = parser.parse_args()
download_raw(**vars(args))


if __name__ == "__main__":
data_dir = Path("data")
dataset_ids = [
"pusht_image",
"xarm_lift_medium_image",
"xarm_lift_medium_replay_image",
"xarm_push_medium_image",
"xarm_push_medium_replay_image",
"aloha_sim_insertion_human_image",
"aloha_sim_insertion_scripted_image",
"aloha_sim_transfer_cube_human_image",
"aloha_sim_transfer_cube_scripted_image",
"pusht",
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
"aloha_mobile_cabinet",
"aloha_mobile_chair",
"aloha_mobile_elevator",
"aloha_mobile_shrimp",
"aloha_mobile_wash_pan",
"aloha_mobile_wipe_wine",
"aloha_static_battery",
"aloha_static_candy",
"aloha_static_coffee",
"aloha_static_coffee_new",
"aloha_static_cups_open",
"aloha_static_fork_pick_up",
"aloha_static_pingpong_test",
"aloha_static_pro_pencil",
"aloha_static_screw_driver",
"aloha_static_tape",
"aloha_static_thread_velcro",
"aloha_static_towel",
"aloha_static_vinh_cup",
"aloha_static_vinh_cup_left",
"aloha_static_ziploc_slide",
"umi_cup_in_the_wild",
]
for dataset_id in dataset_ids:
raw_dir = data_dir / f"{dataset_id}_raw"
download_raw(raw_dir, dataset_id)
main()
Loading

0 comments on commit f906983

Please sign in to comment.