From e28fa2344c545c2ce35ee5948836d3dce8b30d69 Mon Sep 17 00:00:00 2001 From: Marina Barannikov Date: Mon, 17 Jun 2024 09:09:57 +0200 Subject: [PATCH] added visualization for min and max transforms (#271) Co-authored-by: Simon Alibert --- lerobot/scripts/visualize_image_transforms.py | 53 +++++++++++++++---- tests/test_image_transforms.py | 42 +++++++++++++++ 2 files changed, 85 insertions(+), 10 deletions(-) diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py index fa3c0ab2a..b4558dc8c 100644 --- a/lerobot/scripts/visualize_image_transforms.py +++ b/lerobot/scripts/visualize_image_transforms.py @@ -65,11 +65,10 @@ from lerobot.common.datasets.transforms import get_image_transforms OUTPUT_DIR = Path("outputs/image_transforms") -N_EXAMPLES = 5 to_pil = ToPILImage() -def save_config_all_transforms(cfg, original_frame, output_dir): +def save_config_all_transforms(cfg, original_frame, output_dir, n_examples): tf = get_image_transforms( brightness_weight=cfg.brightness.weight, brightness_min_max=cfg.brightness.min_max, @@ -88,7 +87,7 @@ def save_config_all_transforms(cfg, original_frame, output_dir): output_dir_all = output_dir / "all" output_dir_all.mkdir(parents=True, exist_ok=True) - for i in range(1, N_EXAMPLES + 1): + for i in range(1, n_examples + 1): transformed_frame = tf(original_frame) to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100) @@ -96,7 +95,7 @@ def save_config_all_transforms(cfg, original_frame, output_dir): print(f" {output_dir_all}") -def save_config_single_transforms(cfg, original_frame, output_dir): +def save_config_single_transforms(cfg, original_frame, output_dir, n_examples): transforms = [ "brightness", "contrast", @@ -106,6 +105,7 @@ def save_config_single_transforms(cfg, original_frame, output_dir): ] print("Individual transforms examples saved to:") for transform in transforms: + # Apply one transformation with random value in min_max range kwargs = { f"{transform}_weight": cfg[f"{transform}"].weight, f"{transform}_min_max": cfg[f"{transform}"].min_max, @@ -114,18 +114,46 @@ def save_config_single_transforms(cfg, original_frame, output_dir): output_dir_single = output_dir / f"{transform}" output_dir_single.mkdir(parents=True, exist_ok=True) - for i in range(1, N_EXAMPLES + 1): + for i in range(1, n_examples + 1): transformed_frame = tf(original_frame) to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100) + # Apply min transformation + min_value, max_value = cfg[f"{transform}"].min_max + kwargs = { + f"{transform}_weight": cfg[f"{transform}"].weight, + f"{transform}_min_max": (min_value, min_value), + } + tf = get_image_transforms(**kwargs) + transformed_frame = tf(original_frame) + to_pil(transformed_frame).save(output_dir_single / "min.png", quality=100) + + # Apply max transformation + kwargs = { + f"{transform}_weight": cfg[f"{transform}"].weight, + f"{transform}_min_max": (max_value, max_value), + } + tf = get_image_transforms(**kwargs) + transformed_frame = tf(original_frame) + to_pil(transformed_frame).save(output_dir_single / "max.png", quality=100) + + # Apply mean transformation + mean_value = (min_value + max_value) / 2 + kwargs = { + f"{transform}_weight": cfg[f"{transform}"].weight, + f"{transform}_min_max": (mean_value, mean_value), + } + tf = get_image_transforms(**kwargs) + transformed_frame = tf(original_frame) + to_pil(transformed_frame).save(output_dir_single / "mean.png", quality=100) + print(f" {output_dir_single}") -@hydra.main(version_base="1.2", config_name="default", config_path="../configs") -def visualize_transforms(cfg): +def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5): dataset = LeRobotDataset(cfg.dataset_repo_id) - output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1] + output_dir = output_dir / cfg.dataset_repo_id.split("/")[-1] output_dir.mkdir(parents=True, exist_ok=True) # Get 1st frame from 1st camera of 1st episode @@ -134,8 +162,13 @@ def visualize_transforms(cfg): print("\nOriginal frame saved to:") print(f" {output_dir / 'original_frame.png'}.") - save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir) - save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir) + save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples) + save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples) + + +@hydra.main(version_base="1.2", config_name="default", config_path="../configs") +def visualize_transforms_cli(cfg): + visualize_transforms(cfg, output_dir=OUTPUT_DIR) if __name__ == "__main__": diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index ba6d972f3..ccc40ddfc 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -26,6 +26,7 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms from lerobot.common.utils.utils import init_hydra_config, seeded_context +from lerobot.scripts.visualize_image_transforms import visualize_transforms from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors") @@ -258,3 +259,44 @@ def test_sharpness_jitter_invalid_range_min_negative(): def test_sharpness_jitter_invalid_range_max_smaller(): with pytest.raises(ValueError): SharpnessJitter((2.0, 0.1)) + + +@pytest.mark.parametrize( + "repo_id, n_examples", + [ + ("lerobot/aloha_sim_transfer_cube_human", 3), + ], +) +def test_visualize_image_transforms(repo_id, n_examples): + cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"]) + output_dir = Path(__file__).parent / "outputs" / "image_transforms" + visualize_transforms(cfg, output_dir=output_dir, n_examples=n_examples) + output_dir = output_dir / repo_id.split("/")[-1] + + # Check if the original frame image exists + assert (output_dir / "original_frame.png").exists(), "Original frame image was not saved." + + # Check if the transformed images exist for each transform type + transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"] + for transform in transforms: + transform_dir = output_dir / transform + assert transform_dir.exists(), f"{transform} directory was not created." + assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory." + + # Check for specific files within each transform directory + expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"] + for file_name in expected_files: + assert ( + transform_dir / file_name + ).exists(), f"{file_name} was not found in {transform} directory." + + # Check if the combined transforms directory exists and contains the right files + combined_transforms_dir = output_dir / "all" + assert combined_transforms_dir.exists(), "Combined transforms directory was not created." + assert any( + combined_transforms_dir.iterdir() + ), "No transformed images found in combined transforms directory." + for i in range(1, n_examples + 1): + assert ( + combined_transforms_dir / f"{i}.png" + ).exists(), f"Combined transform image {i}.png was not found."