Skip to content

Commit

Permalink
Merge branch 'isaac-sim:main' into feature/hyperparam_tune
Browse files Browse the repository at this point in the history
  • Loading branch information
glvov-bdai authored Oct 28, 2024
2 parents 4196c8e + 9cc298e commit 551003b
Show file tree
Hide file tree
Showing 27 changed files with 759 additions and 35 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Guidelines for modifications:
* Chenyu Yang
* David Yang
* Dorsa Rohani
* Felix Yu
* Gary Lvov
* Giulio Romualdi
* HoJin Jeon
Expand Down
11 changes: 9 additions & 2 deletions docs/source/overview/environments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-sty
| | | |
| | |cartpole-depth-direct-link|| |
+------------------+-----------------------------+-------------------------------------------------------------------------+
| |cartpole| | |cartpole-resnet-link| | Move the cart to keep the pole upwards in the classic cartpole control |
| | | based off of features extracted from perceptive inputs with pre-trained |
| | |cartpole-theia-link| | frozen vision encoders |
+------------------+-----------------------------+-------------------------------------------------------------------------+

.. |humanoid| image:: ../_static/tasks/classic/humanoid.jpg
.. |ant| image:: ../_static/tasks/classic/ant.jpg
Expand All @@ -69,8 +73,11 @@ Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-sty
.. |humanoid-link| replace:: `Isaac-Humanoid-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/humanoid/humanoid_env_cfg.py>`__
.. |ant-link| replace:: `Isaac-Ant-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/ant/ant_env_cfg.py>`__
.. |cartpole-link| replace:: `Isaac-Cartpole-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_env_cfg.py>`__
.. |cartpole-rgb-link| replace:: `Isaac-Cartpole-RGB-Camera-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
.. |cartpole-depth-link| replace:: `Isaac-Cartpole-Depth-Camera-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
.. |cartpole-rgb-link| replace:: `Isaac-Cartpole-RGB-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
.. |cartpole-depth-link| replace:: `Isaac-Cartpole-Depth-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
.. |cartpole-resnet-link| replace:: `Isaac-Cartpole-RGB-ResNet18-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
.. |cartpole-theia-link| replace:: `Isaac-Cartpole-RGB-TheiaTiny-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__


.. |humanoid-direct-link| replace:: `Isaac-Humanoid-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/humanoid/humanoid_env.py>`__
.. |ant-direct-link| replace:: `Isaac-Ant-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/ant/ant_env.py>`__
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ extra_standard_library = [
"toml",
"trimesh",
"tqdm",
"torchvision",
"transformers",
"einops" # Needed for transformers, doesn't always auto-install
]
# Imports from Isaac Sim and Omniverse
known_third_party = [
Expand Down
3 changes: 2 additions & 1 deletion source/extensions/omni.isaac.lab/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.27.4"

version = "0.27.7"

# Description
title = "Isaac Lab framework for Robot Learning"
Expand Down
29 changes: 29 additions & 0 deletions source/extensions/omni.isaac.lab/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
Changelog
---------


0.27.7 (2024-10-28)
~~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added frozen encoder feature extraction observation space with ResNet and Theia


0.27.6 (2024-10-25)
~~~~~~~~~~~~~~~~~~~

Fixed
^^^^^

* Fixed usage of ``meshes`` property in :class:`omni.isaac.lab.sensors.RayCasterCamera` to use ``self.meshes`` instead of the undefined ``RayCaster.meshes``.
* Fixed issue in :class:`omni.isaac.lab.envs.ui.BaseEnvWindow` where undefined configs were being accessed when creating debug visualization elements in UI.


0.27.5 (2024-10-25)
~~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added utilities for serializing/deserializing Gymnasium spaces.


0.27.4 (2024-10-18)
~~~~~~~~~~~~~~~~~~~

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
import omni.isaac.lab.utils.math as math_utils
from omni.isaac.lab.assets import Articulation, RigidObject
from omni.isaac.lab.managers import SceneEntityCfg
from omni.isaac.lab.managers.manager_base import ManagerTermBase
from omni.isaac.lab.managers.manager_term_cfg import ObservationTermCfg
from omni.isaac.lab.sensors import Camera, Imu, RayCaster, RayCasterCamera, TiledCamera

if TYPE_CHECKING:
from omni.isaac.lab.envs import ManagerBasedEnv, ManagerBasedRLEnv


"""
Root state.
"""
Expand Down Expand Up @@ -273,6 +276,134 @@ def image(
return images.clone()


class image_features(ManagerTermBase):
"""Extracted image features from a pre-trained frozen encoder.
This method calls the :meth:`image` function to retrieve images, and then performs
inference on those images.
"""

def __init__(self, cfg: ObservationTermCfg, env: ManagerBasedEnv):
super().__init__(cfg, env)
from torchvision import models
from transformers import AutoModel

def create_theia_model(model_name):
return {
"model": (
lambda: AutoModel.from_pretrained(f"theaiinstitute/{model_name}", trust_remote_code=True)
.eval()
.to("cuda:0")
),
"preprocess": lambda img: (img - torch.amin(img, dim=(1, 2), keepdim=True)) / (
torch.amax(img, dim=(1, 2), keepdim=True) - torch.amin(img, dim=(1, 2), keepdim=True)
),
"inference": lambda model, images: model.forward_feature(
images, do_rescale=False, interpolate_pos_encoding=True
),
}

def create_resnet_model(resnet_name):
return {
"model": lambda: getattr(models, resnet_name)(pretrained=True).eval().to("cuda:0"),
"preprocess": lambda img: (
img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width]
- torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1)
) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1),
"inference": lambda model, images: model(images),
}

