Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zjow): add qgpo policy for new DI-engine pipeline #757

Merged
merged 39 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0e8c8b3
Add CEP
zjowowen Oct 16, 2023
0b40be5
Add CEP
zjowowen Oct 16, 2023
8ca8495
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP
zjowowen Oct 17, 2023
ab04888
Add halfcheetah
zjowowen Oct 17, 2023
11563ac
Add halfcheetah
zjowowen Oct 17, 2023
b5e6774
add d4rl envs
zjowowen Oct 17, 2023
94bc61a
change setup.py
zjowowen Oct 17, 2023
b8bc493
polish code
zjowowen Oct 17, 2023
48837e9
change config
zjowowen Oct 17, 2023
e5d0b32
fix lr bug
zjowowen Dec 7, 2023
0f20c91
polish code for qgpo
zjowowen Dec 7, 2023
504a108
polish code for qgpo
zjowowen Dec 7, 2023
b251a2d
merge from main
zjowowen Dec 7, 2023
a36d301
polish code
zjowowen Dec 7, 2023
9263965
polish code
zjowowen Dec 7, 2023
177f44a
polish code
zjowowen Dec 7, 2023
5cddcbe
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Dec 26, 2023
3203918
polish code
zjowowen Dec 26, 2023
824a5f1
polish code
zjowowen Dec 27, 2023
b544200
polish code
zjowowen Dec 27, 2023
36022c8
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Dec 27, 2023
62b4f84
polish code
zjowowen Dec 27, 2023
a45f43b
polish code
zjowowen Dec 27, 2023
ca59da9
polish code
zjowowen Dec 27, 2023
4a45e24
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Dec 27, 2023
8370185
polish code
zjowowen Dec 29, 2023
a2d03b8
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Dec 29, 2023
62d8957
add hopper walker2d qgpo config
zjowowen Jan 2, 2024
1763141
add doc
zjowowen Jan 2, 2024
12e3757
polish code
zjowowen Jan 16, 2024
7fb422d
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Jan 16, 2024
25ca4f5
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Jan 16, 2024
48cd98f
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Jan 31, 2024
8002cd6
polish notation
zjowowen Jan 31, 2024
bec4211
polish notation
zjowowen Jan 31, 2024
ef9d2a1
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Feb 1, 2024
aeae19d
fix bug in data generator
zjowowen Feb 1, 2024
21062c4
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Feb 4, 2024
e28b1c5
polish code
zjowowen Feb 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,16 +251,17 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 43 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [TD3BC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3_bc.html)<br>[policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u d4rl_td3_bc_main.py |
| 44 | [Decision Transformer](https://arxiv.org/pdf/2106.01345.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/dt](https://github.com/opendilab/DI-engine/blob/main/ding/policy/decision_transformer.py) | python3 -u d4rl_dt_main.py |
| 45 | [EDAC](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)<br>[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py |
| 46 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py |
| 47 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py |
| 48 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)<br>[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py |
| 49 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py |
| 50 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py |
| 51 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 52 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 53 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
| 54 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
| 55 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
| 46 | [QGPO](https://arxiv.org/pdf/2304.12824.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [QGPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qgpo.html)<br>[policy/qgpo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qgpo.py) | python3 -u ding/example/qgpo.py |
| 47 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py |
| 48 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py |
| 49 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)<br>[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py |
| 50 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py |
| 51 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py |
| 52 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 53 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 54 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
| 55 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
| 56 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
</details>


Expand Down
141 changes: 141 additions & 0 deletions ding/example/qgpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import torch
import gym
import d4rl
from easydict import EasyDict
from ditk import logging
from ding.model import QGPO
from ding.policy import QGPOPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.framework.middleware import trainer, CkptSaver, offline_logger, wandb_offline_logger, termination_checker
from ding.framework.middleware.functional.evaluator import qgpo_interaction_evaluator
from ding.framework.middleware.functional.data_processor import qgpo_support_data_generator, qgpo_offline_data_fetcher
from ding.utils import set_pkg_seed

from dizoo.d4rl.config.halfcheetah_medium_expert_qgpo_config import main_config, create_config


class QGPOD4RLDataset(torch.utils.data.Dataset):
"""
Overview:
Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, \
which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
is sampled from the action support generated by the behavior policy.
"""

def __init__(self, cfg, device="cpu"):
self.cfg = cfg
data = d4rl.qlearning_dataset(gym.make(cfg.env_id))
self.device = device
self.states = torch.from_numpy(data['observations']).float().to(self.device)
self.actions = torch.from_numpy(data['actions']).float().to(self.device)
self.next_states = torch.from_numpy(data['next_observations']).float().to(self.device)
reward = torch.from_numpy(data['rewards']).view(-1, 1).float().to(self.device)
self.is_finished = torch.from_numpy(data['terminals']).view(-1, 1).float().to(self.device)

reward_tune = "iql_antmaze" if "antmaze" in cfg.env_id else "iql_locomotion"
if reward_tune == 'normalize':
reward = (reward - reward.mean()) / reward.std()
elif reward_tune == 'iql_antmaze':
reward = reward - 1.0
elif reward_tune == 'iql_locomotion':
min_ret, max_ret = QGPOD4RLDataset.return_range(data, 1000)
reward /= (max_ret - min_ret)
reward *= 1000
elif reward_tune == 'cql_antmaze':
reward = (reward - 0.5) * 4.0
elif reward_tune == 'antmaze':
reward = (reward - 0.25) * 2.0
self.rewards = reward
print("dql dataloard loaded")
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved

self.len = self.states.shape[0]
print(self.len, "data loaded")

def __getitem__(self, index):
data = {
's': self.states[index % self.len],
'a': self.actions[index % self.len],
'r': self.rewards[index % self.len],
's_': self.next_states[index % self.len],
'd': self.is_finished[index % self.len],
'fake_a': self.fake_actions[index % self.len]
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(self, "fake_actions") else 0.0, # self.fake_actions <D, 16, A>
'fake_a_': self.fake_next_actions[index % self.len]
if hasattr(self, "fake_next_actions") else 0.0, # self.fake_next_actions <D, 16, A>
}
return data

def __len__(self):
return self.len

def return_range(dataset, max_episode_steps):
returns, lengths = [], []
ep_ret, ep_len = 0., 0
for r, d in zip(dataset['rewards'], dataset['terminals']):
ep_ret += float(r)
ep_len += 1
if d or ep_len == max_episode_steps:
returns.append(ep_ret)
lengths.append(ep_len)
ep_ret, ep_len = 0., 0
# returns.append(ep_ret) # incomplete trajectory
lengths.append(ep_len) # but still keep track of number of steps
assert sum(lengths) == len(dataset['rewards'])
return min(returns), max(returns)


def main():
# If you don't have offline data, you need to prepare if first and set the data_path in config
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, policy=QGPOPolicy)
ding_init(cfg)
with task.start(async_mode=False, ctx=OfflineRLContext()):
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
model = QGPO(cfg=cfg.policy.model)
policy = QGPOPolicy(cfg.policy, model=model)
dataset = QGPOD4RLDataset(cfg=cfg.dataset, device=policy._device)
if hasattr(cfg.policy, "load_path") and cfg.policy.load_path is not None:
policy_state_dict = torch.load(cfg.policy.load_path, map_location=torch.device("cpu"))
policy.learn_mode.load_state_dict(policy_state_dict)

evaluator_env = BaseEnvManagerV2(
env_fn=[
lambda: DingEnvWrapper(env=gym.make(cfg.env.env_id), cfg=cfg.env, caller="evaluator")
for _ in range(cfg.env.evaluator_env_num)
],
cfg=cfg.env.manager
)

task.use(qgpo_support_data_generator(cfg, dataset, policy))
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
task.use(qgpo_offline_data_fetcher(cfg, dataset, collate_fn=None))
task.use(trainer(cfg, policy.learn_mode))
task.use(qgpo_interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(
wandb_offline_logger(
cfg=EasyDict(
dict(
gradient_logger=False,
plot_logger=True,
video_logger=False,
action_logger=False,
return_logger=False,
vis_dataset=False,
)
),
exp_config=cfg,
metric_list=policy._monitor_vars_learn(),
project_name=cfg.exp_name
)
)
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100000))
task.use(offline_logger())
task.use(termination_checker(max_train_iter=500000 + cfg.policy.learn.q_value_stop_training_iter))
task.run()


if __name__ == "__main__":
main()
Loading
Loading