Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix: missing attention mask in VAE encoder in ACT policy #279

Merged
merged 4 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,16 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)

cls_joint_is_pad = torch.full((batch_size, 2), False).to(
batch["observation.state"].device
) # False: not a padding
key_padding_mask = torch.cat([cls_joint_is_pad, batch["action_is_pad"]], axis=1) # (bs, seq+1)

alexander-soare marked this conversation as resolved.
Show resolved Hide resolved
# Forward pass through VAE encoder to get the latent PDF parameters.
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
vae_encoder_input.permute(1, 0, 2),
pos_embed=pos_embed.permute(1, 0, 2),
key_padding_mask=key_padding_mask,
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.config.latent_dim]
Expand Down Expand Up @@ -402,9 +409,11 @@ def __init__(self, config: ACTConfig):
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()

def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed)
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
x = self.norm(x)
return x

Expand All @@ -427,12 +436,14 @@ def __init__(self, config: ACTConfig):
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm

def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
q = k = x if pos_embed is None else x + pos_embed
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)[
0
] # select just the output, not the attention weights
alexander-soare marked this conversation as resolved.
Show resolved Hide resolved
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
Expand Down
83 changes: 83 additions & 0 deletions lerobot/configs/policy/act_1000_actions.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# @package _global_

seed: 1000
dataset_repo_id: lerobot/aloha_sim_insertion_human

override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)

training:
offline_steps: 80000
online_steps: 0
eval_freq: 10000
save_freq: 100000
log_freq: 250
save_checkpoint: true

batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1

delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"

eval:
n_episodes: 50
batch_size: 50

# See `configuration_act.py` for more details.
policy:
name: act

# Input / output structure.
n_obs_steps: 1
chunk_size: 1000 # chunk_size
n_action_steps: 1000

input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.top: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]

# Normalization / Unnormalization
input_normalization_modes:
observation.images.top: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std

# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4

# Inference.
temporal_ensemble_momentum: null

# Training and loss computation.
dropout: 0.1
kl_weight: 10.0
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
18 changes: 9 additions & 9 deletions tests/scripts/save_policy_to_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,15 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override

if __name__ == "__main__":
env_policies = [
("xarm", "tdmpc", []),
(
"pusht",
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
# ("xarm", "tdmpc", []),
# (
# "pusht",
# "diffusion",
# ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
# ),
("aloha", "act_1000_actions", []),
Copy link
Collaborator

@alexander-soare alexander-soare Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you mentioned in your PR message, I think we can avoid adding a whole config file by doing

Suggested change
("aloha", "act_1000_actions", []),
("aloha", "act_1000_actions", ["policy.n_action_steps=1000", "policy.chunk_size=1000"]),

(and in the test)
Feels like a no-brainer to me but maybe I'm missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes in this case we still need a act_1000_actions config file...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe @aliberts you have an opinion on this question around tests and config file.

See my How it was tested section in this PR description.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, I meant also to change to using just the act.json file. Also happy to hear what @aliberts says.

Copy link
Member Author

@thomwolf thomwolf Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes this also doesn't work because the safetensor files overide themselves (cf my comment in the PR description. maybe a bit short) unless we either:

  • decide we test always in this 1000 action setup for the other ACT policy tests as well
  • decide to update the way we name safetensor files

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks. My stance on this is it's better to add test artifacts than to add config files to the source code. That can even mean just moving the yaml file to the artifacts directory. Also happy with both of your proposals. Waiting on @aliberts

Copy link
Collaborator

@aliberts aliberts Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless we either:

  • decide we test always in this 1000 action setup for the other ACT policy tests as well
  • decide to update the way we name safetensor files

IMO both are fine (the naming trick you went with is okay, although it adds artifacts they're relatively small). I was concerned with speed too but from what I'm seeing in CI tests it doesn't increase tests duration (cf this branch vs main)

Long-term though, I think it'll be much better to have more fine-grained tests and test individual components of the policies and do away with these artifacts, similar to what's done in transformers. My motivations for this are:

  • Whenever this test breaks, it doesn't tell us which part broke it.
  • Having this many tests artifacts as we do right now is really not great for cloning the repo.
  • It's kernel-dependent.

For now it's okay because this granularity allows us to iterate faster while still having guardrails. Happy to hear your opinions on this as well.

# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
]
for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
1 change: 1 addition & 0 deletions tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def test_normalize(insert_temporal_dim):
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
("aloha", "act_1000_actions", []),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
],
Expand Down
Loading