From 33166e1d4344bddce1a57d49201d5fed840e4eb4 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Wed, 19 Jun 2024 10:07:41 +0200 Subject: [PATCH] update --- lerobot/common/policies/normalize.py | 23 ++++++++++-------- tests/test_policies.py | 36 +++++++--------------------- 2 files changed, 21 insertions(+), 38 deletions(-) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 8f9096a6d..11f0a7eca 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -21,7 +21,6 @@ def create_stats_buffers( shapes: dict[str, list[int]], modes: dict[str, str], stats: dict[str, dict[str, Tensor]] | None = None, - std_epsilon: float = 1e-5, ) -> dict[str, dict[str, nn.ParameterDict]]: """ Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max @@ -79,14 +78,10 @@ def create_stats_buffers( # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. if mode == "mean_std": buffer["mean"].data = stats[key]["mean"].clone() - buffer["std"].data = stats[key]["std"].clone().clamp_min(std_epsilon) + buffer["std"].data = stats[key]["std"].clone() elif mode == "min_max": buffer["min"].data = stats[key]["min"].clone() buffer["max"].data = stats[key]["max"].clone() - epsilon = (std_epsilon - (stats[key]["max"] - stats[key]["min"]).abs()).clamp_min( - 0 - ) # To add to have at least std_epsilon between min and max - buffer["max"].data += epsilon stats_buffers[key] = buffer return stats_buffers @@ -134,7 +129,8 @@ def __init__( self.shapes = shapes self.modes = modes self.stats = stats - stats_buffers = create_stats_buffers(shapes, modes, stats, std_epsilon=std_epsilon) + self.std_epsilon = std_epsilon + stats_buffers = create_stats_buffers(shapes, modes, stats) for key, buffer in stats_buffers.items(): setattr(self, "buffer_" + key.replace(".", "_"), buffer) @@ -150,12 +146,15 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: std = buffer["std"] assert not torch.isinf(mean).any(), _no_stats_error_str("mean") assert not torch.isinf(std).any(), _no_stats_error_str("std") - output_batch[key] = (batch[key] - mean) / std + output_batch[key] = (batch[key] - mean) / std.clamp_min(self.std_epsilon) elif mode == "min_max": min = buffer["min"] max = buffer["max"] assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(max).any(), _no_stats_error_str("max") + # To add to have at least std_epsilon between min and max + epsilon = (self.std_epsilon - (max - min).abs()).clamp_min(0) + max = max + epsilon # normalize to [0,1] output_batch[key] = (batch[key] - min) / (max - min) # normalize to [-1, 1] @@ -207,8 +206,9 @@ def __init__( self.shapes = shapes self.modes = modes self.stats = stats + self.std_epsilon = std_epsilon # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` - stats_buffers = create_stats_buffers(shapes, modes, stats, std_epsilon=std_epsilon) + stats_buffers = create_stats_buffers(shapes, modes, stats) for key, buffer in stats_buffers.items(): setattr(self, "buffer_" + key.replace(".", "_"), buffer) @@ -224,12 +224,15 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: std = buffer["std"] assert not torch.isinf(mean).any(), _no_stats_error_str("mean") assert not torch.isinf(std).any(), _no_stats_error_str("std") - output_batch[key] = batch[key] * std + mean + output_batch[key] = batch[key] * std.clamp_min(self.std_epsilon) + mean elif mode == "min_max": min = buffer["min"] max = buffer["max"] assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(max).any(), _no_stats_error_str("max") + # To add to have at least std_epsilon between min and max + epsilon = (self.std_epsilon - (max - min).abs()).clamp_min(0) + max = max + epsilon output_batch[key] = (batch[key] + 1) / 2 output_batch[key] = output_batch[key] * (max - min) + min else: diff --git a/tests/test_policies.py b/tests/test_policies.py index 310e530af..acb2d9369 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -331,24 +331,14 @@ def test_normalize(insert_temporal_dim): ).all() assert torch.isclose( - normalize.buffer_action_test_std_cap.std[0], - dataset_stats["action_test_std_cap"]["std"][0], + normalize.buffer_action_test_std_cap.std, + dataset_stats["action_test_std_cap"]["std"], rtol=0.1, atol=1e-7, ).all() assert torch.isclose( - normalize.buffer_action_test_std_cap.std[1], torch.ones(1) * std_epsilon, rtol=0.1, atol=1e-7 - ).all() - assert torch.isclose( - normalize.buffer_action_test_min_max_cap.max[0] - normalize.buffer_action_test_min_max_cap.min[0], - dataset_stats["action_test_min_max_cap"]["max"][0] - - dataset_stats["action_test_min_max_cap"]["min"][0], - rtol=0.1, - atol=1e-7, - ).all() - assert torch.isclose( - normalize.buffer_action_test_min_max_cap.max[1] - normalize.buffer_action_test_min_max_cap.min[1], - torch.ones(1) * std_epsilon, + normalize.buffer_action_test_min_max_cap.max - normalize.buffer_action_test_min_max_cap.min, + dataset_stats["action_test_min_max_cap"]["max"] - dataset_stats["action_test_min_max_cap"]["min"], rtol=0.1, atol=1e-7, ).all() @@ -496,24 +486,14 @@ def test_normalize(insert_temporal_dim): ).all() assert torch.isclose( - unnormalize.buffer_action_test_std_cap.std[0], - dataset_stats["action_test_std_cap"]["std"][0], - rtol=0.1, - atol=1e-7, - ).all() - assert torch.isclose( - unnormalize.buffer_action_test_std_cap.std[1], torch.ones(1) * std_epsilon, rtol=0.1, atol=1e-7 - ).all() - assert torch.isclose( - unnormalize.buffer_action_test_min_max_cap.max[0] - unnormalize.buffer_action_test_min_max_cap.min[0], - dataset_stats["action_test_min_max_cap"]["max"][0] - - dataset_stats["action_test_min_max_cap"]["min"][0], + unnormalize.buffer_action_test_std_cap.std, + dataset_stats["action_test_std_cap"]["std"], rtol=0.1, atol=1e-7, ).all() assert torch.isclose( - unnormalize.buffer_action_test_min_max_cap.max[1] - unnormalize.buffer_action_test_min_max_cap.min[1], - torch.ones(1) * std_epsilon, + unnormalize.buffer_action_test_min_max_cap.max - unnormalize.buffer_action_test_min_max_cap.min, + dataset_stats["action_test_min_max_cap"]["max"] - dataset_stats["action_test_min_max_cap"]["min"], rtol=0.1, atol=1e-7, ).all()