Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zjow): add qgpo policy for new DI-engine pipeline #757

Merged
merged 39 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0e8c8b3
Add CEP
zjowowen Oct 16, 2023
0b40be5
Add CEP
zjowowen Oct 16, 2023
8ca8495
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP
zjowowen Oct 17, 2023
ab04888
Add halfcheetah
zjowowen Oct 17, 2023
11563ac
Add halfcheetah
zjowowen Oct 17, 2023
b5e6774
add d4rl envs
zjowowen Oct 17, 2023
94bc61a
change setup.py
zjowowen Oct 17, 2023
b8bc493
polish code
zjowowen Oct 17, 2023
48837e9
change config
zjowowen Oct 17, 2023
e5d0b32
fix lr bug
zjowowen Dec 7, 2023
0f20c91
polish code for qgpo
zjowowen Dec 7, 2023
504a108
polish code for qgpo
zjowowen Dec 7, 2023
b251a2d
merge from main
zjowowen Dec 7, 2023
a36d301
polish code
zjowowen Dec 7, 2023
9263965
polish code
zjowowen Dec 7, 2023
177f44a
polish code
zjowowen Dec 7, 2023
5cddcbe
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Dec 26, 2023
3203918
polish code
zjowowen Dec 26, 2023
824a5f1
polish code
zjowowen Dec 27, 2023
b544200
polish code
zjowowen Dec 27, 2023
36022c8
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Dec 27, 2023
62b4f84
polish code
zjowowen Dec 27, 2023
a45f43b
polish code
zjowowen Dec 27, 2023
ca59da9
polish code
zjowowen Dec 27, 2023
4a45e24
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Dec 27, 2023
8370185
polish code
zjowowen Dec 29, 2023
a2d03b8
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Dec 29, 2023
62d8957
add hopper walker2d qgpo config
zjowowen Jan 2, 2024
1763141
add doc
zjowowen Jan 2, 2024
12e3757
polish code
zjowowen Jan 16, 2024
7fb422d
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Jan 16, 2024
25ca4f5
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Jan 16, 2024
48cd98f
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Jan 31, 2024
8002cd6
polish notation
zjowowen Jan 31, 2024
bec4211
polish notation
zjowowen Jan 31, 2024
ef9d2a1
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Feb 1, 2024
aeae19d
fix bug in data generator
zjowowen Feb 1, 2024
21062c4
Merge branch 'main' of https://github.com/zjowowen/DI-engine into CEP-pr
zjowowen Feb 4, 2024
e28b1c5
polish code
zjowowen Feb 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions ding/example/qgpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
import gym
import d4rl
from easydict import EasyDict
from ditk import logging
from ding.model import QGPO
from ding.policy import QGPOPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.framework.middleware import trainer, CkptSaver, offline_logger, wandb_offline_logger, termination_checker
from ding.framework.middleware.functional.evaluator import qgpo_interaction_evaluator
from ding.framework.middleware.functional.data_processor import qgpo_support_data_generator, qgpo_offline_data_fetcher
from ding.utils import set_pkg_seed

from dizoo.d4rl.config.halfcheetah_medium_expert_qgpo_config import main_config, create_config


class QGPO_D4RLDataset(torch.utils.data.Dataset):
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, cfg, device="cpu"):
self.cfg = cfg
data = d4rl.qlearning_dataset(gym.make(cfg.env_id))
self.device = device
self.states = torch.from_numpy(data['observations']).float().to(self.device)
self.actions = torch.from_numpy(data['actions']).float().to(self.device)
self.next_states = torch.from_numpy(data['next_observations']).float().to(self.device)
reward = torch.from_numpy(data['rewards']).view(-1, 1).float().to(self.device)
self.is_finished = torch.from_numpy(data['terminals']).view(-1, 1).float().to(self.device)

reward_tune = "iql_antmaze" if "antmaze" in cfg.env_id else "iql_locomotion"
if reward_tune == 'normalize':
reward = (reward - reward.mean()) / reward.std()
elif reward_tune == 'iql_antmaze':
reward = reward - 1.0
elif reward_tune == 'iql_locomotion':
min_ret, max_ret = QGPO_D4RLDataset.return_range(data, 1000)
reward /= (max_ret - min_ret)
reward *= 1000
elif reward_tune == 'cql_antmaze':
reward = (reward - 0.5) * 4.0
elif reward_tune == 'antmaze':
reward = (reward - 0.25) * 2.0
self.rewards = reward
print("dql dataloard loaded")
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved

