Skip to content

Commit

Permalink
align mujoco and real world pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdavidfagan committed May 8, 2024
1 parent 63b2837 commit 8f6fb4e
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 35 deletions.
16 changes: 16 additions & 0 deletions robot_learning_baselines/config/dataset/debug-transporter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,19 @@ tfds_data_dir: /media/peter/400cb321-2b54-4ffc-a661-2d49ca87dfaf/transporter_dat
shuffle_buffer_size: 10
batch_size: ${config.training.transporter_pick.batch_size}

huggingface:
entity: peterdavidfagan
repo: transporter_networks_mujoco
files:
- colour_splitter_2024-05-08-12:35:43.tar.gz
#- data.tar.xz

crop:
u_min: 0
u_max: 640
v_min: 0
v_max: 480
# u_min: 395
# u_max: 755
# v_min: 220
# v_max: 580
75 changes: 55 additions & 20 deletions robot_learning_baselines/train_transporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from time import time
from tqdm import tqdm
from functools import partial

# linear algebra and deep learning frameworks
import numpy as np
Expand All @@ -16,7 +17,6 @@

# dataset
import tensorflow_datasets as tfds
from envlogger import reader

# model architecture
from transporter_networks.transporter import (
Expand All @@ -38,8 +38,8 @@

# custom training pipeline utilities
from utils.data import (
load_transporter_dataset,
preprocess_transporter_batch,
load_hf_transporter_dataset,
preprocess_transporter,
)

from utils.pipeline import (
Expand All @@ -62,16 +62,33 @@ def main(cfg: DictConfig) -> None:
assert jax.default_backend() != "cpu" # ensure accelerator is available
cfg = cfg["config"] # some hacky and wacky stuff from hydra (TODO: revise)

# precompile data preprocessing based on config
crop = cfg.dataset.crop
preprocess_transporter_batch = jax.jit(
jax.vmap(
partial(
preprocess_transporter,
crop_idx=(crop["v_min"], crop["u_min"], crop["v_max"], crop["u_max"])),
in_axes=(0, 0, 0, 0)
)
)


key = random.PRNGKey(0)
pick_model_key, place_model_key = jax.random.split(key, 2)

train_data = load_transporter_dataset(cfg.dataset)
train_data = load_hf_transporter_dataset(cfg.dataset)
cardinality = train_data.reduce(0, lambda x,_: x+1).numpy()

if cfg.wandb.use:
init_wandb(cfg)
batch = next(train_data.as_numpy_iterator())
(rgbd, rgbd_crop), (rgbd_normalized, rgbd_crop_normalized), pixels, ids = preprocess_transporter_batch(batch)
(rgbd, rgbd_crop), (rgbd_normalized, rgbd_crop_normalized), pixels, ids = preprocess_transporter_batch(
jnp.asarray(batch['pick_rgb']),
jnp.asarray(batch['pick_depth']),
jnp.asarray(batch['pick_pixel_coords']),
jnp.asarray(batch['place_pixel_coords']),
)
batch = {
"rgbd": rgbd,
"rgbd_crop": rgbd_crop,
Expand All @@ -90,7 +107,12 @@ def main(cfg: DictConfig) -> None:


batch = next(train_data.as_numpy_iterator())
(rgbd, rgbd_crop), (rgbd_normalized, rgbd_crop_normalized), pixels, ids = preprocess_transporter_batch(batch)
(rgbd, rgbd_crop), (rgbd_normalized, rgbd_crop_normalized), pixels, ids = preprocess_transporter_batch(
jnp.asarray(batch['pick_rgb']),
jnp.asarray(batch['pick_depth']),
jnp.asarray(batch['pick_pixel_coords']),
jnp.asarray(batch['place_pixel_coords'])
)
eval_data = {
"rgbd": rgbd,
"rgbd_crop": rgbd_crop,
Expand Down Expand Up @@ -132,19 +154,26 @@ def main(cfg: DictConfig) -> None:
}

# shuffle dataset
train_data = train_data.shuffle(10)
train_data_iter = train_data.as_numpy_iterator()
train_data_epoch = train_data.shuffle(16)

# TODO: get dataset size and use tqdm
for batch in tqdm(train_data_iter, leave=False, total=cardinality):
(rgbd, rgbd_crop), (rgbd_normalized, rgbd_crop_normalized), pixels, ids = preprocess_transporter_batch(batch)

for batch in tqdm(train_data_epoch, leave=False, total=cardinality):
(rgbd, rgbd_crop), (rgbd_normalized, rgbd_crop_normalized), pixels, ids = preprocess_transporter_batch(
jnp.asarray(batch['pick_rgb']),
jnp.asarray(batch['pick_depth']),
jnp.asarray(batch['pick_pixel_coords']),
jnp.asarray(batch['place_pixel_coords']),
)

# compute ce loss for pick network and update pick network
pick_train_state, pick_loss = pick_train_step(transporter.pick_model_state, rgbd_normalized, ids[0])
pick_train_state, pick_loss, pick_success_rate = pick_train_step(
transporter.pick_model_state,
rgbd_normalized,
ids[0])
transporter = transporter.replace(pick_model_state=pick_train_state)

# compute ce loss for place networks and update place network
place_train_state, place_loss = place_train_step(
place_train_state, place_loss, place_success_rate = place_train_step(
transporter.place_model_state,
rgbd_normalized,
rgbd_crop_normalized,
Expand All @@ -154,14 +183,15 @@ def main(cfg: DictConfig) -> None:


# report epoch metrics (optionally add to wandb)
pick_loss_epoch = transporter.pick_model_state.metrics.compute()
place_loss_epoch = transporter.place_model_state.metrics.compute()
print(f"Epoch {epoch}: pick_loss: {pick_loss_epoch}, place_loss: {place_loss_epoch}")
pick_metrics = transporter.pick_model_state.metrics.compute()
place_metrics = transporter.place_model_state.metrics.compute()

if cfg.wandb.use:
if cfg.wandb.use and (epoch%5==0):
wandb.log({
"pick_loss": pick_loss_epoch,
"place_loss": place_loss_epoch,
"pick_train_loss": pick_metrics["loss"],
"place_train_loss": place_metrics["loss"],
"pick_train_success_rate": pick_metrics["success_rate"],
"place_train_success_rate":place_metrics["success_rate"],
"epoch": epoch
})
visualize_transporter_predictions(cfg, transporter, eval_data, epoch)
Expand All @@ -170,6 +200,11 @@ def main(cfg: DictConfig) -> None:
transporter.pick_model_state.replace(metrics=pick_train_state.metrics.empty())
transporter.place_model_state.replace(metrics=place_train_state.metrics.empty())

# save model checkpoint
pick_chkpt_manager.save(epoch, pick_train_state)
place_chkpt_manager.save(epoch, place_train_state)



if __name__ == "__main__":
main()
24 changes: 11 additions & 13 deletions robot_learning_baselines/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@ def episode_step_to_transition(episode):
# download data from huggingface
DOWNLOAD_PATH="/tmp/transporter_dataset"
COMPRESSED_FILENAME="data.tar.xz"
hf_hub_download(
repo_id="peterdavidfagan/transporter_networks",
repo_type="dataset",
filename="data.tar.xz",
local_dir=DOWNLOAD_PATH,
)
for file in cfg["huggingface"]["files"]:
hf_hub_download(
repo_id=f"{cfg['huggingface']['entity']}/{cfg['huggingface']['repo']}",
repo_type="dataset",
filename=file,
local_dir=DOWNLOAD_PATH,
)

# uncompress file
COMPRESSED_FILEPATH=os.path.join(DOWNLOAD_PATH, COMPRESSED_FILENAME)
with tarfile.open(COMPRESSED_FILEPATH, 'r:xz') as tar:
tar.extractall(path=DOWNLOAD_PATH)
os.remove(COMPRESSED_FILEPATH)
COMPRESSED_FILEPATH=os.path.join(DOWNLOAD_PATH, file)
with tarfile.open(COMPRESSED_FILEPATH, 'r:xz') as tar:
tar.extractall(path=DOWNLOAD_PATH)
os.remove(COMPRESSED_FILEPATH)

# load with tfds
ds = tfds.builder_from_directory(DOWNLOAD_PATH).as_dataset(split="train")
Expand Down Expand Up @@ -306,8 +306,6 @@ def preprocess_transporter(rgb, depth, pick_pixels, place_pixels, crop_idx):

return (rgbd_crop_raw, rgbd_pick_crop_raw), (rgbd_crop, rgbd_pick_crop), (pick_pixels, place_pixels), (pick_id, place_id)

# hardcoded crop_idx for the huggingface dataset from Edinburgh University RAD Lab.
preprocess_transporter_batch = jax.jit(jax.vmap(partial(preprocess_transporter, crop_idx=(220, 395, 580, 755)), in_axes=(0, 0, 0, 0)))

def preprocess_batch(batch, text_tokenize_fn, action_head_type="diffusion", dummy=False):
"""
Expand Down
4 changes: 2 additions & 2 deletions robot_learning_baselines/utils/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,12 @@ def visualize_transporter_predictions(cfg, transporter, raw_batch, epoch):
# inspect model predictions
pick_pred_ = pick_pred[i,:].copy()
pick_pred_ = (pick_pred_ - pick_pred_.min()) / ((pick_pred_.max() - pick_pred_.min()))
pick_heatmap = pick_pred_.reshape((360, 360))
pick_heatmap = pick_pred_.reshape(rgb.shape[:2])
pick_heatmap = Image.fromarray(np.asarray(cm.viridis(pick_heatmap)*255, dtype=np.uint8))

place_pred_ = place_pred[i,:].copy()
place_pred_ = (place_pred_ - place_pred_.min()) / ((place_pred_.max() - place_pred_.min()))
place_heatmap = place_pred_.reshape((360, 360))
place_heatmap = place_pred_.reshape(rgb.shape[:2])
place_heatmap = Image.fromarray(np.asarray(cm.viridis(place_heatmap)*255, dtype=np.uint8))

data.append([
Expand Down

0 comments on commit 8f6fb4e

Please sign in to comment.