Skip to content

Commit

Permalink
cleans up the image feature code
Browse files Browse the repository at this point in the history
  • Loading branch information
Mayankm96 committed Oct 30, 2024
1 parent da663c4 commit a569a81
Showing 1 changed file with 202 additions and 98 deletions.
300 changes: 202 additions & 98 deletions source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,42 +186,46 @@ def body_incoming_wrench(env: ManagerBasedEnv, asset_cfg: SceneEntityCfg) -> tor


def imu_orientation(env: ManagerBasedEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("imu")) -> torch.Tensor:
"""Imu sensor orientation w.r.t the env.scene.origin.
"""Imu sensor orientation in the simulation world frame.
Args:
env: The environment.
asset_cfg: The SceneEntity associated with an Imu sensor.
asset_cfg: The SceneEntity associated with an IMU sensor. Defaults to SceneEntityCfg("imu").
Returns:
Orientation quaternion (wxyz), shape of torch.tensor is (num_env,4).
Orientation in the world frame in (w, x, y, z) quaternion form. Shape is (num_envs, 4).
"""
# extract the used quantities (to enable type-hinting)
asset: Imu = env.scene[asset_cfg.name]
# return the orientation quaternion
return asset.data.quat_w


def imu_ang_vel(env: ManagerBasedEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("imu")) -> torch.Tensor:
"""Imu sensor angular velocity w.r.t. env.scene.origin expressed in the sensor frame.
"""Imu sensor angular velocity w.r.t. environment origin expressed in the sensor frame.
Args:
env: The environment.
asset_cfg: The SceneEntity associated with an Imu sensor.
asset_cfg: The SceneEntity associated with an IMU sensor. Defaults to SceneEntityCfg("imu").
Returns:
Angular velocity (rad/s), shape of torch.tensor is (num_env,3).
The angular velocity (rad/s) in the sensor frame. Shape is (num_envs, 3).
"""
# extract the used quantities (to enable type-hinting)
asset: Imu = env.scene[asset_cfg.name]
# return the angular velocity
return asset.data.ang_vel_b


