Skip to content

Commit

Permalink
added possibility to record with a policy; added temporary fixes to t…
Browse files Browse the repository at this point in the history
…rain.py to enable training on mac
  • Loading branch information
michel-aractingi committed Oct 24, 2024
1 parent 9a5356d commit 5e01c21
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 51 deletions.
6 changes: 3 additions & 3 deletions lerobot/common/policies/tdmpc/modeling_tdmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
batch[key] = batch[key].transpose(1, 0)

action = batch["action"] # (t, b, action_dim)
reward = batch["reward"] # (t, b)
reward = batch["next.reward"] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}

# Apply random image augmentations.
Expand Down Expand Up @@ -422,7 +422,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
(
temporal_loss_coeffs
* F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["reward_is_pad"]
* ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
Expand All @@ -443,7 +443,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["reward_is_pad"]
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
Expand Down
105 changes: 94 additions & 11 deletions lerobot/scripts/control_sim_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,15 @@
from pathlib import Path
import gymnasium as gym
import multiprocessing
from contextlib import nullcontext


import cv2
import torch
import numpy as np
import tqdm
from omegaconf import DictConfig

from PIL import Image
from datasets import Dataset, Features, Sequence, Value

Expand All @@ -99,12 +103,15 @@
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch, hf_transform_to_torch
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.envs.factory import make_env
from lerobot.common.utils.utils import init_hydra_config, init_logging
from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import (
push_dataset_card_to_hub,
push_meta_data_to_hub,
Expand Down Expand Up @@ -178,6 +185,29 @@ def is_headless():
print()
return True

def get_action_from_policy(policy, observation, device, use_amp=False):
with (
torch.inference_mode(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and use_amp
else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)

# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
return action.to("cpu")

def init_read_leader(robot, fps, **kwargs):
axis_directions = kwargs.get('axis_directions', [1])
offsets = kwargs.get('offsets', [0])
Expand Down Expand Up @@ -240,7 +270,7 @@ def create_rl_hf_dataset(data_dict):
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
features["reward"] = Value(dtype="float32", id=None)
features["next.reward"] = Value(dtype="float32", id=None)

features["seed"] = Value(dtype="int64", id=None)
features["episode_index"] = Value(dtype="int64", id=None)
Expand Down Expand Up @@ -277,6 +307,8 @@ def teleoperate(env, robot: Robot, teleop_time_s=None, **kwargs):
def record(
env,
robot: Robot,
policy: torch.nn.Module | None = None,
policy_cfg: DictConfig | None = None,
fps: int | None = None,
root="data",
repo_id="lerobot/debug",
Expand Down Expand Up @@ -355,7 +387,23 @@ def on_press(key):
num_image_writers = num_image_writers_per_camera * 2 ###############
num_image_writers = max(num_image_writers, 1)

read_leader, command_queue = init_read_leader(robot, fps, **kwargs)
# Load policy if any
if policy is not None:
# Check device is available
device = get_safe_torch_device(policy_cfg.device, log=True)

policy.eval()
policy.to(device)

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(policy_cfg.seed)

# override fps using policy fps
fps = policy_cfg.env.fps
else:
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)

if not is_headless() and visualize_images:
observations_queue = multiprocessing.Queue(1000)
show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, ))
Expand All @@ -369,7 +417,7 @@ def on_press(key):
while episode_index < num_episodes:
logging.info(f"Recording episode {episode_index}")
say(f"Recording episode {episode_index}")
ep_dict = {'action':[], 'reward':[]}
ep_dict = {'action':[], 'next.reward':[]}
for k in state_keys_dict:
ep_dict[k] = []
frame_index = 0
Expand All @@ -381,9 +429,14 @@ def on_press(key):
observation, info = env.reset(seed=seed)
#with stop_reading_leader.get_lock():
#stop_reading_leader.Value = 0
read_leader.start()
if policy is None:
read_leader.start()
while timestamp < episode_time_s:
action = command_queue.get()
if policy is None:
action = command_queue.get()
else:
action = get_action_from_policy(policy, observation)

for key in image_keys:
str_key = key if key.startswith('observation.images.') else 'observation.images.' + key
futures += [
Expand All @@ -402,7 +455,7 @@ def on_press(key):
action = np.expand_dims(action, 0)
observation, reward, _, _ , info = env.step(action)
ep_dict['action'].append(torch.from_numpy(action))
ep_dict['reward'].append(torch.tensor(reward))
ep_dict['next.reward'].append(torch.tensor(reward))
print(reward)

frame_index += 1
Expand All @@ -417,9 +470,10 @@ def on_press(key):
#stop_reading_leader.Value = 1
# TODO (michel_aractinig): temp fix until I figure out the problem with shared memory
# stop_reading_leader is blocking
command_queue.close()
read_leader.terminate()
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)
if policy is None:
command_queue.close()
read_leader.terminate()
read_leader, command_queue = init_read_leader(robot, fps, **kwargs)

timestamp = 0

Expand Down Expand Up @@ -451,7 +505,7 @@ def on_press(key):
for key in state_keys_dict:
ep_dict[key] = torch.vstack(ep_dict[key]) * 180.0 / np.pi
ep_dict['action'] = torch.vstack(ep_dict['action']) * 180.0 / np.pi
ep_dict['reward'] = torch.stack(ep_dict['reward'])
ep_dict['next.reward'] = torch.stack(ep_dict['next.reward'])

ep_dict["seed"] = torch.tensor([seed] * num_frames)
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
Expand Down Expand Up @@ -577,7 +631,11 @@ def on_press(key):
return lerobot_dataset


def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="lerobot/debug"):
def replay(env,
episodes: list,
fps: int | None = None,
root="data",
repo_id="lerobot/debug"):

