From b72d57489127d094a21c8e23cada40a383601a7a Mon Sep 17 00:00:00 2001 From: Jihoon Oh Date: Mon, 17 Jun 2024 23:17:28 +0900 Subject: [PATCH] fix Unet global_cond_dim to use state dim, not action dim (#278) --- lerobot/common/policies/diffusion/modeling_diffusion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 47378fdf5..335653995 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -165,7 +165,9 @@ def __init__(self, config: DiffusionConfig): num_images = len([k for k in config.input_shapes if k.startswith("observation.image")]) self.unet = DiffusionConditionalUnet1d( config, - global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim * num_images) + global_cond_dim=( + config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images + ) * config.n_obs_steps, )