Skip to content

Commit

Permalink
docs: fix offline data collection docs (#289)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj authored Nov 18, 2023
1 parent d55958a commit 6d31328
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
6 changes: 3 additions & 3 deletions benchmarks/offline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ pip install safety_gymnasium
## Training agents used to generate data

```bash
omnisafe train --env-id SafetyAntVelocity-v1 --algo PPO
omnisafe train --env-id SafetyAntVelocity-v1 --algo PPOLag
```

## Collect offline data

The `PATH_TO_AGENT` is the path of the directory containing the `torch_save`.

```python
from omnisafe.common.offline.data_collector import OfflineDataCollector

Expand All @@ -40,8 +41,7 @@ from omnisafe.common.offline.data_collector import OfflineDataCollector
env_name = 'SafetyAntVelocity-v1'
size = 1_000_000
agents = [
('./runs/PPO', 'epoch-500', 500_000),
('./runs/PPOLag', 'epoch-500', 500_000),
('PATH_TO_AGENT', 'epoch-500.pt', 1_000_000),
]
save_dir = './data'

Expand Down
18 changes: 9 additions & 9 deletions examples/collect_offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
# also, please make sure you have run:
# python train_policy.py --algo PPO --env ENVID
# where ENVID is the environment from which you want to collect data.
# The `PATH_TO_AGENT` is the directory path containing the `torch_save`.

ENV_NAME = 'SafetyPointCircle1-v0'
SIZE = 2_000_000
AGENTS = [
('./runs/PPO', 'epoch-500', 1_000_000),
('./runs/PPOLag', 'epoch-500', 1_000_000),
env_name = 'SafetyAntVelocity-v1'
size = 1_000_000
agents = [
('PATH_TO_AGENT', 'epoch-500.pt', 1_000_000),
]
SAVE_DIR = './data'
save_dir = './data'

if __name__ == '__main__':
col = OfflineDataCollector(SIZE, ENV_NAME)
for agent, model_name, num in AGENTS:
col = OfflineDataCollector(size, env_name)
for agent, model_name, num in agents:
col.register_agent(agent, model_name, num)
col.collect(SAVE_DIR)
col.collect(save_dir)

0 comments on commit 6d31328

Please sign in to comment.