Skip to content

Commit

Permalink
feature(nyp): add DQfD algorithm (#48)
Browse files Browse the repository at this point in the history
* add_dqfd

* Is_expert to is_expert

* modify according to the last commnets

* value_gamma; done; marginloss; sqil compatibility

* finally shorten the code, revise config

* revise config, style

* add_readme/two_more_config

* correct format

Co-authored-by: niuyazhe <niuyazhe@sensetime.com>
  • Loading branch information
Will-Nie and PaParaZz1 authored Oct 16, 2021
1 parent 8efee98 commit e2ca873
Show file tree
Hide file tree
Showing 17 changed files with 1,222 additions and 14 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,13 @@ ding -m serial -e cartpole -p dqn -s 0
| 20 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 |
| 21 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_reward_model -c cartpole_dqn_config.py -s 0 |
| 22 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
| 23 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 24 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 25 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 26 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 27 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 28 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py |
| 23 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![IL](https://img.shields.io/badge/-discrete-brightgreen) | [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
| 24 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 25 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 26 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 27 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 28 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 29 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py |

![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms(1-15)

Expand Down
2 changes: 2 additions & 0 deletions ding/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@
from .serial_entry_offline import serial_pipeline_offline
from .serial_entry_il import serial_pipeline_il
from .serial_entry_reward_model import serial_pipeline_reward_model
from .serial_entry_dqfd import serial_pipeline_dqfd
from .serial_entry_sqil import serial_pipeline_sqil
from .parallel_entry import parallel_pipeline
from .application_entry import eval, collect_demo_data
8 changes: 7 additions & 1 deletion ding/entry/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def print_registry(ctx: Context, param: Option, value: str):
@click.option(
'-m',
'--mode',
type=click.Choice(['serial', 'serial_onpolicy', 'serial_sqil', 'parallel', 'dist', 'eval']),
type=click.Choice(['serial', 'serial_onpolicy', 'serial_sqil', 'serial_dqfd', 'parallel', 'dist', 'eval']),
help='serial-train or parallel-train or dist-train or eval'
)
@click.option('-c', '--config', type=str, help='Path to DRL experiment config')
Expand Down Expand Up @@ -157,6 +157,12 @@ def cli(
config = get_predefined_config(env, policy)
expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_sqil(config, expert_config, seed, max_iterations=train_iter)
elif mode == 'serial_dqfd':
from .serial_entry_dqfd import serial_pipeline_dqfd
if config is None:
config = get_predefined_config(env, policy)
expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_dqfd(config, expert_config, seed, max_iterations=train_iter)
elif mode == 'parallel':
from .parallel_entry import parallel_pipeline
parallel_pipeline(config, seed, enable_total_log, disable_flask_log)
Expand Down
Loading

0 comments on commit e2ca873

Please sign in to comment.