# List of Theia models
theia_models = [
"theia-tiny-patch16-224-cddsv",
"theia-tiny-patch16-224-cdiv",
"theia-small-patch16-224-cdiv",
"theia-base-patch16-224-cdiv",
"theia-small-patch16-224-cddsv",
"theia-base-patch16-224-cddsv",
]

# List of ResNet models
resnet_models = ["resnet18", "resnet34", "resnet50", "resnet101"]

self.default_model_zoo_cfg = {}

# Add Theia models to the zoo
for model_name in theia_models:
self.default_model_zoo_cfg[model_name] = create_theia_model(model_name)

# Add ResNet models to the zoo
for resnet_name in resnet_models:
self.default_model_zoo_cfg[resnet_name] = create_resnet_model(resnet_name)

self.model_zoo_cfg = self.default_model_zoo_cfg
self.model_zoo = {}

def __call__(
self,
env: ManagerBasedEnv,
sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"),
data_type: str = "rgb",
convert_perspective_to_orthogonal: bool = False,
model_zoo_cfg: dict | None = None,
model_name: str = "ResNet18",
model_device: str | None = "cuda:0",
reset_model: bool = False,
) -> torch.Tensor:
"""Extracted image features from a pre-trained frozen encoder.
Args:
env: The environment.
sensor_cfg: The sensor configuration to poll. Defaults to SceneEntityCfg("tiled_camera").
data_type: THe sensor configuration datatype. Defaults to "rgb".
convert_perspective_to_orthogonal: Whether to orthogonalize perspective depth images.
This is used only when the data type is "distance_to_camera". Defaults to False.
model_zoo_cfg: Map from model name to model configuration dictionary. Each model
configuration dictionary should include the following entries:
- "model": A callable that returns the model when invoked without arguments.
- "preprocess": A callable that processes the images and returns the preprocessed results.
- "inference": A callable that, when given the model and preprocessed images,
returns the extracted features.
model_name: The name of the model to use for inference. Defaults to "ResNet18".
model_device: The device to store and infer models on. This can be used help offload
computation from the main environment GPU. Defaults to "cuda:0".
reset_model: Initialize the model even if it already exists. Defaults to False.
Returns:
torch.Tensor: the image features, on the same device as the image
"""
if model_zoo_cfg is not None: # use other than default
self.model_zoo_cfg.update(model_zoo_cfg)

if model_name not in self.model_zoo or reset_model:
# The following allows to only load a desired subset of a model zoo into GPU memory
# as it becomes needed, in a "lazy" evaluation.
print(f"[INFO]: Adding {model_name} to the model zoo")
self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]()

if model_device is not None and self.model_zoo[model_name].device != model_device:
# want to offload vision model inference to another device
self.model_zoo[model_name] = self.model_zoo[model_name].to(model_device)

images = image(
env=env,
sensor_cfg=sensor_cfg,
data_type=data_type,
convert_perspective_to_orthogonal=convert_perspective_to_orthogonal,
normalize=True, # want this for training stability
)

image_device = images.device

if model_device is not None:
images = images.to(model_device)

proc_images = self.model_zoo_cfg[model_name]["preprocess"](images)
features = self.model_zoo_cfg[model_name]["inference"](self.model_zoo[model_name], proc_images)

return features.to(image_device).clone()