env = env()
local_dir = Path(root) / repo_id
Expand Down Expand Up @@ -700,6 +758,21 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
default=0,
help="Visualize image observations with opencv.",
)
parser_record.add_argument(
"-p",
"--pretrained-policy-name-or-path",
type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
),
)
parser_record.add_argument(
"--policy-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)

parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
Expand Down Expand Up @@ -748,6 +821,16 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
teleoperate(env_fn, robot, **kwargs)

elif control_mode == "record":
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
policy_overrides = args.policy_overrides
del kwargs["pretrained_policy_name_or_path"]
del kwargs["policy_overrides"]

if pretrained_policy_name_or_path is not None:
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
kwargs["policy_cfg"] = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
kwargs["policy"] = make_policy(hydra_cfg=kwargs["policy_cfg"], pretrained_policy_name_or_path=pretrained_policy_path)

record(env_fn, robot, **kwargs)

elif control_mode == "replay":
Expand Down
83 changes: 46 additions & 37 deletions lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def update_policy(

# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer)
#with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()

Expand Down Expand Up @@ -311,6 +311,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No

logging.info("make_dataset")
offline_dataset = make_dataset(cfg)

remove_indices=['observation.images.image_top', 'observation.velocity', 'seed']
# temp fix michel_Aractingi TODO
offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(remove_indices)

if isinstance(offline_dataset, MultiLeRobotDataset):
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
Expand Down Expand Up @@ -477,7 +482,7 @@ def evaluate_and_checkpoint_if_needed(step, is_online):
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
"next.done": {"shape": (), "dtype": np.dtype("?")},
"next.success": {"shape": (), "dtype": np.dtype("?")},
#"next.success": {"shape": (), "dtype": np.dtype("?")},
},
buffer_capacity=cfg.training.online_buffer_capacity,
fps=online_env.unwrapped.metadata["render_fps"],
Expand All @@ -504,6 +509,9 @@ def evaluate_and_checkpoint_if_needed(step, is_online):
num_samples=len(concat_dataset),
replacement=True,
)

# TODO michel_aractingi temp fix for incosistent keys

dataloader = torch.utils.data.DataLoader(
concat_dataset,
batch_size=cfg.training.batch_size,
Expand Down Expand Up @@ -538,8 +546,8 @@ def evaluate_and_checkpoint_if_needed(step, is_online):

def sample_trajectory_and_update_buffer():
nonlocal rollout_start_seed
with lock:
online_rollout_policy.load_state_dict(policy.state_dict())
#with lock:
online_rollout_policy.load_state_dict(policy.state_dict())
online_rollout_policy.eval()
start_rollout_time = time.perf_counter()
with torch.no_grad():
Expand All @@ -556,37 +564,35 @@ def sample_trajectory_and_update_buffer():
)
online_rollout_s = time.perf_counter() - start_rollout_time

with lock:
start_update_buffer_time = time.perf_counter()
online_dataset.add_data(eval_info["episodes"])

# Update the concatenated dataset length used during sampling.
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)

# Update the sampling weights.
sampler.weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
online_dataset=online_dataset,
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# this final observation in the offline datasets, but we might add them in future.
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.training.online_sampling_ratio,
)
sampler.num_samples = len(concat_dataset)

update_online_buffer_s = time.perf_counter() - start_update_buffer_time
#with lock:
start_update_buffer_time = time.perf_counter()
online_dataset.add_data(eval_info["episodes"])
# Update the concatenated dataset length used during sampling.
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# Update the sampling weights.
sampler.weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
online_dataset=online_dataset,
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# this final observation in the offline datasets, but we might add them in future.
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.training.online_sampling_ratio,
)
sampler.num_samples = len(concat_dataset)
update_online_buffer_s = time.perf_counter() - start_update_buffer_time

return online_rollout_s, update_online_buffer_s

future = executor.submit(sample_trajectory_and_update_buffer)
# TODO remove parallelization for sim
#future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
if (
not cfg.training.do_online_rollout_async
or len(online_dataset) <= cfg.training.online_buffer_seed_size
):
online_rollout_s, update_online_buffer_s = future.result()
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()#future.result()

if len(online_dataset) <= cfg.training.online_buffer_seed_size:
logging.info(
Expand All @@ -596,12 +602,15 @@ def sample_trajectory_and_update_buffer():

policy.train()
for _ in range(cfg.training.online_steps_between_rollouts):
with lock:
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
#with lock:
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time

for key in batch:
# TODO michel aractingi convert float64 to float32 for mac
if batch[key].dtype == torch.float64:
batch[key] = batch[key].float()
batch[key] = batch[key].to(cfg.device, non_blocking=True)

train_info = update_policy(
Expand All @@ -619,8 +628,8 @@ def sample_trajectory_and_update_buffer():
train_info["online_rollout_s"] = online_rollout_s
train_info["update_online_buffer_s"] = update_online_buffer_s
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
with lock:
train_info["online_buffer_size"] = len(online_dataset)
#with lock:
train_info["online_buffer_size"] = len(online_dataset)

if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
Expand All @@ -634,10 +643,10 @@ def sample_trajectory_and_update_buffer():

# If we're doing async rollouts, we should now wait until we've completed them before proceeding
# to do the next batch of rollouts.
if future.running():
start = time.perf_counter()
online_rollout_s, update_online_buffer_s = future.result()
await_update_online_buffer_s = time.perf_counter() - start
#if future.running():
#start = time.perf_counter()
#online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()#future.result()
#await_update_online_buffer_s = time.perf_counter() - start

if online_step >= cfg.training.online_steps:
break
Expand Down

0 comments on commit 5e01c21

Please sign in to comment.