def imu_lin_acc(env: ManagerBasedEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("imu")) -> torch.Tensor:
"""Imu sensor linear acceleration w.r.t. env.scene.origin expressed in sensor frame.
"""Imu sensor linear acceleration w.r.t. the environment origin expressed in sensor frame.
Args:
env: The environment.
asset_cfg: The SceneEntity associated with an Imu sensor.
asset_cfg: The SceneEntity associated with an IMU sensor. Defaults to SceneEntityCfg("imu").
Returns:
linear acceleration (m/s^2), shape of torch.tensor is (num_env,3).
The linear acceleration (m/s^2) in the sensor frame. Shape is (num_envs, 3).
"""
asset: Imu = env.scene[asset_cfg.name]
return asset.data.lin_acc_b
Expand Down Expand Up @@ -279,65 +283,89 @@ def image(
class image_features(ManagerTermBase):
"""Extracted image features from a pre-trained frozen encoder.
This method calls the :meth:`image` function to retrieve images, and then performs
inference on those images.
This term uses models from the model zoo in PyTorch and extracts features from the images.
It calls the :func:`image` function to get the images and then processes them using the model zoo.
A user can provide their own model zoo configuration to use different models for feature extraction.
The model zoo configuration should be a dictionary that maps different model names to a dictionary
that defines the model, preprocess and inference functions. The dictionary should have the following
entries:
- "model": A callable that returns the model when invoked without arguments.
- "preprocess": A callable that processes the images and returns the preprocessed results.
- "inference": A callable that, when given the model and preprocessed images, returns the extracted features.
If the model zoo configuration is not provided, the default model zoo configurations are used. The default
model zoo configurations include the models from Theia and ResNet.
Args:
sensor_cfg: The sensor configuration to poll. Defaults to SceneEntityCfg("tiled_camera").
data_type: The sensor data type. Defaults to "rgb".
convert_perspective_to_orthogonal: Whether to orthogonalize perspective depth images.
This is used only when the data type is "distance_to_camera". Defaults to False.
model_zoo_cfg: A user-defined dictionary that maps different model names to their respective configurations.
Defaults to None. If None, the default model zoo configurations are used.
model_name: The name of the model to use for inference. Defaults to "ResNet18".
model_device: The device to store and infer the model on. This is useful when offloading the computation
from the environment simulation device. Defaults to the environment device.
Returns:
The extracted features tensor. Shape is (num_envs, feature_dim).
Raises:
ValueError: When the model name is not found in the provided model zoo configuration.
ValueError: When the model name is not found in the default model zoo configuration.
"""

def __init__(self, cfg: ObservationTermCfg, env: ManagerBasedEnv):
# initialize the base class
super().__init__(cfg, env)
from torchvision import models
from transformers import AutoModel

def create_theia_model(model_name):
return {
"model": (
lambda: AutoModel.from_pretrained(f"theaiinstitute/{model_name}", trust_remote_code=True)
.eval()
.to("cuda:0")
),
"preprocess": lambda img: (img - torch.amin(img, dim=(1, 2), keepdim=True)) / (
torch.amax(img, dim=(1, 2), keepdim=True) - torch.amin(img, dim=(1, 2), keepdim=True)
),
"inference": lambda model, images: model.forward_feature(
images, do_rescale=False, interpolate_pos_encoding=True
),
}

def create_resnet_model(resnet_name):
return {
"model": lambda: getattr(models, resnet_name)(pretrained=True).eval().to("cuda:0"),
"preprocess": lambda img: (
img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width]
- torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1)
) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1),
"inference": lambda model, images: model(images),
}

# List of Theia models
theia_models = [
# extract parameters from the configuration
self.model_zoo_cfg: dict = cfg.params.get("model_zoo_cfg", None) # type: ignore
self.model_name: str = cfg.params.get("model_name", "ResNet18") # type: ignore
self.model_device: str = cfg.params.get("model_device", env.device) # type: ignore

# List of Theia models - These are configured through `_prepare_theia_transformer_model` function
default_theia_models = [
"theia-tiny-patch16-224-cddsv",
"theia-tiny-patch16-224-cdiv",
"theia-small-patch16-224-cdiv",
"theia-base-patch16-224-cdiv",
"theia-small-patch16-224-cddsv",
"theia-base-patch16-224-cddsv",
]

# List of ResNet models
resnet_models = ["resnet18", "resnet34", "resnet50", "resnet101"]

self.default_model_zoo_cfg = {}

# Add Theia models to the zoo
for model_name in theia_models:
self.default_model_zoo_cfg[model_name] = create_theia_model(model_name)

# Add ResNet models to the zoo
for resnet_name in resnet_models:
self.default_model_zoo_cfg[resnet_name] = create_resnet_model(resnet_name)

self.model_zoo_cfg = self.default_model_zoo_cfg
self.model_zoo = {}
# List of ResNet models - These are configured through `_prepare_resnet_model` function
default_resnet_models = ["resnet18", "resnet34", "resnet50", "resnet101"]

# Check if model name is specified in the model zoo configuration
if self.model_zoo_cfg is not None and self.model_name not in self.model_zoo_cfg:
raise ValueError(
f"Model name '{self.model_name}' not found in the provided model zoo configuration."
" Please add the model to the model zoo configuration or use a different model name."
f" Available models in the provided list: {list(self.model_zoo_cfg.keys())}."
f"\nHint: If you want to use a default model, consider using one of the following models:"
f" {default_theia_models + default_resnet_models}. In this case, you can remove the"
" 'model_zoo_cfg' parameter from the observation term configuration."
)
elif self.model_zoo_cfg is None:
if self.model_name in default_theia_models:
model_config = self._prepare_theia_transformer_model(self.model_name, self.model_device)
elif self.model_name in default_resnet_models:
model_config = self._prepare_resnet_model(self.model_name, self.model_device)
else:
raise ValueError(
f"Model name '{self.model_name}' not found in the default model zoo configuration."
f" Available models: {default_theia_models + default_resnet_models}."
)
else:
model_config = self.model_zoo_cfg[self.model_name]

# Retrieve the model, preprocess and inference functions
self._model = model_config["model"]()
self._preprocess_fn = model_config["preprocess"]
self._inference_fn = model_config["inference"]

def __call__(
self,
Expand All @@ -347,61 +375,137 @@ def __call__(
convert_perspective_to_orthogonal: bool = False,
model_zoo_cfg: dict | None = None,
model_name: str = "ResNet18",
model_device: str | None = "cuda:0",
reset_model: bool = False,
model_device: str | None = None,
) -> torch.Tensor:
"""Extracted image features from a pre-trained frozen encoder.
# obtain the images from the sensor
image_data = image(
env=env,
sensor_cfg=sensor_cfg,
data_type=data_type,
convert_perspective_to_orthogonal=convert_perspective_to_orthogonal,
normalize=True, # need this for training stability
)
# store the device of the image
image_device = image_data.device
# preprocess the images and obtain the features
image_processed = self._preprocess_fn(image_data)
# forward the images through the model
features = self._inference_fn(self._model, image_processed)

# move the features back to the image device
return features.detach().to(image_device)

"""
Helper functions.
"""