self.len = self.states.shape[0]
print(self.len, "data loaded")

def __getitem__(self, index):
data = {
's': self.states[index % self.len],
'a': self.actions[index % self.len],
'r': self.rewards[index % self.len],
's_': self.next_states[index % self.len],
'd': self.is_finished[index % self.len],
'fake_a': self.fake_actions[index % self.len]
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(self, "fake_actions") else 0.0, # self.fake_actions <D, 16, A>
'fake_a_': self.fake_next_actions[index % self.len]
if hasattr(self, "fake_next_actions") else 0.0, # self.fake_next_actions <D, 16, A>
}
return data

def __add__(self, other):
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
pass

def __len__(self):
return self.len

def return_range(dataset, max_episode_steps):
returns, lengths = [], []
ep_ret, ep_len = 0., 0
for r, d in zip(dataset['rewards'], dataset['terminals']):
ep_ret += float(r)
ep_len += 1
if d or ep_len == max_episode_steps:
returns.append(ep_ret)
lengths.append(ep_len)
ep_ret, ep_len = 0., 0
# returns.append(ep_ret) # incomplete trajectory
lengths.append(ep_len) # but still keep track of number of steps
assert sum(lengths) == len(dataset['rewards'])
return min(returns), max(returns)


def main():
# If you don't have offline data, you need to prepare if first and set the data_path in config
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, policy=QGPOPolicy)
ding_init(cfg)
with task.start(async_mode=False, ctx=OfflineRLContext()):
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
model = QGPO(cfg=cfg.policy.model)
policy = QGPOPolicy(cfg.policy, model=model)
dataset = QGPO_D4RLDataset(cfg=cfg.dataset, device=policy._device)
if hasattr(cfg.policy, "load_path") and cfg.policy.load_path is not None:
policy_state_dict = torch.load(cfg.policy.load_path, map_location=torch.device("cpu"))
policy.learn_mode.load_state_dict(policy_state_dict)

evaluator_env = BaseEnvManagerV2(
env_fn=[
lambda: DingEnvWrapper(env=gym.make(cfg.env.env_id), cfg=cfg.env, caller="evaluator")
for _ in range(cfg.env.evaluator_env_num)
],
cfg=cfg.env.manager
)

task.use(qgpo_support_data_generator(cfg, dataset, policy))
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
task.use(qgpo_offline_data_fetcher(cfg, dataset, collate_fn=None))
task.use(trainer(cfg, policy.learn_mode))
task.use(qgpo_interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(
wandb_offline_logger(
cfg=EasyDict(
dict(
gradient_logger=False,
plot_logger=True,
video_logger=False,
action_logger=False,
return_logger=False,
vis_dataset=False,
)
),
exp_config=cfg,
metric_list=policy._monitor_vars_learn(),
project_name=cfg.exp_name
)
)
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100000))
task.use(offline_logger())
task.use(termination_checker(max_train_iter=500000 + cfg.policy.learn.q_value_stop_training_iter))
task.run()


if __name__ == "__main__":
main()
126 changes: 123 additions & 3 deletions ding/framework/middleware/functional/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import TYPE_CHECKING, Callable, List, Union, Tuple, Dict, Optional
from easydict import EasyDict
from ditk import logging
import numpy as np
import torch
import tqdm
from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type
from ding.data.buffer.middleware import PriorityExperienceReplay
from ding.framework import task
Expand Down Expand Up @@ -226,7 +228,7 @@ def _fetch(ctx: "OfflineRLContext"):
return _fetch


def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable:
def offline_data_fetcher(cfg: EasyDict, dataset: Dataset, collate_fn=lambda x: x) -> Callable:
"""
Overview:
The outer function transforms a Pytorch `Dataset` to `DataLoader`. \
Expand All @@ -238,7 +240,7 @@ def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable:
- dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data.
"""
# collate_fn is executed in policy now
dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)
dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=collate_fn)
dataloader = iter(dataloader)

def _fetch(ctx: "OfflineRLContext"):
Expand All @@ -258,7 +260,7 @@ def _fetch(ctx: "OfflineRLContext"):
ctx.train_epoch += 1
del dataloader
dataloader = DataLoader(
dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x
dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=collate_fn
)
dataloader = iter(dataloader)
ctx.train_data = next(dataloader)
Expand Down Expand Up @@ -316,3 +318,121 @@ def _pusher(ctx: "OnlineRLContext"):
ctx.trajectories = None

return _pusher


def qgpo_support_data_generator(cfg, dataset, policy) -> Callable:

