diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 113f5863c0..f9d44c90bc 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -37,6 +37,7 @@ Guidelines for modifications: * Chenyu Yang * Jia Lin Yuan * Jingzhou Liu +* Lorenz Wellhausen * Kourosh Darvish * Qinxi Yu * René Zurbrügg diff --git a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/manager_term_cfg.py b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/manager_term_cfg.py index d6602b252f..fbe5edd9fd 100644 --- a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/manager_term_cfg.py +++ b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/manager_term_cfg.py @@ -200,6 +200,13 @@ class RandomizationTermCfg(ManagerTermBaseCfg): ## +@configclass +class RewardGroupCfg: + # Reserved for future use. + # No parameters, yet. + pass + + @configclass class RewardTermCfg(ManagerTermBaseCfg): """Configuration for a reward term.""" diff --git a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py index 15b5b30e12..01efce0a9b 100644 --- a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py +++ b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py @@ -12,12 +12,15 @@ from typing import TYPE_CHECKING, Sequence from .manager_base import ManagerBase, ManagerTermBase -from .manager_term_cfg import RewardTermCfg +from .manager_term_cfg import RewardGroupCfg, RewardTermCfg if TYPE_CHECKING: from omni.isaac.orbit.envs import RLTaskEnv +DEFAULT_GROUP_NAME = "reward" + + class RewardManager(ManagerBase): """Manager for computing reward signals for a given world. @@ -26,7 +29,11 @@ class RewardManager(ManagerBase): terms configuration. The reward terms are parsed from a config class containing the manager's settings and each term's - parameters. Each reward term should instantiate the :class:`RewardTermCfg` class. + parameters. + + Rewards are organized into groups, for multi-critic or CMDP use-cases. + Each rewards group shouuld inherit from the :class:`RewardGroupCfg` class. + Within each group, each reward term should inherit from the :class:`RewardTermCfg` class. .. note:: @@ -43,34 +50,52 @@ def __init__(self, cfg: object, env: RLTaskEnv): """Initialize the reward manager. Args: - cfg: The configuration object or dictionary (``dict[str, RewardTermCfg]``). + cfg: The configuration object or dictionary (``dict[str, RewardGroupCfg]``). env: The environment instance. """ + # Variable to track whether we have reward groups or not. + # Needs to be set before we call super().__init__ because it's needed in prepare_terms. + self.no_group = None super().__init__(cfg, env) - # prepare extra info to store individual reward term information + + # Allocate storage for reward terms. self._episode_sums = dict() - for term_name in self._term_names: - self._episode_sums[term_name] = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) - # create buffer for managing reward per environment - self._reward_buf = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) + self._reward_buf = {} + self._term_names_flat = [] # flat list of all term names + for group_name, group_term_names in self._group_term_names.items(): + for term_name in group_term_names: + sum_term_name = term_name if self.no_group else f"{group_name}/{term_name}" + self._episode_sums[sum_term_name] = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) + + self._term_names_flat.append(sum_term_name) + + # create buffer for managing reward per environment + self._reward_buf[group_name] = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) def __str__(self) -> str: """Returns: A string representation for reward manager.""" - msg = f" contains {len(self._term_names)} active terms.\n" + # Get number of reward terms. + msg = f" contains {len(self._term_names_flat)} active terms.\n" # create table for term information - table = PrettyTable() - table.title = "Active Reward Terms" - table.field_names = ["Index", "Name", "Weight"] - # set alignment of table columns - table.align["Name"] = "l" - table.align["Weight"] = "r" - # add info on each term - for index, (name, term_cfg) in enumerate(zip(self._term_names, self._term_cfgs)): - table.add_row([index, name, term_cfg.weight]) - # convert table to string - msg += table.get_string() - msg += "\n" + for group_name in self._group_term_names.keys(): + table = PrettyTable() + table.title = "Active Reward Terms In Group: " + group_name + table.field_names = ["Index", "Name", "Weight"] + # set alignment of table columns + table.align["Name"] = "l" + table.align["Weight"] = "r" + # add info on each term + for index, (name, term_cfg) in enumerate( + zip( + self._group_term_names[group_name], + self._group_term_cfgs[group_name], + ) + ): + table.add_row([index, name, term_cfg.weight]) + # convert table to string + msg += table.get_string() + msg += "\n" return msg @@ -81,7 +106,7 @@ def __str__(self) -> str: @property def active_terms(self) -> list[str]: """Name of active reward terms.""" - return self._term_names + return self._term_names_flat """ Operations. @@ -110,8 +135,9 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor] # reset episodic sum self._episode_sums[key][env_ids] = 0.0 # reset all the reward terms - for term_cfg in self._class_term_cfgs: - term_cfg.func.reset(env_ids=env_ids) + for group_cfg in self._group_class_term_cfgs.values(): + for term_cfg in group_cfg: + term_cfg.func.reset(env_ids=env_ids) # return logged information return extras @@ -128,20 +154,28 @@ def compute(self, dt: float) -> torch.Tensor: The net reward signal of shape (num_envs,). """ # reset computation - self._reward_buf[:] = 0.0 - # iterate over all the reward terms - for name, term_cfg in zip(self._term_names, self._term_cfgs): - # skip if weight is zero (kind of a micro-optimization) - if term_cfg.weight == 0.0: - continue - # compute term's value - value = term_cfg.func(self._env, **term_cfg.params) * term_cfg.weight * dt - # update total reward - self._reward_buf += value - # update episodic sum - self._episode_sums[name] += value - - return self._reward_buf + for key in self._reward_buf.keys(): + self._reward_buf[key][:] = 0.0 + # iterate over all reward terms of all groups + for group_name in self._group_term_names.keys(): + # iterate over all the reward terms + for term_name, term_cfg in zip(self._group_term_names[group_name], self._group_term_cfgs[group_name]): + # skip if weight is zero (kind of a micro-optimization) + if term_cfg.weight == 0.0: + continue + # compute term's value + value = term_cfg.func(self._env, **term_cfg.params) * term_cfg.weight * dt + # update total reward + self._reward_buf[group_name] += value + # update episodic sum + name = term_name if self.no_group else f"{group_name}/{term_name}" + self._episode_sums[name] += value + + # Return only Tensor if config has no groups. + if self.no_group: + return self._reward_buf[DEFAULT_GROUP_NAME] + else: + return self._reward_buf """ Operations - Term settings. @@ -157,10 +191,18 @@ def set_term_cfg(self, term_name: str, cfg: RewardTermCfg): Raises: ValueError: If the term name is not found. """ - if term_name not in self._term_names: + # Split term_name at '/' if it has one. + if "/" in term_name: + group_name, term_name = term_name.split("/") + else: + group_name = DEFAULT_GROUP_NAME + + if group_name not in self._group_term_names: + raise ValueError(f"Reward group '{group_name}' not found.") + if term_name not in self._group_term_names[group_name]: raise ValueError(f"Reward term '{term_name}' not found.") # set the configuration - self._term_cfgs[self._term_names.index(term_name)] = cfg + self._group_term_cfgs[group_name][self._group_term_names[group_name].index(term_name)] = cfg def get_term_cfg(self, term_name: str) -> RewardTermCfg: """Gets the configuration for the specified term. @@ -174,10 +216,18 @@ def get_term_cfg(self, term_name: str) -> RewardTermCfg: Raises: ValueError: If the term name is not found. """ - if term_name not in self._term_names: + # Split term_name at '/' if it has one. + if "/" in term_name: + group_name, term_name = term_name.split("/") + else: + group_name = DEFAULT_GROUP_NAME + + if group_name not in self._group_term_names: + raise ValueError(f"Reward group '{group_name}' not found.") + if term_name not in self._group_term_names[group_name]: raise ValueError(f"Reward term '{term_name}' not found.") # return the configuration - return self._term_cfgs[self._term_names.index(term_name)] + return self._group_term_cfgs[group_name][self._group_term_names[group_name].index(term_name)] """ Helper functions. @@ -185,38 +235,73 @@ def get_term_cfg(self, term_name: str) -> RewardTermCfg: def _prepare_terms(self): """Prepares a list of reward functions.""" - # parse remaining reward terms and decimate their information - self._term_names: list[str] = list() - self._term_cfgs: list[RewardTermCfg] = list() - self._class_term_cfgs: list[RewardTermCfg] = list() + + self._group_term_names: dict[str, list[str]] = dict() + self._group_term_cfgs: dict[str, list[RewardTermCfg]] = dict() + self._group_class_term_cfgs: dict[str, list[RewardTermCfg]] = dict() # check if config is dict already if isinstance(self.cfg, dict): cfg_items = self.cfg.items() else: cfg_items = self.cfg.__dict__.items() - # iterate over all the terms - for term_name, term_cfg in cfg_items: + + # Check whether we have a group or not and fail if we have a mix. + for name, cfg in cfg_items: # check for non config - if term_cfg is None: + if cfg is None: continue - # check for valid config type - if not isinstance(term_cfg, RewardTermCfg): - raise TypeError( - f"Configuration for the term '{term_name}' is not of type RewardTermCfg." - f" Received: '{type(term_cfg)}'." - ) - # check for valid weight type - if not isinstance(term_cfg.weight, (float, int)): - raise TypeError( - f"Weight for the term '{term_name}' is not of type float or int." - f" Received: '{type(term_cfg.weight)}'." - ) - # resolve common parameters - self._resolve_common_term_cfg(term_name, term_cfg, min_argc=1) - # add function to list - self._term_names.append(term_name) - self._term_cfgs.append(term_cfg) - # check if the term is a class - if isinstance(term_cfg.func, ManagerTermBase): - self._class_term_cfgs.append(term_cfg) + if isinstance(cfg, RewardTermCfg): + if self.no_group is None: + self.no_group = True + elif self.no_group is False: + raise ValueError("Cannot mix reward groups with reward terms.") + else: + if self.no_group is None: + self.no_group = False + elif self.no_group is True: + raise ValueError("Cannot mix reward groups with reward terms.") + + # Make a group if we do not have one. + if self.no_group: + cfg_items = {DEFAULT_GROUP_NAME: dict(cfg_items)}.items() + + # iterate over all the groups + for group_name, group_cfg in cfg_items: + self._group_term_names[group_name] = list() + self._group_term_cfgs[group_name] = list() + self._group_class_term_cfgs[group_name] = list() + + # Make group config a list if it is not. + if isinstance(group_cfg, dict): + group_cfg_items = group_cfg.items() + else: + group_cfg_items = group_cfg.__dict__.items() + + # Iterate over all the terms in the group + for term_name, term_cfg in group_cfg_items: + # check for non config + if term_cfg is None: + continue + # check for valid config type + if not isinstance(term_cfg, RewardTermCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type RewardTermCfg." + f" Received: '{type(term_cfg)}'." + ) + # check for valid weight type + if not isinstance(term_cfg.weight, (float, int)): + raise TypeError( + f"Weight for the term '{term_name}' is not of type float or int." + f" Received: '{type(term_cfg.weight)}'." + ) + # resolve common terms in the config + self._resolve_common_term_cfg(f"{group_name}/{term_name}", term_cfg, min_argc=1) + # add term config to list + self._group_term_names[group_name].append(term_name) + self._group_term_cfgs[group_name].append(term_cfg) + # add term to separate list if term is a class + if isinstance(term_cfg.func, ManagerTermBase): + self._group_class_term_cfgs[group_name].append(term_cfg) + # call reset on the term + term_cfg.func.reset()