From c38f535c9fb16179e4745dc91f4ea55ac4328dd4 Mon Sep 17 00:00:00 2001 From: Marina Barannikov Date: Wed, 12 Jun 2024 19:45:42 +0200 Subject: [PATCH] FIx make_dataset to match transforms config (#264) Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- .github/workflows/test.yml | 2 ++ Makefile | 7 ++++++- lerobot/common/datasets/factory.py | 25 +++++++++++++------------ 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a466cff7d..f10f541e2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,6 +10,7 @@ on: - "examples/**" - ".github/**" - "poetry.lock" + - "Makefile" push: branches: - main @@ -19,6 +20,7 @@ on: - "examples/**" - ".github/**" - "poetry.lock" + - "Makefile" jobs: pytest: diff --git a/Makefile b/Makefile index 33f3edf2e..9bac437d6 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ PYTHON_PATH := $(shell which python) # If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python POETRY_CHECK := $(shell command -v poetry) ifneq ($(POETRY_CHECK),) - PYTHON_PATH := $(shell poetry run which python) + PYTHON_PATH := $(shell poetry run which python) endif export PATH := $(dir $(PYTHON_PATH)):$(PATH) @@ -46,6 +46,7 @@ test-act-ete-train: policy.n_action_steps=20 \ policy.chunk_size=20 \ training.batch_size=2 \ + training.image_transforms.enable=true \ hydra.run.dir=tests/outputs/act/ test-act-ete-eval: @@ -73,6 +74,7 @@ test-act-ete-train-amp: policy.chunk_size=20 \ training.batch_size=2 \ hydra.run.dir=tests/outputs/act_amp/ \ + training.image_transforms.enable=true \ use_amp=true test-act-ete-eval-amp: @@ -100,6 +102,7 @@ test-diffusion-ete-train: training.save_checkpoint=true \ training.save_freq=2 \ training.batch_size=2 \ + training.image_transforms.enable=true \ hydra.run.dir=tests/outputs/diffusion/ test-diffusion-ete-eval: @@ -127,6 +130,7 @@ test-tdmpc-ete-train: training.save_checkpoint=true \ training.save_freq=2 \ training.batch_size=2 \ + training.image_transforms.enable=true \ hydra.run.dir=tests/outputs/tdmpc/ test-tdmpc-ete-eval: @@ -159,5 +163,6 @@ test-act-pusht-tutorial: training.save_model=true \ training.save_freq=2 \ training.batch_size=2 \ + training.image_transforms.enable=true \ hydra.run.dir=tests/outputs/act_pusht/ rm lerobot/configs/policy/created_by_Makefile.yaml diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index fab8ca575..754bc91b2 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -74,19 +74,20 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData image_transforms = None if cfg.training.image_transforms.enable: + cfg_tf = cfg.training.image_transforms image_transforms = get_image_transforms( - brightness_weight=cfg.brightness.weight, - brightness_min_max=cfg.brightness.min_max, - contrast_weight=cfg.contrast.weight, - contrast_min_max=cfg.contrast.min_max, - saturation_weight=cfg.saturation.weight, - saturation_min_max=cfg.saturation.min_max, - hue_weight=cfg.hue.weight, - hue_min_max=cfg.hue.min_max, - sharpness_weight=cfg.sharpness.weight, - sharpness_min_max=cfg.sharpness.min_max, - max_num_transforms=cfg.max_num_transforms, - random_order=cfg.random_order, + brightness_weight=cfg_tf.brightness.weight, + brightness_min_max=cfg_tf.brightness.min_max, + contrast_weight=cfg_tf.contrast.weight, + contrast_min_max=cfg_tf.contrast.min_max, + saturation_weight=cfg_tf.saturation.weight, + saturation_min_max=cfg_tf.saturation.min_max, + hue_weight=cfg_tf.hue.weight, + hue_min_max=cfg_tf.hue.min_max, + sharpness_weight=cfg_tf.sharpness.weight, + sharpness_min_max=cfg_tf.sharpness.min_max, + max_num_transforms=cfg_tf.max_num_transforms, + random_order=cfg_tf.random_order, ) if isinstance(cfg.dataset_repo_id, str):