def _prepare_theia_transformer_model(self, model_name: str, model_device: str) -> dict:
"""Prepare the Theia transformer model for inference.
Args:
env: The environment.
sensor_cfg: The sensor configuration to poll. Defaults to SceneEntityCfg("tiled_camera").
data_type: THe sensor configuration datatype. Defaults to "rgb".
convert_perspective_to_orthogonal: Whether to orthogonalize perspective depth images.
This is used only when the data type is "distance_to_camera". Defaults to False.
model_zoo_cfg: Map from model name to model configuration dictionary. Each model
configuration dictionary should include the following entries:
- "model": A callable that returns the model when invoked without arguments.
- "preprocess": A callable that processes the images and returns the preprocessed results.
- "inference": A callable that, when given the model and preprocessed images,
returns the extracted features.
model_name: The name of the model to use for inference. Defaults to "ResNet18".
model_device: The device to store and infer models on. This can be used help offload
computation from the main environment GPU. Defaults to "cuda:0".
reset_model: Initialize the model even if it already exists. Defaults to False.
model_name: The name of the Theia transformer model to prepare.
model_device: The device to store and infer the model on.
Returns:
torch.Tensor: the image features, on the same device as the image
A dictionary containing the model, preprocess and inference functions.
"""
if model_zoo_cfg is not None: # use other than default
self.model_zoo_cfg.update(model_zoo_cfg)
from transformers import AutoModel

if model_name not in self.model_zoo or reset_model:
# The following allows to only load a desired subset of a model zoo into GPU memory
# as it becomes needed, in a "lazy" evaluation.
print(f"[INFO]: Adding {model_name} to the model zoo")
self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]()
def _load_model() -> torch.nn.Module:
"""Load the Theia transformer model."""
model = AutoModel.from_pretrained(f"theaiinstitute/{model_name}", trust_remote_code=True).eval()
return model.to(model_device)

if model_device is not None:
# want to offload vision model inference to another device
self.model_zoo[model_name] = self.model_zoo[model_name].to(model_device)
def _preprocess_image(img: torch.Tensor) -> torch.Tensor:
"""Preprocess the image for the Theia transformer model.
images = image(
env=env,
sensor_cfg=sensor_cfg,
data_type=data_type,
convert_perspective_to_orthogonal=convert_perspective_to_orthogonal,
normalize=True, # want this for training stability
)
Args:
img: The image tensor to preprocess. Shape is (num_envs, height, width, channel).
Returns:
The preprocessed image tensor. Shape is (num_envs, height, width, channel).
"""
# Move the image to the model device
img = img.to(model_device)
# Normalize the image
min_img = torch.amin(img, dim=(1, 2), keepdim=True)
max_img = torch.amax(img, dim=(1, 2), keepdim=True)

return (img - min_img) / (max_img - min_img)

def _inference(model, images: torch.Tensor) -> torch.Tensor:
"""Inference the Theia transformer model.
Args:
model: The Theia transformer model.
images: The preprocessed image tensor. Shape is (num_envs, height, width, channel).
Returns:
The extracted features tensor. Shape is (num_envs, feature_dim).
"""
return model.forward_feature(images, do_rescale=False, interpolate_pos_encoding=True)

image_device = images.device
# return the model, preprocess and inference functions
return {
"model": _load_model,
"preprocess": _preprocess_image,
"inference": _inference,
}

if model_device is not None:
images = images.to(model_device)
def _prepare_resnet_model(self, model_name: str, model_device: str) -> dict:
"""Prepare the ResNet model for inference.
proc_images = self.model_zoo_cfg[model_name]["preprocess"](images)
features = self.model_zoo_cfg[model_name]["inference"](self.model_zoo[model_name], proc_images)
Args:
model_name: The name of the ResNet model to prepare.
model_device: The device to store and infer the model on.
Returns:
A dictionary containing the model, preprocess and inference functions.
"""
from torchvision import models

return features.to(image_device).clone()
def _load_model() -> torch.nn.Module:
"""Load the ResNet model."""
model = getattr(models, model_name)(pretrained=True).eval()
return model.to(model_device)

def _preprocess_image(img: torch.Tensor) -> torch.Tensor:
"""Preprocess the image for the ResNet model.
Args:
img: The image tensor to preprocess. Shape is (num_envs, height, width, channel).
Returns:
The preprocessed image tensor. Shape is (num_envs, channel, height, width).
"""
# move the image to the model device
img = img.to(model_device)
# permute the image to (num_envs, channel, height, width)
img = img.permute(0, 3, 1, 2)
# normalize the image
mean = torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1)

return (img - mean) / std

def _inference(model, images: torch.Tensor) -> torch.Tensor:
"""Inference the ResNet model.
Args:
model: The ResNet model.
images: The preprocessed image tensor. Shape is (num_envs, channel, height, width).
Returns:
The extracted features tensor. Shape is (num_envs, feature_dim).
"""
return model(images)

# return the model, preprocess and inference functions
return {
"model": _load_model,
"preprocess": _preprocess_image,
"inference": _inference,
}


"""
Expand Down

0 comments on commit a569a81

Please sign in to comment.