Skip to content

Commit

Permalink
Move test to test_image_transforms.py
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Jun 17, 2024
1 parent ced58dc commit 1ac2984
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 76 deletions.
23 changes: 13 additions & 10 deletions lerobot/scripts/visualize_image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@
from lerobot.common.datasets.transforms import get_image_transforms

OUTPUT_DIR = Path("outputs/image_transforms")
N_EXAMPLES = 5
to_pil = ToPILImage()


def save_config_all_transforms(cfg, original_frame, output_dir):
def save_config_all_transforms(cfg, original_frame, output_dir, n_examples):
tf = get_image_transforms(
brightness_weight=cfg.brightness.weight,
brightness_min_max=cfg.brightness.min_max,
Expand All @@ -88,15 +87,15 @@ def save_config_all_transforms(cfg, original_frame, output_dir):
output_dir_all = output_dir / "all"
output_dir_all.mkdir(parents=True, exist_ok=True)

for i in range(1, N_EXAMPLES + 1):
for i in range(1, n_examples + 1):
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)

print("Combined transforms examples saved to:")
print(f" {output_dir_all}")


def save_config_single_transforms(cfg, original_frame, output_dir):
def save_config_single_transforms(cfg, original_frame, output_dir, n_examples):
transforms = [
"brightness",
"contrast",
Expand All @@ -115,7 +114,7 @@ def save_config_single_transforms(cfg, original_frame, output_dir):
output_dir_single = output_dir / f"{transform}"
output_dir_single.mkdir(parents=True, exist_ok=True)

for i in range(1, N_EXAMPLES + 1):
for i in range(1, n_examples + 1):
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)

Expand Down Expand Up @@ -151,11 +150,10 @@ def save_config_single_transforms(cfg, original_frame, output_dir):
print(f" {output_dir_single}")


@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def visualize_transforms(cfg):
def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
dataset = LeRobotDataset(cfg.dataset_repo_id)

output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1]
output_dir = output_dir / cfg.dataset_repo_id.split("/")[-1]
output_dir.mkdir(parents=True, exist_ok=True)

# Get 1st frame from 1st camera of 1st episode
Expand All @@ -164,8 +162,13 @@ def visualize_transforms(cfg):
print("\nOriginal frame saved to:")
print(f" {output_dir / 'original_frame.png'}.")

save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir)
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir)
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)


@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def visualize_transforms_cli(cfg):
visualize_transforms(cfg, output_dir=OUTPUT_DIR)


if __name__ == "__main__":
Expand Down
42 changes: 42 additions & 0 deletions tests/test_image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.visualize_image_transforms import visualize_transforms
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel

ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
Expand Down Expand Up @@ -258,3 +259,44 @@ def test_sharpness_jitter_invalid_range_min_negative():
def test_sharpness_jitter_invalid_range_max_smaller():
with pytest.raises(ValueError):
SharpnessJitter((2.0, 0.1))


@pytest.mark.parametrize(
"repo_id, n_examples",
[
("lerobot/aloha_sim_transfer_cube_human", 3),
],
)
def test_visualize_image_transforms(repo_id, n_examples):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"])
output_dir = Path(__file__).parent / "outputs" / "image_transforms"
visualize_transforms(cfg, output_dir=output_dir, n_examples=n_examples)
output_dir = output_dir / repo_id.split("/")[-1]

# Check if the original frame image exists
assert (output_dir / "original_frame.png").exists(), "Original frame image was not saved."

# Check if the transformed images exist for each transform type
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
for transform in transforms:
transform_dir = output_dir / transform
assert transform_dir.exists(), f"{transform} directory was not created."
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."

# Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
for file_name in expected_files:
assert (
transform_dir / file_name
).exists(), f"{file_name} was not found in {transform} directory."

# Check if the combined transforms directory exists and contains the right files
combined_transforms_dir = output_dir / "all"
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
assert any(
combined_transforms_dir.iterdir()
), "No transformed images found in combined transforms directory."
for i in range(1, n_examples + 1):
assert (
combined_transforms_dir / f"{i}.png"
).exists(), f"Combined transform image {i}.png was not found."
66 changes: 0 additions & 66 deletions tests/test_visualize_image_transforms.py

This file was deleted.

0 comments on commit 1ac2984

Please sign in to comment.