Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
Co-Authored-By: Roger Creus <31919499+roger-creus@users.noreply.github.com>
  • Loading branch information
yuanmingqi and roger-creus committed Aug 17, 2024
1 parent 0940ab4 commit 02a6c9b
Show file tree
Hide file tree
Showing 23 changed files with 953 additions and 398 deletions.
4 changes: 2 additions & 2 deletions docs/api_docs/agent/drqv2.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ DrQv2 agent instance.
Update the agent and return training metrics such as actor loss, critic_loss, etc.

### .update_critic
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/drqv2.py/#L189)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/drqv2.py/#L177)
```python
.update_critic(
obs: th.Tensor, actions: th.Tensor, rewards: th.Tensor, discount: th.Tensor,
Expand All @@ -86,7 +86,7 @@ Update the critic network.
None.

### .update_actor
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/drqv2.py/#L236)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/drqv2.py/#L224)
```python
.update_actor(
obs: th.Tensor
Expand Down
8 changes: 4 additions & 4 deletions docs/api_docs/agent/legacy/ddpg.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


## DDPG
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/ddpg.py/#L41)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/ddpg.py/#L42)
```python
DDPG(
env: VecEnv, eval_env: Optional[VecEnv] = None, tag: str = 'default', seed: int = 1,
Expand Down Expand Up @@ -51,7 +51,7 @@ DDPG agent instance.


### .update
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/ddpg.py/#L145)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/ddpg.py/#L146)
```python
.update()
```
Expand All @@ -60,7 +60,7 @@ DDPG agent instance.
Update the agent and return training metrics such as actor loss, critic_loss, etc.

### .update_critic
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/ddpg.py/#L186)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/ddpg.py/#L175)
```python
.update_critic(
obs: th.Tensor, actions: th.Tensor, rewards: th.Tensor, terminateds: th.Tensor,
Expand All @@ -87,7 +87,7 @@ Update the critic network.
None.

### .update_actor
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/ddpg.py/#L235)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/ddpg.py/#L225)
```python
.update_actor(
obs: th.Tensor
Expand Down
4 changes: 2 additions & 2 deletions docs/api_docs/agent/legacy/sac.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Get the temperature coefficient.
Update the agent and return training metrics such as actor loss, critic_loss, etc.

### .update_critic
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/sac.py/#L206)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/sac.py/#L194)
```python
.update_critic(
obs: th.Tensor, actions: th.Tensor, rewards: th.Tensor, terminateds: th.Tensor,
Expand All @@ -103,7 +103,7 @@ Update the critic network.
None.

### .update_actor_and_alpha
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/sac.py/#L256)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/sac.py/#L244)
```python
.update_actor_and_alpha(
obs: th.Tensor
Expand Down
6 changes: 3 additions & 3 deletions docs/api_docs/agent/legacy/sacd.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Get the temperature coefficient.
Update the agent and return training metrics such as actor loss, critic_loss, etc.

### .deal_with_zero_probs
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/sacd.py/#L206)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/sacd.py/#L194)
```python
.deal_with_zero_probs(
action_probs: th.Tensor
Expand All @@ -97,7 +97,7 @@ Deal with situation of 0.0 probabilities.
Action probabilities and its log values.

### .update_critic
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/sacd.py/#L220)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/sacd.py/#L208)
```python
.update_critic(
obs: th.Tensor, actions: th.Tensor, rewards: th.Tensor, terminateds: th.Tensor,
Expand All @@ -124,7 +124,7 @@ Update the critic network.
None.

### .update_actor_and_alpha
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/sacd.py/#L270)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/agent/legacy/sacd.py/#L258)
```python
.update_actor_and_alpha(
obs: th.Tensor
Expand Down
16 changes: 8 additions & 8 deletions docs/api_docs/common/prototype/base_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Base agent instance.


### .freeze
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L151)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L153)
```python
.freeze(
**kwargs
Expand All @@ -45,7 +45,7 @@ Base agent instance.
Freeze the agent and get ready for training.

### .check
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L172)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L174)
```python
.check()
```
Expand All @@ -54,7 +54,7 @@ Freeze the agent and get ready for training.
Check the compatibility of selected modules.

### .set
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L198)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L200)
```python
.set(
encoder: Optional[Encoder] = None, policy: Optional[Policy] = None,
Expand Down Expand Up @@ -85,7 +85,7 @@ Set a module for the agent.
None.

### .mode
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L238)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L240)
```python
.mode(
training: bool = True
Expand All @@ -106,7 +106,7 @@ Set the training mode.
None.

### .save
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L250)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L252)
```python
.save()
```
Expand All @@ -115,7 +115,7 @@ None.
Save the agent.

### .update
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L262)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L264)
```python
.update(
*args, **kwargs
Expand All @@ -126,7 +126,7 @@ Save the agent.
Update function of the agent.

### .train
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L266)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L268)
```python
.train(
num_train_steps: int, init_model_path: Optional[str], log_interval: int,
Expand Down Expand Up @@ -154,7 +154,7 @@ Training function.
None.

### .eval
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L292)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_agent.py/#L294)
```python
.eval(
num_eval_episodes: int
Expand Down
137 changes: 92 additions & 45 deletions docs/api_docs/common/prototype/base_reward.md
Original file line number Diff line number Diff line change
@@ -1,110 +1,157 @@
#


## BaseIntrinsicRewardModule
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L35)
## BaseReward
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L38)
```python
BaseIntrinsicRewardModule(
observation_space: gym.Space, action_space: gym.Space, device: str = 'cpu',
beta: float = 0.05, kappa: float = 2.5e-05
BaseReward(
envs: VectorEnv, device: str = 'cpu', beta: float = 1.0, kappa: float = 0.0,
gamma: Optional[float] = None, rwd_norm_type: str = 'rms', obs_norm_type: str = 'rms'
)
```