behavior_policy_stop_training_iter = cfg.policy.learn.behavior_policy_stop_training_iter if hasattr(
cfg.policy.learn, 'behavior_policy_stop_training_iter'
) else np.inf
energy_guided_policy_begin_training_iter = cfg.policy.learn.energy_guided_policy_begin_training_iter if hasattr(
cfg.policy.learn, 'energy_guided_policy_begin_training_iter'
) else 0
actions_generated = False

def generate_fake_actions():
policy._model.score_model.q[0].guidance_scale = 0.0
allstates = dataset.states[:].cpu().numpy()
actions_sampled = []
for states in tqdm.tqdm(np.array_split(allstates, allstates.shape[0] // 4096 + 1)):
actions_sampled.append(
policy._model.score_model.sample(
states, sample_per_state=cfg.policy.learn.M, diffusion_steps=cfg.policy.learn.diffusion_steps
)
)
actions = np.concatenate(actions_sampled)

allnextstates = dataset.next_states[:].cpu().numpy()
actions_next_states_sampled = []
for next_states in tqdm.tqdm(np.array_split(allnextstates, allnextstates.shape[0] // 4096 + 1)):
actions_next_states_sampled.append(
policy._model.score_model.sample(
next_states, sample_per_state=cfg.policy.learn.M, diffusion_steps=cfg.policy.learn.diffusion_steps
)
)
actions_next_states = np.concatenate(actions_next_states_sampled)
policy._model.score_model.q[0].guidance_scale = 1.0
return actions, actions_next_states

def _data_generator(ctx: "OfflineRLContext"):
nonlocal actions_generated

if ctx.train_iter >= energy_guided_policy_begin_training_iter:
if ctx.train_iter > behavior_policy_stop_training_iter:
# no need to generate fake actions if fake actions are already generated
if actions_generated:
pass
else:
actions, actions_next_states = generate_fake_actions()
dataset.fake_actions = torch.Tensor(actions.astype(np.float32)).to(cfg.policy.model.device)
dataset.fake_next_actions = torch.Tensor(actions_next_states.astype(np.float32)
).to(cfg.policy.model.device)
actions_generated = True
else:
# generate fake actions
actions, actions_next_states = generate_fake_actions()
dataset.fake_actions = torch.Tensor(actions.astype(np.float32)).to(cfg.policy.model.device)
dataset.fake_next_actions = torch.Tensor(actions_next_states.astype(np.float32)
).to(cfg.policy.model.device)
actions_generated = True
else:
# no need to generate fake actions
pass

return _data_generator


def qgpo_offline_data_fetcher(cfg: EasyDict, dataset: Dataset, collate_fn=lambda x: x) -> Callable:
"""
Overview:
The outer function transforms a Pytorch `Dataset` to `DataLoader`. \
The return function is a generator which each time fetches a batch of data from the previous `DataLoader`.\
Please refer to the link https://pytorch.org/tutorials/beginner/basics/data_tutorial.html \
and https://pytorch.org/docs/stable/data.html for more details.
Arguments:
- cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`.
- dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data.
"""
# collate_fn is executed in policy now
dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=collate_fn)
dataloader_q = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size_q, shuffle=True, collate_fn=collate_fn)

behavior_policy_stop_training_iter = cfg.policy.learn.behavior_policy_stop_training_iter if hasattr(
cfg.policy.learn, 'behavior_policy_stop_training_iter'
) else np.inf
energy_guided_policy_begin_training_iter = cfg.policy.learn.energy_guided_policy_begin_training_iter if hasattr(
cfg.policy.learn, 'energy_guided_policy_begin_training_iter'
) else 0

def get_behavior_policy_training_data():
while True:
yield from dataloader

data = get_behavior_policy_training_data()

def get_q_training_data():
while True:
yield from dataloader_q

data_q = get_q_training_data()

def _fetch(ctx: "OfflineRLContext"):
"""
Overview:
Every time this generator is iterated, the fetched data will be assigned to ctx.train_data. \
After the dataloader is empty, the attribute `ctx.train_epoch` will be incremented by 1.
Input of ctx:
- train_epoch (:obj:`int`): Number of `train_epoch`.
Output of ctx:
- train_data (:obj:`List[Tensor]`): The fetched data batch.
"""

if ctx.train_iter >= energy_guided_policy_begin_training_iter:
ctx.train_data = next(data_q)
else:
ctx.train_data = next(data)

# TODO apply data update (e.g. priority) in offline setting when necessary
ctx.trained_env_step += len(ctx.train_data)

return _fetch
Loading
Loading