Skip to content

Commit

Permalink
Adds support for callable classes in manager terms (#104)
Browse files Browse the repository at this point in the history
# Description

This PR adds support for callable classes to the `term_cfg.func`
attribute. This is needed for complex behaviors where users may want to
define certain persistent behaviors as part of the terms.

The callable class should take in the term configuration object and the
environment instance as inputs to its constructor. Additionally, they
should implement the `__call__` function with the signature expected by
the manager.

For example, in the case of observation terms, this looks like:

```python
class complex_function_class:
    def __init__(self, cfg: ObservationTermCfg, env: object):
        self.cfg = cfg
        self.env = env
        # define some variables
        self.history_length = 2
        self._obs_history = torch.zeros(self.env.num_envs, self.history_length, 2, device=self.env.device)

    def __call__(self, env: object) -> torch.Tensor:
        new_obs = torch.rand(env.num_envs, 2, device=env.device)
        # update history
        self._obs_history[:, 1:] = self._obs_history[:, :1].clone()
        self._obs_history[:, 0] = new_obs
        # return obs
        return new_obs
```

## Type of change

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
  • Loading branch information
Mayankm96 authored Aug 3, 2023
1 parent 619337e commit 16c740f
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 2 deletions.
2 changes: 1 addition & 1 deletion source/extensions/omni.isaac.orbit/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.8.5"
version = "0.8.6"

# Description
title = "ORBIT framework for Robot Learning"
Expand Down
9 changes: 9 additions & 0 deletions source/extensions/omni.isaac.orbit/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Changelog
---------

0.8.6 (2023-08-03)
~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added support for callable classes in the :class:`omni.isaac.orbit.managers.ManagerBase`.


0.8.5 (2023-08-03)
~~~~~~~~~~~~~~~~~~

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def _prepare_terms(self):
# check for non config
if term_cfg is None:
continue
# check if the term is a valid term config
if not isinstance(term_cfg, CurriculumTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type CurriculumTermCfg. Received '{type(term_cfg)}'."
)
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2)
# add name and config to list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,14 @@ def _resolve_common_term_cfg(self, term_name: str, term_cfg: ManagerBaseTermCfg,
# acquire the body indices
body_ids, _ = getattr(self._env, term_cfg.asset_name).find_bodies(term_cfg.body_names)
term_cfg.params["body_ids"] = body_ids
# get the corresponding function
# get the corresponding function or functional class
if isinstance(term_cfg.func, str):
term_cfg.func = string_to_callable(term_cfg.func)
# initialize the term if it is a class
if inspect.isclass(term_cfg.func):
term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env)
# add the "self" argument to the count
min_argc += 1
# check if function is callable
if not callable(term_cfg.func):
raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ class ManagerBaseTermCfg:
"""The function to be called for the term.
The function must take the environment object as the first argument.
Note:
It also supports `callable classes`_, i.e. classes that implement the :meth:`__call__`
method.
..`callable objects`: https://docs.python.org/3/reference/datamodel.html#object.__call__
"""
sensor_name: str | None = None
"""The name of the sensor required by the term. Defaults to None.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ def _prepare_terms(self):
# check for non config
if term_cfg is None:
continue
if not isinstance(term_cfg, ObservationTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type ObservationTermCfg. Received '{type(term_cfg)}'."
)
# resolve common terms in the config
self._resolve_common_term_cfg(f"{group_name}/{term_name}", term_cfg, min_argc=1)
# check noise settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ def _prepare_terms(self):
# check for non config
if term_cfg is None:
continue
# check for valid config type
if not isinstance(term_cfg, RandomizationTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type RandomizationTermCfg. Received '{type(term_cfg)}'."
)
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2)
# check if mode is a new mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ def _prepare_terms(self):
# check for non config
if term_cfg is None:
continue
# check for valid config type
if not isinstance(term_cfg, RewardTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type RewardTermCfg. Received '{type(term_cfg)}'."
)
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=1)
# remove zero scales and multiply non-zero ones by dt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ def _prepare_terms(self):
# check for non config
if term_cfg is None:
continue
# check for valid config type
if not isinstance(term_cfg, TerminationTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type TerminationTermCfg. Received '{type(term_cfg)}'."
)
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=1)
# add function to list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,28 @@ def grilled_chicken_with_yoghurt(env, hot: bool, bland: float):
return hot * bland * torch.ones(env.num_envs, 5, device=env.device)


class complex_function_class:
def __init__(self, cfg: ObservationTermCfg, env: object):
self.cfg = cfg
self.env = env
# define some variables
self._cost = 2 * self.env.num_envs

def __call__(self, env: object) -> torch.Tensor:
return torch.ones(env.num_envs, 2, device=env.device) * self._cost


class non_callable_complex_function_class:
def __init__(self, cfg: ObservationTermCfg, env: object):
self.cfg = cfg
self.env = env
# define some variables
self._cost = 2 * self.env.num_envs

def call_me(self, env: object) -> torch.Tensor:
return torch.ones(env.num_envs, 2, device=env.device) * self._cost


class TestObservationManager(unittest.TestCase):
"""Test cases for various situations with observation manager."""

Expand Down Expand Up @@ -189,6 +211,53 @@ class PolicyCfg(ObservationGroupCfg):
with self.assertRaises(ValueError):
self.obs_man = ObservationManager(cfg, self.env)

def test_callable_class_term(self):
"""Test the observation computation with callable class term."""

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

term_1 = ObservationTermCfg(func=grilled_chicken, scale=10)
term_2 = ObservationTermCfg(func=complex_function_class, scale=0.2)

policy: ObservationGroupCfg = PolicyCfg()

# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()
# check the observation shape
self.assertEqual((self.env.num_envs, 6), observations["policy"].shape)
self.assertEqual(observations["policy"][0, -1].item(), 2 * self.env.num_envs * 0.2)

def test_non_callable_class_term(self):
"""Test the observation computation with non-callable class term."""

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

term_1 = ObservationTermCfg(func=grilled_chicken, scale=10)
term_2 = ObservationTermCfg(func=non_callable_complex_function_class, scale=0.2)

policy: ObservationGroupCfg = PolicyCfg()

# create observation manager config
cfg = MyObservationManagerCfg()
# create observation manager
with self.assertRaises(AttributeError):
self.obs_man = ObservationManager(cfg, self.env)


if __name__ == "__main__":
unittest.main()

0 comments on commit 16c740f

Please sign in to comment.