From 48951662f221a05fb65f5c5106d25dfc3cb1270c Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Wed, 19 Jun 2024 13:07:21 +0200 Subject: [PATCH] Bug fix: missing attention mask in VAE encoder in ACT policy (#279) Co-authored-by: Alexander Soare --- lerobot/common/policies/act/modeling_act.py | 27 ++++++++++++++---- .../aloha_act_1000_steps/actions.safetensors | 3 ++ .../grad_stats.safetensors | 3 ++ .../output_dict.safetensors | 3 ++ .../param_stats.safetensors | 3 ++ tests/scripts/save_policy_to_safetensors.py | 28 ++++++++++--------- tests/test_policies.py | 18 +++++++----- 7 files changed, 60 insertions(+), 25 deletions(-) create mode 100644 tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_act_1000_steps/output_dict.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index bef59becb..5f302bc7a 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -314,9 +314,23 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) + # Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the + # sequence depending whether we use the input states or not (cls and robot state) + # False means not a padding token. + cls_joint_is_pad = torch.full( + (batch_size, 2 if self.use_input_state else 1), + False, + device=batch["observation.state"].device, + ) + key_padding_mask = torch.cat( + [cls_joint_is_pad, batch["action_is_pad"]], axis=1 + ) # (bs, seq+1 or 2) + # Forward pass through VAE encoder to get the latent PDF parameters. cls_token_out = self.vae_encoder( - vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) + vae_encoder_input.permute(1, 0, 2), + pos_embed=pos_embed.permute(1, 0, 2), + key_padding_mask=key_padding_mask, )[0] # select the class token, with shape (B, D) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) mu = latent_pdf_params[:, : self.config.latent_dim] @@ -402,9 +416,11 @@ def __init__(self, config: ACTConfig): self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)]) self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() - def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor: + def forward( + self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None + ) -> Tensor: for layer in self.layers: - x = layer(x, pos_embed=pos_embed) + x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask) x = self.norm(x) return x @@ -427,12 +443,13 @@ def __init__(self, config: ACTConfig): self.activation = get_activation_fn(config.feedforward_activation) self.pre_norm = config.pre_norm - def forward(self, x, pos_embed: Tensor | None = None) -> Tensor: + def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor: skip = x if self.pre_norm: x = self.norm1(x) q = k = x if pos_embed is None else x + pos_embed - x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights + x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask) + x = x[0] # note: [0] to select just the output, not the attention weights x = skip + self.dropout1(x) if self.pre_norm: skip = x diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors new file mode 100644 index 000000000..1529153d1 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f4e0e525aeb22ea94b79e26b39a87e6f2da9fbee33e493906aaf2aad9a7c1ef +size 515400 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors new file mode 100644 index 000000000..6a359f4e3 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6dc658a1c1616c7d1c211eb8f87cec3d44f7b67d6b3cea7a6ce12b32d74674da +size 31688 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/output_dict.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/output_dict.safetensors new file mode 100644 index 000000000..099011101 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03971f92b7907b6b7e6ac207f508666104cd84c26c5276f510c431db604e188b +size 68 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors new file mode 100644 index 000000000..157c382c3 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01d993c67a9267032fe9fbeff20b4359c209464976ea503040a0a76ae213450a +size 33408 diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 961b7cef1..5fead55a0 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -89,8 +89,8 @@ def get_policy_stats(env_name, policy_name, extra_overrides): return output_dict, grad_stats, param_stats, actions -def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides): - env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}" +def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra): + env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}" if env_policy_dir.exists(): print(f"Overwrite existing safetensors in '{env_policy_dir}':") @@ -108,15 +108,17 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override if __name__ == "__main__": env_policies = [ - ("xarm", "tdmpc", []), - ( - "pusht", - "diffusion", - ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], - ), - ("aloha", "act", ["policy.n_action_steps=10"]), - ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), - ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), + # ("xarm", "tdmpc", []), + # ( + # "pusht", + # "diffusion", + # ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], + # ), + ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"), + # ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), + # ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), ] - for env, policy, extra_overrides in env_policies: - save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) + for env, policy, extra_overrides, file_name_extra in env_policies: + save_policy_to_safetensors( + "tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra + ) diff --git a/tests/test_policies.py b/tests/test_policies.py index 95da20c9f..fdc747513 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -315,24 +315,26 @@ def test_normalize(insert_temporal_dim): @pytest.mark.parametrize( - "env_name, policy_name, extra_overrides", + "env_name, policy_name, extra_overrides, file_name_extra", [ - ("xarm", "tdmpc", []), + ("xarm", "tdmpc", [], ""), ( "pusht", "diffusion", ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], + "", ), - ("aloha", "act", ["policy.n_action_steps=10"]), - ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), - ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), + ("aloha", "act", ["policy.n_action_steps=10"], ""), + ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"), + ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""), + ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""), ], ) # As artifacts have been generated on an x86_64 kernel, this test won't # pass if it's run on another platform due to floating point errors @require_x86_64_kernel @require_cpu -def test_backward_compatibility(env_name, policy_name, extra_overrides): +def test_backward_compatibility(env_name, policy_name, extra_overrides, file_name_extra): """ NOTE: If this test does not pass, and you have intentionally changed something in the policy: 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should @@ -344,7 +346,9 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides): 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. 6. Remember to stage and commit the resulting changes to `tests/data`. """ - env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}" + env_policy_dir = ( + Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}" + ) saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors") saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors") saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")