Skip to content

Commit

Permalink
fix Unet global_cond_dim to use state dim, not action dim (#278)
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh6404 authored Jun 17, 2024
1 parent 15dd682 commit b72d574
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ def __init__(self, config: DiffusionConfig):
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self.unet = DiffusionConditionalUnet1d(
config,
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim * num_images)
global_cond_dim=(
config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images
)
* config.n_obs_steps,
)

Expand Down

0 comments on commit b72d574

Please sign in to comment.