-
-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-Authored-By: Roger Creus <31919499+roger-creus@users.noreply.github.com>
- Loading branch information
1 parent
0940ab4
commit 02a6c9b
Showing
23 changed files
with
953 additions
and
398 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,110 +1,157 @@ | ||
# | ||
|
||
|
||
## BaseIntrinsicRewardModule | ||
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L35) | ||
## BaseReward | ||
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L38) | ||
```python | ||
BaseIntrinsicRewardModule( | ||
observation_space: gym.Space, action_space: gym.Space, device: str = 'cpu', | ||
beta: float = 0.05, kappa: float = 2.5e-05 | ||
BaseReward( | ||
envs: VectorEnv, device: str = 'cpu', beta: float = 1.0, kappa: float = 0.0, | ||
gamma: Optional[float] = None, rwd_norm_type: str = 'rms', obs_norm_type: str = 'rms' | ||
) | ||
``` | ||
|
||
|
||
--- | ||
Base class of intrinsic reward module. | ||
Base class of reward module. | ||
|
||
|
||
**Args** | ||
|
||
* **observation_space** (gym.Space) : The observation space of environment. | ||
* **action_space** (gym.Space) : The action space of environment. | ||
* **envs** (VectorEnv) : The vectorized environments. | ||
* **device** (str) : Device (cpu, cuda, ...) on which the code should be run. | ||
* **beta** (float) : The initial weighting coefficient of the intrinsic rewards. | ||
* **kappa** (float) : The decay rate. | ||
* **kappa** (float) : The decay rate of the weighting coefficient. | ||
* **gamma** (Optional[float]) : Intrinsic reward discount rate, default is `None`. | ||
* **rwd_norm_type** (str) : Normalization type for intrinsic rewards from ['rms', 'minmax', 'none']. | ||
* **obs_norm_type** (str) : Normalization type for observations data from ['rms', 'none']. | ||
|
||
|
||
**Returns** | ||
|
||
Instance of the base intrinsic reward module. | ||
Instance of the base reward module. | ||
|
||
|
||
**Methods:** | ||
|
||
|
||
### .compute_irs | ||
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L70) | ||
### .weight | ||
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L101) | ||
```python | ||
.compute_irs( | ||
samples: Dict, step: int = 0 | ||
.weight() | ||
``` | ||
|
||
--- | ||
Get the weighting coefficient of the intrinsic rewards. | ||
|
||
### .scale | ||
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L105) | ||
```python | ||
.scale( | ||
rewards: th.Tensor | ||
) | ||
``` | ||
|
||
--- | ||
Compute the intrinsic rewards for current samples. | ||
Scale the intrinsic rewards. | ||
|
||
|
||
**Args** | ||
|
||
* **samples** (Dict) : The collected samples. A python dict like | ||
{obs (n_steps, n_envs, *obs_shape) <class 'th.Tensor'>, | ||
actions (n_steps, n_envs, *action_shape) <class 'th.Tensor'>, | ||
rewards (n_steps, n_envs) <class 'th.Tensor'>, | ||
next_obs (n_steps, n_envs, *obs_shape) <class 'th.Tensor'>}. | ||
* **step** (int) : The global training step. | ||
* **rewards** (th.Tensor) : The intrinsic rewards with shape (n_steps, n_envs). | ||
|
||
|
||
**Returns** | ||
|
||
The intrinsic rewards. | ||
The scaled intrinsic rewards. | ||
|
||
### .update | ||
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L86) | ||
### .normalize | ||
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L131) | ||
```python | ||
.update( | ||
samples: Dict | ||
.normalize( | ||
x: th.Tensor | ||
) | ||
``` | ||
|
||
--- | ||
Normalize the observations data, especially useful for images-based observations. | ||
|
||
### .init_normalization | ||
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L142) | ||
```python | ||
.init_normalization() | ||
``` | ||
|
||
--- | ||
Initialize the normalization parameters for observations if the RMS is used. | ||
|
||
### .watch | ||
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L186) | ||
```python | ||
.watch( | ||
observations: th.Tensor, actions: th.Tensor, rewards: th.Tensor, | ||
terminateds: th.Tensor, truncateds: th.Tensor, next_observations: th.Tensor | ||
) | ||
``` | ||
|
||
--- | ||
Watch the interaction processes and obtain necessary elements for reward computation. | ||
|
||
|
||
**Args** | ||
|
||
* **observations** (th.Tensor) : Observations data with shape (n_envs, *obs_shape). | ||
* **actions** (th.Tensor) : Actions data with shape (n_envs, *action_shape). | ||
* **rewards** (th.Tensor) : Extrinsic rewards data with shape (n_envs). | ||
* **terminateds** (th.Tensor) : Termination signals with shape (n_envs). | ||
* **truncateds** (th.Tensor) : Truncation signals with shape (n_envs). | ||
* **next_observations** (th.Tensor) : Next observations data with shape (n_envs, *obs_shape). | ||
|
||
|
||
**Returns** | ||
|
||
Feedbacks for the current samples. | ||
|
||
### .compute | ||
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L210) | ||
```python | ||
.compute( | ||
samples: Dict[str, th.Tensor], sync: bool = True | ||
) | ||
``` | ||
|
||
--- | ||
Update the intrinsic reward module if necessary. | ||
Compute the rewards for current samples. | ||
|
||
|
||
**Args** | ||
|
||
* **samples** : The collected samples. A python dict like | ||
{obs (n_steps, n_envs, *obs_shape) <class 'th.Tensor'>, | ||
actions (n_steps, n_envs, *action_shape) <class 'th.Tensor'>, | ||
rewards (n_steps, n_envs) <class 'th.Tensor'>, | ||
next_obs (n_steps, n_envs, *obs_shape) <class 'th.Tensor'>}. | ||
* **samples** (Dict[str, th.Tensor]) : The collected samples. A python dict consists of multiple tensors, | ||
whose keys are ['observations', 'actions', 'rewards', 'terminateds', 'truncateds', 'next_observations']. | ||
For example, the data shape of 'observations' is (n_steps, n_envs, *obs_shape). | ||
* **sync** (bool) : Whether to update the reward module after the `compute` function, default is `True`. | ||
|
||
|
||
**Returns** | ||
|
||
None | ||
The intrinsic rewards. | ||
|
||
### .add | ||
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L104) | ||
### .update | ||
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L241) | ||
```python | ||
.add( | ||
samples: Dict | ||
.update( | ||
samples: Dict[str, th.Tensor] | ||
) | ||
``` | ||
|
||
--- | ||
Add the samples to the intrinsic reward module if necessary. | ||
User for modules like `RE3` that have a storage component. | ||
Update the reward module if necessary. | ||
|
||
|
||
**Args** | ||
|
||
* **samples** : The collected samples. A python dict like | ||
{obs (n_steps, n_envs, *obs_shape) <class 'th.Tensor'>, | ||
actions (n_steps, n_envs, *action_shape) <class 'th.Tensor'>, | ||
rewards (n_steps, n_envs) <class 'th.Tensor'>, | ||
next_obs (n_steps, n_envs, *obs_shape) <class 'th.Tensor'>}. | ||
* **samples** (Dict[str, th.Tensor]) : The collected samples same as the `compute` function. | ||
|
||
|
||
**Returns** | ||
|
||
None | ||
None. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.