Skip to content

Commit

Permalink
wip - still need to verify full training run
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-soare committed Mar 11, 2024
1 parent 304355c commit 87fcc53
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion lerobot/common/envs/pusht/pusht_image_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _get_obs(self):
img = super()._render_frame(mode="rgb_array")

agent_pos = np.array(self.agent.position)
img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0)
img_obs = np.moveaxis(img.astype(np.float32), -1, 0)
obs = {"image": img_obs, "agent_pos": agent_pos}

# draw action
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def __init__(
if imagenet_norm:
# TODO(rcadene): move normalizer to dataset and env
this_normalizer = torchvision.transforms.Normalize(
# Note: This matches the normalization in the original impl. for PushT Image. This may not be
# the case for other tasks.
mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
)
Expand Down
12 changes: 6 additions & 6 deletions lerobot/configs/policy/diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ policy:
num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null
diffusion_step_embed_dim: 256 # before 128
down_dims: [256, 512, 1024] # before [512, 1024, 2048]
diffusion_step_embed_dim: 128
down_dims: [512, 1024, 2048]
kernel_size: 5
n_groups: 8
cond_predict_scale: True
Expand Down Expand Up @@ -109,13 +109,13 @@ training:
debug: False
resume: True
# optimization
# lr_scheduler: cosine
# lr_warmup_steps: 500
num_epochs: 8000
lr_scheduler: cosine
lr_warmup_steps: 500
num_epochs: 500
# gradient_accumulate_every: 1
# EMA destroys performance when used with BatchNorm
# replace BatchNorm with GroupNorm.
# use_ema: True
use_ema: True
freeze_encoder: False
# training loop control
# in epochs
Expand Down

0 comments on commit 87fcc53

Please sign in to comment.