diff --git a/lerobot/common/envs/pusht/pusht_image_env.py b/lerobot/common/envs/pusht/pusht_image_env.py index 5f7bc03c8..2d52c89e4 100644 --- a/lerobot/common/envs/pusht/pusht_image_env.py +++ b/lerobot/common/envs/pusht/pusht_image_env.py @@ -25,7 +25,7 @@ def _get_obs(self): img = super()._render_frame(mode="rgb_array") agent_pos = np.array(self.agent.position) - img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0) + img_obs = np.moveaxis(img.astype(np.float32), -1, 0) obs = {"image": img_obs, "agent_pos": agent_pos} # draw action diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index 0b4bba7dd..91472dd5e 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -123,6 +123,8 @@ def __init__( if imagenet_norm: # TODO(rcadene): move normalizer to dataset and env this_normalizer = torchvision.transforms.Normalize( + # Note: This matches the normalization in the original impl. for PushT Image. This may not be + # the case for other tasks. mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], ) diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 28fd4e4e2..f07e4754c 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -42,8 +42,8 @@ policy: num_inference_steps: 100 obs_as_global_cond: ${obs_as_global_cond} # crop_shape: null - diffusion_step_embed_dim: 256 # before 128 - down_dims: [256, 512, 1024] # before [512, 1024, 2048] + diffusion_step_embed_dim: 128 + down_dims: [512, 1024, 2048] kernel_size: 5 n_groups: 8 cond_predict_scale: True @@ -109,13 +109,13 @@ training: debug: False resume: True # optimization - # lr_scheduler: cosine - # lr_warmup_steps: 500 - num_epochs: 8000 + lr_scheduler: cosine + lr_warmup_steps: 500 + num_epochs: 500 # gradient_accumulate_every: 1 # EMA destroys performance when used with BatchNorm # replace BatchNorm with GroupNorm. - # use_ema: True + use_ema: True freeze_encoder: False # training loop control # in epochs