---
Base class of intrinsic reward module.
Base class of reward module.


**Args**

* **observation_space** (gym.Space) : The observation space of environment.
* **action_space** (gym.Space) : The action space of environment.
* **envs** (VectorEnv) : The vectorized environments.
* **device** (str) : Device (cpu, cuda, ...) on which the code should be run.
* **beta** (float) : The initial weighting coefficient of the intrinsic rewards.
* **kappa** (float) : The decay rate.
* **kappa** (float) : The decay rate of the weighting coefficient.
* **gamma** (Optional[float]) : Intrinsic reward discount rate, default is `None`.
* **rwd_norm_type** (str) : Normalization type for intrinsic rewards from ['rms', 'minmax', 'none'].
* **obs_norm_type** (str) : Normalization type for observations data from ['rms', 'none'].


**Returns**

Instance of the base intrinsic reward module.
Instance of the base reward module.


**Methods:**


### .compute_irs
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L70)
### .weight
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L101)
```python
.compute_irs(
samples: Dict, step: int = 0
.weight()
```

---
Get the weighting coefficient of the intrinsic rewards.

### .scale
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L105)
```python
.scale(
rewards: th.Tensor
)
```

---
Compute the intrinsic rewards for current samples.
Scale the intrinsic rewards.


**Args**

* **samples** (Dict) : The collected samples. A python dict like
{obs (n_steps, n_envs, *obs_shape) <class 'th.Tensor'>,
actions (n_steps, n_envs, *action_shape) <class 'th.Tensor'>,
rewards (n_steps, n_envs) <class 'th.Tensor'>,
next_obs (n_steps, n_envs, *obs_shape) <class 'th.Tensor'>}.
* **step** (int) : The global training step.
* **rewards** (th.Tensor) : The intrinsic rewards with shape (n_steps, n_envs).


**Returns**

The intrinsic rewards.
The scaled intrinsic rewards.

### .update
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L86)
### .normalize
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L131)
```python
.update(
samples: Dict
.normalize(
x: th.Tensor
)
```

---
Normalize the observations data, especially useful for images-based observations.

### .init_normalization
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L142)
```python
.init_normalization()
```

---
Initialize the normalization parameters for observations if the RMS is used.

### .watch
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L186)
```python
.watch(
observations: th.Tensor, actions: th.Tensor, rewards: th.Tensor,
terminateds: th.Tensor, truncateds: th.Tensor, next_observations: th.Tensor
)
```

---
Watch the interaction processes and obtain necessary elements for reward computation.


**Args**

* **observations** (th.Tensor) : Observations data with shape (n_envs, *obs_shape).
* **actions** (th.Tensor) : Actions data with shape (n_envs, *action_shape).
* **rewards** (th.Tensor) : Extrinsic rewards data with shape (n_envs).
* **terminateds** (th.Tensor) : Termination signals with shape (n_envs).
* **truncateds** (th.Tensor) : Truncation signals with shape (n_envs).
* **next_observations** (th.Tensor) : Next observations data with shape (n_envs, *obs_shape).


**Returns**

Feedbacks for the current samples.

### .compute
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L210)
```python
.compute(
samples: Dict[str, th.Tensor], sync: bool = True
)
```

---
Update the intrinsic reward module if necessary.
Compute the rewards for current samples.


**Args**

* **samples** : The collected samples. A python dict like
{obs (n_steps, n_envs, *obs_shape) <class 'th.Tensor'>,
actions (n_steps, n_envs, *action_shape) <class 'th.Tensor'>,
rewards (n_steps, n_envs) <class 'th.Tensor'>,
next_obs (n_steps, n_envs, *obs_shape) <class 'th.Tensor'>}.
* **samples** (Dict[str, th.Tensor]) : The collected samples. A python dict consists of multiple tensors,
whose keys are ['observations', 'actions', 'rewards', 'terminateds', 'truncateds', 'next_observations'].
For example, the data shape of 'observations' is (n_steps, n_envs, *obs_shape).
* **sync** (bool) : Whether to update the reward module after the `compute` function, default is `True`.


**Returns**

None
The intrinsic rewards.

### .add
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L104)
### .update
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/base_reward.py/#L241)
```python
.add(
samples: Dict
.update(
samples: Dict[str, th.Tensor]
)
```

---
Add the samples to the intrinsic reward module if necessary.
User for modules like `RE3` that have a storage component.
Update the reward module if necessary.


**Args**

* **samples** : The collected samples. A python dict like
{obs (n_steps, n_envs, *obs_shape) <class 'th.Tensor'>,
actions (n_steps, n_envs, *action_shape) <class 'th.Tensor'>,
rewards (n_steps, n_envs) <class 'th.Tensor'>,
next_obs (n_steps, n_envs, *obs_shape) <class 'th.Tensor'>}.
* **samples** (Dict[str, th.Tensor]) : The collected samples same as the `compute` function.


**Returns**

None
None.
2 changes: 1 addition & 1 deletion docs/api_docs/common/prototype/off_policy_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Training function.
None.

### .eval
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/off_policy_agent.py/#L205)
[source](https://github.com/RLE-Foundation/rllte/blob/main/rllte/common/prototype/off_policy_agent.py/#L211)
```python
.eval(
num_eval_episodes: int
Expand Down
Loading

0 comments on commit 02a6c9b

Please sign in to comment.