"""
Actions.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def joint_acc_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntity
return torch.sum(torch.square(asset.data.joint_acc[:, asset_cfg.joint_ids]), dim=1)


def joint_deviation_l1(env, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
def joint_deviation_l1(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize joint positions that deviate from the default one."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def _create_debug_vis_ui_element(self, name: str, elem: object):
self.ui_window_elements[f"{name}_cb"] = SimpleCheckBox(
model=omni.ui.SimpleBoolModel(),
enabled=elem.has_debug_vis_implementation,
checked=elem.cfg.debug_vis,
checked=elem.cfg.debug_vis if elem.cfg else False,
on_checked_fn=lambda value, e=weakref.proxy(elem): e.set_debug_vis(value),
)
omni.isaac.ui.ui_utils.add_line_rect_flourish()
Expand Down
129 changes: 129 additions & 0 deletions source/extensions/omni.isaac.lab/omni/isaac/lab/envs/utils/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import gymnasium as gym
import json
import numpy as np
import torch
from typing import Any
Expand Down Expand Up @@ -90,3 +91,131 @@ def tensorize(s, x):

sample = (gym.vector.utils.batch_space(space, batch_size) if batch_size > 0 else space).sample()
return tensorize(space, sample)


def serialize_space(space: SpaceType) -> str:
"""Serialize a space specification as JSON.
Args:
space: Space specification.
Returns:
Serialized JSON representation.
"""
# Gymnasium spaces
if isinstance(space, gym.spaces.Discrete):
return json.dumps({"type": "gymnasium", "space": "Discrete", "n": int(space.n)})
elif isinstance(space, gym.spaces.Box):
return json.dumps({
"type": "gymnasium",
"space": "Box",
"low": space.low.tolist(),
"high": space.high.tolist(),
"shape": space.shape,
})
elif isinstance(space, gym.spaces.MultiDiscrete):
return json.dumps({"type": "gymnasium", "space": "MultiDiscrete", "nvec": space.nvec.tolist()})
elif isinstance(space, gym.spaces.Tuple):
return json.dumps({"type": "gymnasium", "space": "Tuple", "spaces": tuple(map(serialize_space, space.spaces))})
elif isinstance(space, gym.spaces.Dict):
return json.dumps(
{"type": "gymnasium", "space": "Dict", "spaces": {k: serialize_space(v) for k, v in space.spaces.items()}}
)
# Python data types
# Box
elif isinstance(space, int) or (isinstance(space, list) and all(isinstance(x, int) for x in space)):
return json.dumps({"type": "python", "space": "Box", "value": space})
# Discrete
elif isinstance(space, set) and len(space) == 1:
return json.dumps({"type": "python", "space": "Discrete", "value": next(iter(space))})
# MultiDiscrete
elif isinstance(space, list) and all(isinstance(x, set) and len(x) == 1 for x in space):
return json.dumps({"type": "python", "space": "MultiDiscrete", "value": [next(iter(x)) for x in space]})
# composite spaces
# Tuple
elif isinstance(space, tuple):
return json.dumps({"type": "python", "space": "Tuple", "value": [serialize_space(x) for x in space]})
# Dict
elif isinstance(space, dict):
return json.dumps(
{"type": "python", "space": "Dict", "value": {k: serialize_space(v) for k, v in space.items()}}
)
raise ValueError(f"Unsupported space ({space})")


def deserialize_space(string: str) -> gym.spaces.Space:
"""Deserialize a space specification encoded as JSON.
Args:
string: Serialized JSON representation.
Returns:
Space specification.
"""
obj = json.loads(string)
# Gymnasium spaces
if obj["type"] == "gymnasium":
if obj["space"] == "Discrete":
return gym.spaces.Discrete(n=obj["n"])
elif obj["space"] == "Box":
return gym.spaces.Box(low=np.array(obj["low"]), high=np.array(obj["high"]), shape=obj["shape"])
elif obj["space"] == "MultiDiscrete":
return gym.spaces.MultiDiscrete(nvec=np.array(obj["nvec"]))
elif obj["space"] == "Tuple":
return gym.spaces.Tuple(spaces=tuple(map(deserialize_space, obj["spaces"])))
elif obj["space"] == "Dict":
return gym.spaces.Dict(spaces={k: deserialize_space(v) for k, v in obj["spaces"].items()})
else:
raise ValueError(f"Unsupported space ({obj['spaces']})")
# Python data types
elif obj["type"] == "python":
if obj["space"] == "Discrete":
return {obj["value"]}
elif obj["space"] == "Box":
return obj["value"]
elif obj["space"] == "MultiDiscrete":
return [{x} for x in obj["value"]]
elif obj["space"] == "Tuple":
return tuple(map(deserialize_space, obj["value"]))
elif obj["space"] == "Dict":
return {k: deserialize_space(v) for k, v in obj["value"].items()}
else:
raise ValueError(f"Unsupported space ({obj['spaces']})")
else:
raise ValueError(f"Unsupported type ({obj['type']})")


def replace_env_cfg_spaces_with_strings(env_cfg: object) -> object:
"""Replace spaces objects with their serialized JSON representations in an environment config.
Args:
env_cfg: Environment config instance.
Returns:
Environment config instance with spaces replaced if any.
"""
for attr in ["observation_space", "action_space", "state_space"]:
if hasattr(env_cfg, attr):
setattr(env_cfg, attr, serialize_space(getattr(env_cfg, attr)))
for attr in ["observation_spaces", "action_spaces"]:
if hasattr(env_cfg, attr):
setattr(env_cfg, attr, {k: serialize_space(v) for k, v in getattr(env_cfg, attr).items()})
return env_cfg


def replace_strings_with_env_cfg_spaces(env_cfg: object) -> object:
"""Replace spaces objects with their serialized JSON representations in an environment config.
Args:
env_cfg: Environment config instance.
Returns:
Environment config instance with spaces replaced if any.
"""
for attr in ["observation_space", "action_space", "state_space"]:
if hasattr(env_cfg, attr):
setattr(env_cfg, attr, deserialize_space(getattr(env_cfg, attr)))
for attr in ["observation_spaces", "action_spaces"]:
if hasattr(env_cfg, attr):
setattr(env_cfg, attr, {k: deserialize_space(v) for k, v in getattr(env_cfg, attr).items()})
return env_cfg
Loading

0 comments on commit 551003b

Please sign in to comment.