Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added visualization for min and max transforms #271

Merged
merged 6 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions lerobot/scripts/visualize_image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@
from lerobot.common.datasets.transforms import get_image_transforms

OUTPUT_DIR = Path("outputs/image_transforms")
N_EXAMPLES = 5
to_pil = ToPILImage()


def save_config_all_transforms(cfg, original_frame, output_dir):
def save_config_all_transforms(cfg, original_frame, output_dir, n_examples):
tf = get_image_transforms(
brightness_weight=cfg.brightness.weight,
brightness_min_max=cfg.brightness.min_max,
Expand All @@ -88,15 +87,15 @@ def save_config_all_transforms(cfg, original_frame, output_dir):
output_dir_all = output_dir / "all"
output_dir_all.mkdir(parents=True, exist_ok=True)

for i in range(1, N_EXAMPLES + 1):
for i in range(1, n_examples + 1):
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)

print("Combined transforms examples saved to:")
print(f" {output_dir_all}")


def save_config_single_transforms(cfg, original_frame, output_dir):
def save_config_single_transforms(cfg, original_frame, output_dir, n_examples):
transforms = [
"brightness",
"contrast",
Expand All @@ -106,6 +105,7 @@ def save_config_single_transforms(cfg, original_frame, output_dir):
]
print("Individual transforms examples saved to:")
for transform in transforms:
# Apply one transformation with random value in min_max range
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": cfg[f"{transform}"].min_max,
Expand All @@ -114,18 +114,46 @@ def save_config_single_transforms(cfg, original_frame, output_dir):
output_dir_single = output_dir / f"{transform}"
output_dir_single.mkdir(parents=True, exist_ok=True)

for i in range(1, N_EXAMPLES + 1):
for i in range(1, n_examples + 1):
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)

# Apply min transformation
min_value, max_value = cfg[f"{transform}"].min_max
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": (min_value, min_value),
}
tf = get_image_transforms(**kwargs)
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / "min.png", quality=100)

# Apply max transformation
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": (max_value, max_value),
}
tf = get_image_transforms(**kwargs)
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / "max.png", quality=100)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add this one as well?

Suggested change
# Apply mean transformation
mean_value = (min_value + max_value) / 2
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": (mean_value, mean_value),
}
tf = get_image_transforms(**kwargs)
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / "mean.png", quality=100)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f98c9a7 done

# Apply mean transformation
mean_value = (min_value + max_value) / 2
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": (mean_value, mean_value),
}
tf = get_image_transforms(**kwargs)
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / "mean.png", quality=100)

print(f" {output_dir_single}")


@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def visualize_transforms(cfg):
def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
dataset = LeRobotDataset(cfg.dataset_repo_id)

output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1]
output_dir = output_dir / cfg.dataset_repo_id.split("/")[-1]
output_dir.mkdir(parents=True, exist_ok=True)

# Get 1st frame from 1st camera of 1st episode
Expand All @@ -134,8 +162,13 @@ def visualize_transforms(cfg):
print("\nOriginal frame saved to:")
print(f" {output_dir / 'original_frame.png'}.")

save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir)
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir)
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)


@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def visualize_transforms_cli(cfg):
visualize_transforms(cfg, output_dir=OUTPUT_DIR)


if __name__ == "__main__":
Expand Down
42 changes: 42 additions & 0 deletions tests/test_image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.visualize_image_transforms import visualize_transforms
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel

ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
Expand Down Expand Up @@ -258,3 +259,44 @@ def test_sharpness_jitter_invalid_range_min_negative():
def test_sharpness_jitter_invalid_range_max_smaller():
with pytest.raises(ValueError):
SharpnessJitter((2.0, 0.1))


@pytest.mark.parametrize(
"repo_id, n_examples",
[
("lerobot/aloha_sim_transfer_cube_human", 3),
],
)
def test_visualize_image_transforms(repo_id, n_examples):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"])
output_dir = Path(__file__).parent / "outputs" / "image_transforms"
visualize_transforms(cfg, output_dir=output_dir, n_examples=n_examples)
output_dir = output_dir / repo_id.split("/")[-1]

# Check if the original frame image exists
assert (output_dir / "original_frame.png").exists(), "Original frame image was not saved."

# Check if the transformed images exist for each transform type
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
for transform in transforms:
transform_dir = output_dir / transform
assert transform_dir.exists(), f"{transform} directory was not created."
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."

# Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
for file_name in expected_files:
assert (
transform_dir / file_name
).exists(), f"{file_name} was not found in {transform} directory."

# Check if the combined transforms directory exists and contains the right files
combined_transforms_dir = output_dir / "all"
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
assert any(
combined_transforms_dir.iterdir()
), "No transformed images found in combined transforms directory."
for i in range(1, n_examples + 1):
assert (
combined_transforms_dir / f"{i}.png"
).exists(), f"Combined transform image {i}.png was not found."
Loading