Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix: missing attention mask in VAE encoder in ACT policy #279

Merged
merged 4 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,23 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)

# Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the
# sequence depending whether we use the input states or not (cls and robot state)
# False means not a padding token.
cls_joint_is_pad = torch.full(
(batch_size, 2 if self.use_input_state else 1),
False,
device=batch["observation.state"].device,
)
key_padding_mask = torch.cat(
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
) # (bs, seq+1 or 2)

# Forward pass through VAE encoder to get the latent PDF parameters.
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
vae_encoder_input.permute(1, 0, 2),
pos_embed=pos_embed.permute(1, 0, 2),
key_padding_mask=key_padding_mask,
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.config.latent_dim]
Expand Down Expand Up @@ -402,9 +416,11 @@ def __init__(self, config: ACTConfig):
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()

def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed)
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
x = self.norm(x)
return x

Expand All @@ -427,12 +443,13 @@ def __init__(self, config: ACTConfig):
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm

def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
q = k = x if pos_embed is None else x + pos_embed
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)
x = x[0] # note: [0] to select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
Expand Down
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
28 changes: 15 additions & 13 deletions tests/scripts/save_policy_to_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
return output_dict, grad_stats, param_stats, actions


def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}"
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"

if env_policy_dir.exists():
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
Expand All @@ -108,15 +108,17 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override

if __name__ == "__main__":
env_policies = [
("xarm", "tdmpc", []),
(
"pusht",
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
# ("xarm", "tdmpc", []),
# (
# "pusht",
# "diffusion",
# ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
# ),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
]
for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
for env, policy, extra_overrides, file_name_extra in env_policies:
save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
)
18 changes: 11 additions & 7 deletions tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,24 +315,26 @@ def test_normalize(insert_temporal_dim):


@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides",
"env_name, policy_name, extra_overrides, file_name_extra",
[
("xarm", "tdmpc", []),
("xarm", "tdmpc", [], ""),
(
"pusht",
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
"",
),
("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
("aloha", "act", ["policy.n_action_steps=10"], ""),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
@require_cpu
def test_backward_compatibility(env_name, policy_name, extra_overrides):
def test_backward_compatibility(env_name, policy_name, extra_overrides, file_name_extra):
"""
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
Expand All @@ -344,7 +346,9 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides):
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`.
"""
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
env_policy_dir = (
Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}"
)
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
Expand Down
Loading