diff --git a/README.md b/README.md index 4833c8fec0..39b7e5f46c 100644 --- a/README.md +++ b/README.md @@ -254,16 +254,17 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo` | 43 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [TD3BC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3_bc.html)
[policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u d4rl_td3_bc_main.py | | 44 | [Decision Transformer](https://arxiv.org/pdf/2106.01345.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/dt](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dt.py) | python3 -u d4rl_dt_mujoco.py | | 45 | [EDAC](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)
[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py | -| 46 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py | -| 47 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py | -| 48 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)
[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py | -| 49 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py | -| 50 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py | -| 51 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` | -| 52 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` | -| 53 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 | -| 54 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)
[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 | -| 55 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 | +| 46 | [QGPO](https://arxiv.org/pdf/2304.12824.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [QGPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qgpo.html)
[policy/qgpo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qgpo.py) | python3 -u ding/example/qgpo.py | +| 47 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py | +| 48 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py | +| 49 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)
[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py | +| 50 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py | +| 51 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py | +| 52 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` | +| 53 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` | +| 54 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 | +| 55 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)
[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 | +| 56 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 | diff --git a/ding/example/qgpo.py b/ding/example/qgpo.py new file mode 100644 index 0000000000..ed844974be --- /dev/null +++ b/ding/example/qgpo.py @@ -0,0 +1,170 @@ +import torch +import gym +import d4rl +from easydict import EasyDict +from ditk import logging +from ding.model import QGPO +from ding.policy import QGPOPolicy +from ding.envs import DingEnvWrapper, BaseEnvManagerV2 +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OfflineRLContext +from ding.framework.middleware import trainer, CkptSaver, offline_logger, wandb_offline_logger, termination_checker +from ding.framework.middleware.functional.evaluator import interaction_evaluator +from ding.framework.middleware.functional.data_processor import qgpo_support_data_generator, qgpo_offline_data_fetcher +from ding.utils import set_pkg_seed + +from dizoo.d4rl.config.halfcheetah_medium_expert_qgpo_config import main_config, create_config + + +class QGPOD4RLDataset(torch.utils.data.Dataset): + """ + Overview: + Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, \ + which needs true action and fake action. The true action is sampled from the dataset, and the fake action \ + is sampled from the action support generated by the behavior policy. + Interface: + ``__init__``, ``__getitem__``, ``__len__``. + """ + + def __init__(self, cfg, device="cpu"): + """ + Overview: + Initialization method of QGPOD4RLDataset class + Arguments: + - cfg (:obj:`EasyDict`): Config dict + - device (:obj:`str`): Device name + """ + + self.cfg = cfg + data = d4rl.qlearning_dataset(gym.make(cfg.env_id)) + self.device = device + self.states = torch.from_numpy(data['observations']).float().to(self.device) + self.actions = torch.from_numpy(data['actions']).float().to(self.device) + self.next_states = torch.from_numpy(data['next_observations']).float().to(self.device) + reward = torch.from_numpy(data['rewards']).view(-1, 1).float().to(self.device) + self.is_finished = torch.from_numpy(data['terminals']).view(-1, 1).float().to(self.device) + + reward_tune = "iql_antmaze" if "antmaze" in cfg.env_id else "iql_locomotion" + if reward_tune == 'normalize': + reward = (reward - reward.mean()) / reward.std() + elif reward_tune == 'iql_antmaze': + reward = reward - 1.0 + elif reward_tune == 'iql_locomotion': + min_ret, max_ret = QGPOD4RLDataset.return_range(data, 1000) + reward /= (max_ret - min_ret) + reward *= 1000 + elif reward_tune == 'cql_antmaze': + reward = (reward - 0.5) * 4.0 + elif reward_tune == 'antmaze': + reward = (reward - 0.25) * 2.0 + self.rewards = reward + self.len = self.states.shape[0] + logging.info(f"{self.len} data loaded in QGPOD4RLDataset") + + def __getitem__(self, index): + """ + Overview: + Get data by index + Arguments: + - index (:obj:`int`): Index of data + Returns: + - data (:obj:`dict`): Data dict + + .. note:: + The data dict contains the following keys: + - s (:obj:`torch.Tensor`): State + - a (:obj:`torch.Tensor`): Action + - r (:obj:`torch.Tensor`): Reward + - s_ (:obj:`torch.Tensor`): Next state + - d (:obj:`torch.Tensor`): Is finished + - fake_a (:obj:`torch.Tensor`): Fake action for contrastive energy prediction and qgpo training \ + (fake action is sampled from the action support generated by the behavior policy) + - fake_a_ (:obj:`torch.Tensor`): Fake next action for contrastive energy prediction and qgpo training \ + (fake action is sampled from the action support generated by the behavior policy) + """ + + data = { + 's': self.states[index % self.len], + 'a': self.actions[index % self.len], + 'r': self.rewards[index % self.len], + 's_': self.next_states[index % self.len], + 'd': self.is_finished[index % self.len], + 'fake_a': self.fake_actions[index % self.len] + if hasattr(self, "fake_actions") else 0.0, # self.fake_actions + 'fake_a_': self.fake_next_actions[index % self.len] + if hasattr(self, "fake_next_actions") else 0.0, # self.fake_next_actions + } + return data + + def __len__(self): + return self.len + + def return_range(dataset, max_episode_steps): + returns, lengths = [], [] + ep_ret, ep_len = 0., 0 + for r, d in zip(dataset['rewards'], dataset['terminals']): + ep_ret += float(r) + ep_len += 1 + if d or ep_len == max_episode_steps: + returns.append(ep_ret) + lengths.append(ep_len) + ep_ret, ep_len = 0., 0 + # returns.append(ep_ret) # incomplete trajectory + lengths.append(ep_len) # but still keep track of number of steps + assert sum(lengths) == len(dataset['rewards']) + return min(returns), max(returns) + + +def main(): + # If you don't have offline data, you need to prepare if first and set the data_path in config + # For demostration, we also can train a RL policy (e.g. SAC) and collect some data + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, policy=QGPOPolicy) + ding_init(cfg) + with task.start(async_mode=False, ctx=OfflineRLContext()): + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + model = QGPO(cfg=cfg.policy.model) + policy = QGPOPolicy(cfg.policy, model=model) + dataset = QGPOD4RLDataset(cfg=cfg.dataset, device=policy._device) + if hasattr(cfg.policy, "load_path") and cfg.policy.load_path is not None: + policy_state_dict = torch.load(cfg.policy.load_path, map_location=torch.device("cpu")) + policy.learn_mode.load_state_dict(policy_state_dict) + + task.use(qgpo_support_data_generator(cfg, dataset, policy)) + task.use(qgpo_offline_data_fetcher(cfg, dataset, collate_fn=None)) + task.use(trainer(cfg, policy.learn_mode)) + for guidance_scale in cfg.policy.eval.guidance_scale: + evaluator_env = BaseEnvManagerV2( + env_fn=[ + lambda: DingEnvWrapper(env=gym.make(cfg.env.env_id), cfg=cfg.env, caller="evaluator") + for _ in range(cfg.env.evaluator_env_num) + ], + cfg=cfg.env.manager + ) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env, guidance_scale=guidance_scale)) + task.use( + wandb_offline_logger( + cfg=EasyDict( + dict( + gradient_logger=False, + plot_logger=True, + video_logger=False, + action_logger=False, + return_logger=False, + vis_dataset=False, + ) + ), + exp_config=cfg, + metric_list=policy._monitor_vars_learn(), + project_name=cfg.exp_name + ) + ) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100000)) + task.use(offline_logger()) + task.use(termination_checker(max_train_iter=500000 + cfg.policy.learn.q_value_stop_training_iter)) + task.run() + + +if __name__ == "__main__": + main() diff --git a/ding/framework/context.py b/ding/framework/context.py index 6fb35eec13..12f144766a 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -68,6 +68,7 @@ class OnlineRLContext(Context): last_eval_value: int = -np.inf eval_output: List = dataclasses.field(default_factory=dict) # wandb + info_for_logging: Dict = dataclasses.field(default_factory=dict) wandb_url: str = "" def __post_init__(self): @@ -93,6 +94,7 @@ class OfflineRLContext(Context): last_eval_value: int = -np.inf eval_output: List = dataclasses.field(default_factory=dict) # wandb + info_for_logging: Dict = dataclasses.field(default_factory=dict) wandb_url: str = "" def __post_init__(self): diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index cbcc39e7a2..5882716ef3 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -2,7 +2,9 @@ from typing import TYPE_CHECKING, Callable, List, Union, Tuple, Dict, Optional from easydict import EasyDict from ditk import logging +import numpy as np import torch +import tqdm from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type from ding.data.buffer.middleware import PriorityExperienceReplay from ding.framework import task @@ -229,7 +231,7 @@ def _fetch(ctx: "OfflineRLContext"): return _fetch -def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable: +def offline_data_fetcher(cfg: EasyDict, dataset: Dataset, collate_fn=lambda x: x) -> Callable: """ Overview: The outer function transforms a Pytorch `Dataset` to `DataLoader`. \ @@ -241,7 +243,7 @@ def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable: - dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data. """ # collate_fn is executed in policy now - dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x) + dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=collate_fn) dataloader = iter(dataloader) def _fetch(ctx: "OfflineRLContext"): @@ -261,7 +263,7 @@ def _fetch(ctx: "OfflineRLContext"): ctx.train_epoch += 1 del dataloader dataloader = DataLoader( - dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x + dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=collate_fn ) dataloader = iter(dataloader) ctx.train_data = next(dataloader) @@ -319,3 +321,125 @@ def _pusher(ctx: "OnlineRLContext"): ctx.trajectories = None return _pusher + + +def qgpo_support_data_generator(cfg, dataset, policy) -> Callable: + + behavior_policy_stop_training_iter = cfg.policy.learn.behavior_policy_stop_training_iter if hasattr( + cfg.policy.learn, 'behavior_policy_stop_training_iter' + ) else np.inf + energy_guided_policy_begin_training_iter = cfg.policy.learn.energy_guided_policy_begin_training_iter if hasattr( + cfg.policy.learn, 'energy_guided_policy_begin_training_iter' + ) else 0 + actions_generated = False + + def generate_fake_actions(): + allstates = dataset.states[:].cpu().numpy() + actions_sampled = [] + for states in tqdm.tqdm(np.array_split(allstates, allstates.shape[0] // 4096 + 1)): + actions_sampled.append( + policy._model.sample( + states, + sample_per_state=cfg.policy.learn.M, + diffusion_steps=cfg.policy.learn.diffusion_steps, + guidance_scale=0.0, + ) + ) + actions = np.concatenate(actions_sampled) + + allnextstates = dataset.next_states[:].cpu().numpy() + actions_next_states_sampled = [] + for next_states in tqdm.tqdm(np.array_split(allnextstates, allnextstates.shape[0] // 4096 + 1)): + actions_next_states_sampled.append( + policy._model.sample( + next_states, + sample_per_state=cfg.policy.learn.M, + diffusion_steps=cfg.policy.learn.diffusion_steps, + guidance_scale=0.0, + ) + ) + actions_next_states = np.concatenate(actions_next_states_sampled) + return actions, actions_next_states + + def _data_generator(ctx: "OfflineRLContext"): + nonlocal actions_generated + + if ctx.train_iter >= energy_guided_policy_begin_training_iter: + if ctx.train_iter > behavior_policy_stop_training_iter: + # no need to generate fake actions if fake actions are already generated + if actions_generated: + pass + else: + actions, actions_next_states = generate_fake_actions() + dataset.fake_actions = torch.Tensor(actions.astype(np.float32)).to(cfg.policy.model.device) + dataset.fake_next_actions = torch.Tensor(actions_next_states.astype(np.float32) + ).to(cfg.policy.model.device) + actions_generated = True + else: + # generate fake actions + actions, actions_next_states = generate_fake_actions() + dataset.fake_actions = torch.Tensor(actions.astype(np.float32)).to(cfg.policy.model.device) + dataset.fake_next_actions = torch.Tensor(actions_next_states.astype(np.float32) + ).to(cfg.policy.model.device) + actions_generated = True + else: + # no need to generate fake actions + pass + + return _data_generator + + +def qgpo_offline_data_fetcher(cfg: EasyDict, dataset: Dataset, collate_fn=lambda x: x) -> Callable: + """ + Overview: + The outer function transforms a Pytorch `Dataset` to `DataLoader`. \ + The return function is a generator which each time fetches a batch of data from the previous `DataLoader`.\ + Please refer to the link https://pytorch.org/tutorials/beginner/basics/data_tutorial.html \ + and https://pytorch.org/docs/stable/data.html for more details. + Arguments: + - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`. + - dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data. + """ + # collate_fn is executed in policy now + dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=collate_fn) + dataloader_q = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size_q, shuffle=True, collate_fn=collate_fn) + + behavior_policy_stop_training_iter = cfg.policy.learn.behavior_policy_stop_training_iter if hasattr( + cfg.policy.learn, 'behavior_policy_stop_training_iter' + ) else np.inf + energy_guided_policy_begin_training_iter = cfg.policy.learn.energy_guided_policy_begin_training_iter if hasattr( + cfg.policy.learn, 'energy_guided_policy_begin_training_iter' + ) else 0 + + def get_behavior_policy_training_data(): + while True: + yield from dataloader + + data = get_behavior_policy_training_data() + + def get_q_training_data(): + while True: + yield from dataloader_q + + data_q = get_q_training_data() + + def _fetch(ctx: "OfflineRLContext"): + """ + Overview: + Every time this generator is iterated, the fetched data will be assigned to ctx.train_data. \ + After the dataloader is empty, the attribute `ctx.train_epoch` will be incremented by 1. + Input of ctx: + - train_epoch (:obj:`int`): Number of `train_epoch`. + Output of ctx: + - train_data (:obj:`List[Tensor]`): The fetched data batch. + """ + + if ctx.train_iter >= energy_guided_policy_begin_training_iter: + ctx.train_data = next(data_q) + else: + ctx.train_data = next(data) + + # TODO apply data update (e.g. priority) in offline setting when necessary + ctx.trained_env_step += len(ctx.train_data) + + return _fetch diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 611bbcdea6..3fa407b6de 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -210,7 +210,9 @@ def get_episode_output(self): return output -def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, render: bool = False) -> Callable: +def interaction_evaluator( + cfg: EasyDict, policy: Policy, env: BaseEnvManager, render: bool = False, **kwargs +) -> Callable: """ Overview: The middleware that executes the evaluation. @@ -219,6 +221,7 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, re - policy (:obj:`Policy`): The policy to be evaluated. - env (:obj:`BaseEnvManager`): The env for the evaluation. - render (:obj:`bool`): Whether to render env images and policy logits. + - kwargs: (:obj:`Any`): Other arguments for specific evaluation. """ if task.router.is_active and not task.has_role(task.role.EVALUATOR): return task.void() @@ -239,8 +242,13 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): # evaluation will be executed if the task begins or enough train_iter after last evaluation if ctx.last_eval_iter != -1 and \ - (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): - return + (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): + if ctx.train_iter != ctx.last_eval_iter: + return + if len(kwargs) > 0: + kwargs_str = '/'.join([f'{k}({v})' for k, v in kwargs.items()]) + else: + kwargs_str = '' if env.closed: env.launch() @@ -252,7 +260,10 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): while not eval_monitor.is_finished(): obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32) obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD - inference_output = policy.forward(obs) + if len(kwargs) > 0: + inference_output = policy.forward(obs, **kwargs) + else: + inference_output = policy.forward(obs) if render: eval_monitor.update_video(env.ready_imgs) eval_monitor.update_output(inference_output) @@ -275,12 +286,14 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): stop_flag = episode_return >= cfg.env.stop_value and ctx.train_iter > 0 if isinstance(ctx, OnlineRLContext): logging.info( - 'Evaluation: Train Iter({})\tEnv Step({})\tEpisode Return({:.3f})'.format( - ctx.train_iter, ctx.env_step, episode_return + 'Evaluation: Train Iter({}) Env Step({}) Episode Return({:.3f}) {}'.format( + ctx.train_iter, ctx.env_step, episode_return, kwargs_str ) ) elif isinstance(ctx, OfflineRLContext): - logging.info('Evaluation: Train Iter({})\tEval Reward({:.3f})'.format(ctx.train_iter, episode_return)) + logging.info( + 'Evaluation: Train Iter({}) Eval Return({:.3f}) {}'.format(ctx.train_iter, episode_return, kwargs_str) + ) else: raise TypeError("not supported ctx type: {}".format(type(ctx))) ctx.last_eval_iter = ctx.train_iter @@ -299,6 +312,16 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): else: ctx.eval_output['output'] = output # for compatibility + if len(kwargs) > 0: + ctx.info_for_logging.update( + { + f'{kwargs_str}/eval_episode_return': episode_return, + f'{kwargs_str}/eval_episode_return_min': episode_return_min, + f'{kwargs_str}/eval_episode_return_max': episode_return_max, + f'{kwargs_str}/eval_episode_return_std': episode_return_std, + } + ) + if stop_flag: task.finish = True diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index 9f62e2f429..5f968a5fd3 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -611,6 +611,15 @@ def _plot(ctx: "OfflineRLContext"): ) if ctx.eval_value != -np.inf: + if hasattr(ctx, "info_for_logging"): + """ + .. note:: + The info_for_logging is a dict that contains the information to be logged. + Users can add their own information to the dict. + All the information in the dict will be logged to wandb. + """ + info_for_logging.update(ctx.info_for_logging) + if hasattr(ctx, "eval_value_min"): info_for_logging.update({ "episode return min": ctx.eval_value_min, diff --git a/ding/model/common/__init__.py b/ding/model/common/__init__.py index 4bf7d8be5a..b63aedfcd9 100755 --- a/ding/model/common/__init__.py +++ b/ding/model/common/__init__.py @@ -1,5 +1,5 @@ from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, StochasticDuelingHead, \ QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \ independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead -from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder +from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder, GaussianFourierProjectionTimeEncoder from .utils import create_model diff --git a/ding/model/common/encoder.py b/ding/model/common/encoder.py index 82dab4808a..8936f6af4d 100644 --- a/ding/model/common/encoder.py +++ b/ding/model/common/encoder.py @@ -2,6 +2,7 @@ from functools import reduce import operator import math +import numpy as np import torch import torch.nn as nn from torch.nn import functional as F @@ -470,3 +471,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.final_relu: x = torch.relu(x) return x + + +class GaussianFourierProjectionTimeEncoder(nn.Module): + """ + Overview: + Gaussian random features for encoding time steps. + This module is used as the encoder of time in generative models such as diffusion model. + Interfaces: + ``__init__``, ``forward``. + """ + + def __init__(self, embed_dim, scale=30.): + """ + Overview: + Initialize the Gaussian Fourier Projection Time Encoder according to arguments. + Arguments: + - embed_dim (:obj:`int`): The dimension of the output embedding vector. + - scale (:obj:`float`): The scale of the Gaussian random features. + """ + super().__init__() + # Randomly sample weights during initialization. These weights are fixed + # during optimization and are not trainable. + self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale * 2 * np.pi, requires_grad=False) + + def forward(self, x): + """ + Overview: + Return the output embedding vector of the input time step. + Arguments: + - x (:obj:`torch.Tensor`): Input time step tensor. + Returns: + - output (:obj:`torch.Tensor`): Output embedding vector. + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B,)`, where B is batch size. + - output (:obj:`torch.Tensor`): :math:`(B, embed_dim)`, where B is batch size, embed_dim is the \ + dimension of the output embedding vector. + Examples: + >>> encoder = GaussianFourierProjectionTimeEncoder(128) + >>> x = torch.randn(100) + >>> output = encoder(x) + """ + x_proj = x[..., None] * self.W[None, :] + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) diff --git a/ding/model/common/tests/test_encoder.py b/ding/model/common/tests/test_encoder.py index cd8a5bf752..d73f10561a 100644 --- a/ding/model/common/tests/test_encoder.py +++ b/ding/model/common/tests/test_encoder.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from ding.model import ConvEncoder, FCEncoder, IMPALAConvEncoder +from ding.model import ConvEncoder, FCEncoder, IMPALAConvEncoder, GaussianFourierProjectionTimeEncoder from ding.torch_utils import is_differentiable B = 4 @@ -61,3 +61,10 @@ def test_impalaconv_encoder(self): outputs = model(inputs) self.output_check(model, outputs) assert outputs.shape == (B, 256) + + def test_GaussianFourierProjectionTimeEncoder(self): + inputs = torch.randn(B) + model = GaussianFourierProjectionTimeEncoder(128) + print(model) + outputs = model(inputs) + assert outputs.shape == (B, 128) diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index c9dc17791c..8e902f1504 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -26,5 +26,6 @@ from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS from .bcq import BCQ from .edac import EDAC +from .qgpo import QGPO from .ebm import EBM, AutoregressiveEBM from .havac import HAVAC diff --git a/ding/model/template/qgpo.py b/ding/model/template/qgpo.py new file mode 100644 index 0000000000..135433c42c --- /dev/null +++ b/ding/model/template/qgpo.py @@ -0,0 +1,465 @@ +############################################################# +# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion +############################################################# + +from easydict import EasyDict +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from ding.torch_utils import MLP +from ding.torch_utils.diffusion_SDE.dpm_solver_pytorch import DPM_Solver, NoiseScheduleVP +from ding.model.common.encoder import GaussianFourierProjectionTimeEncoder +from ding.torch_utils.network.res_block import TemporalSpatialResBlock + + +def marginal_prob_std(t, device): + """ + Overview: + Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$. + Arguments: + - t (:obj:`torch.Tensor`): The input time. + - device (:obj:`torch.device`): The device to use. + """ + + t = torch.tensor(t, device=device) + beta_1 = 20.0 + beta_0 = 0.1 + log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0 + alpha_t = torch.exp(log_mean_coeff) + std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + return alpha_t, std + + +class TwinQ(nn.Module): + """ + Overview: + Twin Q network for QGPO, which has two Q networks. + Interfaces: + ``__init__``, ``forward``, ``both`` + """ + + def __init__(self, action_dim, state_dim): + """ + Overview: + Initialization of Twin Q. + Arguments: + - action_dim (:obj:`int`): The dimension of action. + - state_dim (:obj:`int`): The dimension of state. + """ + super().__init__() + self.q1 = MLP( + in_channels=state_dim + action_dim, + hidden_channels=256, + out_channels=1, + activation=nn.ReLU(), + layer_num=4, + output_activation=False + ) + self.q2 = MLP( + in_channels=state_dim + action_dim, + hidden_channels=256, + out_channels=1, + activation=nn.ReLU(), + layer_num=4, + output_activation=False + ) + + def both(self, action, condition=None): + """ + Overview: + Return the output of two Q networks. + Arguments: + - action (:obj:`torch.Tensor`): The input action. + - condition (:obj:`torch.Tensor`): The input condition. + """ + as_ = torch.cat([action, condition], -1) if condition is not None else action + return self.q1(as_), self.q2(as_) + + def forward(self, action, condition=None): + """ + Overview: + Return the minimum output of two Q networks. + Arguments: + - action (:obj:`torch.Tensor`): The input action. + - condition (:obj:`torch.Tensor`): The input condition. + """ + return torch.min(*self.both(action, condition)) + + +class GuidanceQt(nn.Module): + """ + Overview: + Energy Guidance Qt network for QGPO. \ + In the origin paper, the energy guidance is trained by CEP method. + Interfaces: + ``__init__``, ``forward`` + """ + + def __init__(self, action_dim, state_dim, time_embed_dim=32): + """ + Overview: + Initialization of Guidance Qt. + Arguments: + - action_dim (:obj:`int`): The dimension of action. + - state_dim (:obj:`int`): The dimension of state. + - time_embed_dim (:obj:`int`): The dimension of time embedding. \ + The time embedding is a Gaussian Fourier Feature tensor. + """ + super().__init__() + self.qt = MLP( + in_channels=action_dim + time_embed_dim + state_dim, + hidden_channels=256, + out_channels=1, + activation=torch.nn.SiLU(), + layer_num=4, + output_activation=False + ) + self.embed = nn.Sequential( + GaussianFourierProjectionTimeEncoder(embed_dim=time_embed_dim), nn.Linear(time_embed_dim, time_embed_dim) + ) + + def forward(self, action, t, condition=None): + """ + Overview: + Return the output of Guidance Qt. + Arguments: + - action (:obj:`torch.Tensor`): The input action. + - t (:obj:`torch.Tensor`): The input time. + - condition (:obj:`torch.Tensor`): The input condition. + """ + embed = self.embed(t) + ats = torch.cat([action, embed, condition], -1) if condition is not None else torch.cat([action, embed], -1) + return self.qt(ats) + + +class QGPOCritic(nn.Module): + """ + Overview: + QGPO critic network. + Interfaces: + ``__init__``, ``forward``, ``calculateQ``, ``calculate_guidance`` + """ + + def __init__(self, device, cfg, action_dim, state_dim) -> None: + """ + Overview: + Initialization of QGPO critic. + Arguments: + - device (:obj:`torch.device`): The device to use. + - cfg (:obj:`EasyDict`): The config dict. + - action_dim (:obj:`int`): The dimension of action. + - state_dim (:obj:`int`): The dimension of state. + """ + + super().__init__() + # is state_dim is 0 means unconditional guidance + assert state_dim > 0 + # only apply to conditional sampling here + self.device = device + self.q0 = TwinQ(action_dim, state_dim).to(self.device) + self.q0_target = copy.deepcopy(self.q0).requires_grad_(False).to(self.device) + self.qt = GuidanceQt(action_dim, state_dim).to(self.device) + + self.alpha = cfg.alpha + self.q_alpha = cfg.q_alpha + + def calculate_guidance(self, a, t, condition=None, guidance_scale=1.0): + """ + Overview: + Calculate the guidance for conditional sampling. + Arguments: + - a (:obj:`torch.Tensor`): The input action. + - t (:obj:`torch.Tensor`): The input time. + - condition (:obj:`torch.Tensor`): The input condition. + - guidance_scale (:obj:`float`): The scale of guidance. + """ + + with torch.enable_grad(): + a.requires_grad_(True) + Q_t = self.qt(a, t, condition) + guidance = guidance_scale * torch.autograd.grad(torch.sum(Q_t), a)[0] + return guidance.detach() + + def forward(self, a, condition=None): + """ + Overview: + Return the output of QGPO critic. + Arguments: + - a (:obj:`torch.Tensor`): The input action. + - condition (:obj:`torch.Tensor`): The input condition. + """ + + return self.q0(a, condition) + + def calculateQ(self, a, condition=None): + """ + Overview: + Return the output of QGPO critic. + Arguments: + - a (:obj:`torch.Tensor`): The input action. + - condition (:obj:`torch.Tensor`): The input condition. + """ + + return self(a, condition) + + +class ScoreNet(nn.Module): + """ + Overview: + Score-based generative model for QGPO. + Interfaces: + ``__init__``, ``forward`` + """ + + def __init__(self, device, input_dim, output_dim, embed_dim=32): + """ + Overview: + Initialization of ScoreNet. + Arguments: + - device (:obj:`torch.device`): The device to use. + - input_dim (:obj:`int`): The dimension of input. + - output_dim (:obj:`int`): The dimension of output. + - embed_dim (:obj:`int`): The dimension of time embedding. \ + The time embedding is a Gaussian Fourier Feature tensor. + """ + + super().__init__() + + # origin score base + self.output_dim = output_dim + self.embed = nn.Sequential( + GaussianFourierProjectionTimeEncoder(embed_dim=embed_dim), nn.Linear(embed_dim, embed_dim) + ) + + self.device = device + self.pre_sort_condition = nn.Sequential(nn.Linear(input_dim - output_dim, 32), torch.nn.SiLU()) + self.sort_t = nn.Sequential( + nn.Linear(64, 128), + torch.nn.SiLU(), + nn.Linear(128, 128), + ) + self.down_block1 = TemporalSpatialResBlock(output_dim, 512) + self.down_block2 = TemporalSpatialResBlock(512, 256) + self.down_block3 = TemporalSpatialResBlock(256, 128) + self.middle1 = TemporalSpatialResBlock(128, 128) + self.up_block3 = TemporalSpatialResBlock(256, 256) + self.up_block2 = TemporalSpatialResBlock(512, 512) + self.last = nn.Linear(1024, output_dim) + + def forward(self, x, t, condition): + """ + Overview: + Return the output of ScoreNet. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + - t (:obj:`torch.Tensor`): The input time. + - condition (:obj:`torch.Tensor`): The input condition. + """ + + embed = self.embed(t) + embed = torch.cat([self.pre_sort_condition(condition), embed], dim=-1) + embed = self.sort_t(embed) + d1 = self.down_block1(x, embed) + d2 = self.down_block2(d1, embed) + d3 = self.down_block3(d2, embed) + u3 = self.middle1(d3, embed) + u2 = self.up_block3(torch.cat([d3, u3], dim=-1), embed) + u1 = self.up_block2(torch.cat([d2, u2], dim=-1), embed) + u0 = torch.cat([d1, u1], dim=-1) + h = self.last(u0) + self.h = h + # Normalize output + return h / marginal_prob_std(t, device=self.device)[1][..., None] + + +class QGPO(nn.Module): + """ + Overview: + Model of QGPO algorithm. + Interfaces: + ``__init__``, ``calculateQ``, ``select_actions``, ``sample``, ``score_model_loss_fn``, ``q_loss_fn``, \ + ``qt_loss_fn`` + """ + + def __init__(self, cfg: EasyDict) -> None: + """ + Overview: + Initialization of QGPO. + Arguments: + - cfg (:obj:`EasyDict`): The config dict. + """ + + super(QGPO, self).__init__() + self.device = cfg.device + self.obs_dim = cfg.obs_dim + self.action_dim = cfg.action_dim + + self.noise_schedule = NoiseScheduleVP(schedule='linear') + + self.score_model = ScoreNet( + device=self.device, + input_dim=self.obs_dim + self.action_dim, + output_dim=self.action_dim, + ) + + self.q = QGPOCritic(self.device, cfg.qgpo_critic, action_dim=self.action_dim, state_dim=self.obs_dim) + + def calculateQ(self, s, a): + """ + Overview: + Calculate the Q value. + Arguments: + - s (:obj:`torch.Tensor`): The input state. + - a (:obj:`torch.Tensor`): The input action. + """ + + return self.q(a, s) + + def select_actions(self, states, diffusion_steps=15, guidance_scale=1.0): + """ + Overview: + Select actions for conditional sampling. + Arguments: + - states (:obj:`list`): The input states. + - diffusion_steps (:obj:`int`): The diffusion steps. + - guidance_scale (:obj:`float`): The scale of guidance. + """ + + def forward_dpm_wrapper_fn(x, t): + score = self.score_model(x, t, condition=states) + result = -(score + + self.q.calculate_guidance(x, t, states, guidance_scale=guidance_scale)) * marginal_prob_std( + t, device=self.device + )[1][..., None] + return result + + self.eval() + multiple_input = True + with torch.no_grad(): + states = torch.FloatTensor(states).to(self.device) + if states.dim == 1: + states = states.unsqueeze(0) + multiple_input = False + num_states = states.shape[0] + + init_x = torch.randn(states.shape[0], self.action_dim, device=self.device) + results = DPM_Solver( + forward_dpm_wrapper_fn, self.noise_schedule, predict_x0=True + ).sample( + init_x, steps=diffusion_steps, order=2 + ).cpu().numpy() + + actions = results.reshape(num_states, self.action_dim).copy() # + + out_actions = [actions[i] for i in range(actions.shape[0])] if multiple_input else actions[0] + self.train() + return out_actions + + def sample(self, states, sample_per_state=16, diffusion_steps=15, guidance_scale=1.0): + """ + Overview: + Sample actions for conditional sampling. + Arguments: + - states (:obj:`list`): The input states. + - sample_per_state (:obj:`int`): The number of samples per state. + - diffusion_steps (:obj:`int`): The diffusion steps. + - guidance_scale (:obj:`float`): The scale of guidance. + """ + + def forward_dpm_wrapper_fn(x, t): + score = self.score_model(x, t, condition=states) + result = -(score + self.q.calculate_guidance(x, t, states, guidance_scale=guidance_scale)) \ + * marginal_prob_std(t, device=self.device)[1][..., None] + return result + + self.eval() + num_states = states.shape[0] + with torch.no_grad(): + states = torch.FloatTensor(states).to(self.device) + states = torch.repeat_interleave(states, sample_per_state, dim=0) + + init_x = torch.randn(states.shape[0], self.action_dim, device=self.device) + results = DPM_Solver( + forward_dpm_wrapper_fn, self.noise_schedule, predict_x0=True + ).sample( + init_x, steps=diffusion_steps, order=2 + ).cpu().numpy() + + actions = results[:, :].reshape(num_states, sample_per_state, self.action_dim).copy() + + self.train() + return actions + + def score_model_loss_fn(self, x, s, eps=1e-3): + """ + Overview: + The loss function for training score-based generative models. + Arguments: + model: A PyTorch model instance that represents a \ + time-dependent score-based model. + x: A mini-batch of training data. + eps: A tolerance value for numerical stability. + """ + + random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps + z = torch.randn_like(x) + alpha_t, std = marginal_prob_std(random_t, device=x.device) + perturbed_x = x * alpha_t[:, None] + z * std[:, None] + score = self.score_model(perturbed_x, random_t, condition=s) + loss = torch.mean(torch.sum((score * std[:, None] + z) ** 2, dim=(1, ))) + return loss + + def q_loss_fn(self, a, s, r, s_, d, fake_a_, discount=0.99): + """ + Overview: + The loss function for training Q function. + Arguments: + - a (:obj:`torch.Tensor`): The input action. + - s (:obj:`torch.Tensor`): The input state. + - r (:obj:`torch.Tensor`): The input reward. + - s_ (:obj:`torch.Tensor`): The input next state. + - d (:obj:`torch.Tensor`): The input done. + - fake_a_ (:obj:`torch.Tensor`): The input fake action. + - discount (:obj:`float`): The discount factor. + """ + + with torch.no_grad(): + softmax = nn.Softmax(dim=1) + next_energy = self.q.q0_target(fake_a_, torch.stack([s_] * fake_a_.shape[1], axis=1)).detach().squeeze() + next_v = torch.sum(softmax(self.q.q_alpha * next_energy) * next_energy, dim=-1, keepdim=True) + # Update Q function + targets = r + (1. - d.float()) * discount * next_v.detach() + qs = self.q.q0.both(a, s) + q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs) + + return q_loss + + def qt_loss_fn(self, s, fake_a): + """ + Overview: + The loss function for training Guidance Qt. + Arguments: + - s (:obj:`torch.Tensor`): The input state. + - fake_a (:obj:`torch.Tensor`): The input fake action. + """ + + # input many s anction , + energy = self.q.q0_target(fake_a, torch.stack([s] * fake_a.shape[1], axis=1)).detach().squeeze() + + # CEP guidance method, as proposed in the paper + logsoftmax = nn.LogSoftmax(dim=1) + softmax = nn.Softmax(dim=1) + + x0_data_energy = energy * self.q.alpha + random_t = torch.rand((fake_a.shape[0], ), device=self.device) * (1. - 1e-3) + 1e-3 + random_t = torch.stack([random_t] * fake_a.shape[1], dim=1) + z = torch.randn_like(fake_a) + alpha_t, std = marginal_prob_std(random_t, device=self.device) + perturbed_fake_a = fake_a * alpha_t[..., None] + z * std[..., None] + xt_model_energy = self.q.qt(perturbed_fake_a, random_t, torch.stack([s] * fake_a.shape[1], axis=1)).squeeze() + p_label = softmax(x0_data_energy) + + # + qt_loss = -torch.mean(torch.sum(p_label * logsoftmax(xt_model_energy), axis=-1)) + return qt_loss diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index c85883a0af..1f202da3bb 100755 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -51,6 +51,7 @@ from .pc import ProcedureCloningBFSPolicy from .bcq import BCQPolicy +from .qgpo import QGPOPolicy # new-type policy from .ppof import PPOFPolicy diff --git a/ding/policy/qgpo.py b/ding/policy/qgpo.py new file mode 100644 index 0000000000..cfa3cb19c1 --- /dev/null +++ b/ding/policy/qgpo.py @@ -0,0 +1,269 @@ +############################################################# +# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion +############################################################# + +from typing import List, Dict, Any +import torch +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_collate +from ding.torch_utils import to_device +from .base_policy import Policy + + +@POLICY_REGISTRY.register('qgpo') +class QGPOPolicy(Policy): + """ + Overview: + Policy class of QGPO algorithm + Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning + https://arxiv.org/abs/2304.12824 + Interfaces: + ``__init__``, ``forward``, ``learn``, ``eval``, ``state_dict``, ``load_state_dict`` + """ + + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='qgpo', + # (bool) Whether to use cuda for network. + cuda=False, + # (bool type) on_policy: Determine whether on-policy or off-policy. + # on-policy setting influences the behaviour of buffer. + # Default False in QGPO. + on_policy=False, + multi_agent=False, + model=dict( + qgpo_critic=dict( + # (float) The scale of the energy guidance when training qt. + # \pi_{behavior}\exp(f(s,a)) \propto \pi_{behavior}\exp(alpha * Q(s,a)) + alpha=3, + # (float) The scale of the energy guidance when training q0. + # \mathcal{T}Q(s,a)=r(s,a)+\mathbb{E}_{s'\sim P(s'|s,a),a'\sim\pi_{support}(a'|s')}Q(s',a') + # \pi_{support} \propto \pi_{behavior}\exp(q_alpha * Q(s,a)) + q_alpha=1, + ), + device='cuda', + # obs_dim + # action_dim + ), + learn=dict( + # learning rate for behavior model training + learning_rate=1e-4, + # batch size during the training of behavior model + batch_size=4096, + # batch size during the training of q value + batch_size_q=256, + # number of fake action support + M=16, + # number of diffusion time steps + diffusion_steps=15, + # training iterations when behavior model is fixed + behavior_policy_stop_training_iter=600000, + # training iterations when energy-guided policy begin training + energy_guided_policy_begin_training_iter=600000, + # training iterations when q value stop training, default None means no limit + q_value_stop_training_iter=1100000, + ), + eval=dict( + # energy guidance scale for policy in evaluation + # \pi_{evaluation} \propto \pi_{behavior}\exp(guidance_scale * alpha * Q(s,a)) + guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], + ), + ) + + def _init_learn(self) -> None: + """ + Overview: + Learn mode initialization method. For QGPO, it mainly contains the optimizer, \ + algorithm-specific arguments such as qt_update_momentum, discount, behavior_policy_stop_training_iter, \ + energy_guided_policy_begin_training_iter and q_value_stop_training_iter, etc. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + """ + self.cuda = self._cfg.cuda + + self.behavior_model_optimizer = torch.optim.Adam( + self._model.score_model.parameters(), lr=self._cfg.learn.learning_rate + ) + self.q_optimizer = torch.optim.Adam(self._model.q.q0.parameters(), lr=3e-4) + self.qt_optimizer = torch.optim.Adam(self._model.q.qt.parameters(), lr=3e-4) + + self.qt_update_momentum = 0.005 + self.discount = 0.99 + + self.behavior_policy_stop_training_iter = self._cfg.learn.behavior_policy_stop_training_iter + self.energy_guided_policy_begin_training_iter = self._cfg.learn.energy_guided_policy_begin_training_iter + self.q_value_stop_training_iter = self._cfg.learn.q_value_stop_training_iter + + def _forward_learn(self, data: dict) -> Dict[str, Any]: + """ + Overview: + Forward function for learning mode. + The training of QGPO algorithm is based on contrastive energy prediction, \ + which needs true action and fake action. The true action is sampled from the dataset, and the fake action \ + is sampled from the action support generated by the behavior policy. + The training process is divided into two stages: + 1. Train the behavior model, which is modeled as a diffusion model by parameterizing the score function. + 2. Train the Q function by fake action support generated by the behavior model. + 3. Train the energy-guided policy by the Q function. + Arguments: + - data (:obj:`dict`): Dict type data. + Returns: + - result (:obj:`dict`): Dict type data of algorithm results. + """ + + if self.cuda: + data = to_device(data, self._device) + + s = data['s'] + a = data['a'] + r = data['r'] + s_ = data['s_'] + d = data['d'] + fake_a = data['fake_a'] + fake_a_ = data['fake_a_'] + + # training behavior model + if self.behavior_policy_stop_training_iter > 0: + + behavior_model_training_loss = self._model.score_model_loss_fn(a, s) + + self.behavior_model_optimizer.zero_grad() + behavior_model_training_loss.backward() + self.behavior_model_optimizer.step() + + self.behavior_policy_stop_training_iter -= 1 + behavior_model_training_loss = behavior_model_training_loss.item() + else: + behavior_model_training_loss = 0 + + # training Q function + self.energy_guided_policy_begin_training_iter -= 1 + self.q_value_stop_training_iter -= 1 + if self.energy_guided_policy_begin_training_iter < 0: + if self.q_value_stop_training_iter > 0: + q0_loss = self._model.q_loss_fn(a, s, r, s_, d, fake_a_, discount=self.discount) + + self.q_optimizer.zero_grad() + q0_loss.backward() + self.q_optimizer.step() + + # Update target + for param, target_param in zip(self._model.q.q0.parameters(), self._model.q.q0_target.parameters()): + target_param.data.copy_( + self.qt_update_momentum * param.data + (1 - self.qt_update_momentum) * target_param.data + ) + + q0_loss = q0_loss.item() + + else: + q0_loss = 0 + qt_loss = self._model.qt_loss_fn(s, fake_a) + + self.qt_optimizer.zero_grad() + qt_loss.backward() + self.qt_optimizer.step() + + qt_loss = qt_loss.item() + + else: + q0_loss = 0 + qt_loss = 0 + + total_loss = behavior_model_training_loss + q0_loss + qt_loss + + return dict( + total_loss=total_loss, + behavior_model_training_loss=behavior_model_training_loss, + q0_loss=q0_loss, + qt_loss=qt_loss, + ) + + def _init_collect(self) -> None: + """ + Overview: + Collect mode initialization method. Not supported for QGPO. + """ + pass + + def _forward_collect(self) -> None: + """ + Overview: + Forward function for collect mode. Not supported for QGPO. + """ + pass + + def _init_eval(self) -> None: + """ + Overview: + Eval mode initialization method. For QGPO, it mainly contains the guidance_scale and diffusion_steps, etc. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + """ + + self.diffusion_steps = self._cfg.eval.diffusion_steps + + def _forward_eval(self, data: dict, guidance_scale: float) -> dict: + """ + Overview: + Forward function for eval mode. The eval process is based on the energy-guided policy, \ + which is modeled as a diffusion model by parameterizing the score function. + Arguments: + - data (:obj:`dict`): Dict type data. + - guidance_scale (:obj:`float`): The scale of the energy guidance. + Returns: + - output (:obj:`dict`): Dict type data of algorithm output. + """ + + data_id = list(data.keys()) + states = default_collate(list(data.values())) + actions = self._model.select_actions( + states, diffusion_steps=self.diffusion_steps, guidance_scale=guidance_scale + ) + output = actions + + return {i: {"action": d} for i, d in zip(data_id, output)} + + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + Get the train sample from the replay buffer, currently not supported for QGPO. + Arguments: + - transitions (:obj:`List[Dict[str, Any]]`): The data from replay buffer. + Returns: + - samples (:obj:`List[Dict[str, Any]]`): The data for training. + """ + pass + + def _process_transition(self) -> None: + """ + Overview: + Process the transition data, currently not supported for QGPO. + """ + pass + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state dict for saving. + Returns: + - state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict. + """ + return { + 'model': self._model.state_dict(), + 'behavior_model_optimizer': self.behavior_model_optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state dict. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict. + """ + self._model.load_state_dict(state_dict['model']) + self.behavior_model_optimizer.load_state_dict(state_dict['behavior_model_optimizer']) + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the variables names to be monitored. + """ + return ['total_loss', 'behavior_model_training_loss', 'q0_loss', 'qt_loss'] diff --git a/ding/torch_utils/diffusion_SDE/__init__.py b/ding/torch_utils/diffusion_SDE/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ding/torch_utils/diffusion_SDE/dpm_solver_pytorch.py b/ding/torch_utils/diffusion_SDE/dpm_solver_pytorch.py new file mode 100644 index 0000000000..4f43c10510 --- /dev/null +++ b/ding/torch_utils/diffusion_SDE/dpm_solver_pytorch.py @@ -0,0 +1,1288 @@ +############################################################# +# This DPM-Solver snippet is from https://github.com/ChenDRAG/CEP-energy-guided-diffusion +# wich is based on https://github.com/LuChengTHU/dpm-solver +############################################################# + +import torch +import math + + +class NoiseScheduleVP: + + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation \ + for log_alpha_t. We recommend to use schedule='discrete' for the discrete-time diffusion models, \ + especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution \ + q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), \ + which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. + For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and \ + continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, \ + we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. \ + (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. \ + (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(betas). \ + Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. \ + Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. + In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). + The hyperparameters for the noise schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array \ + for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError( + "Unsupported noise schedule {}. \ + The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule) + ) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape(( + 1, + -1, + )) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi + ) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1, )).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1, )).to(lamb.device), -2. * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]) + ) + return t.reshape((-1, )) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1, )).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0) + ) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1, )).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1, )).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + + We support both the noise prediction model ("predicting epsilon") and \ + the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs \ + with large guidance scales. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. \ + The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, \ + Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. \ + Photorealistic text-to-image diffusion models with deep language understanding. \ + arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / (s / self.max_val) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. \ + (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. \ + (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) + ) + + def get_orders_for_singlestep_solver(self, steps, order): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, \ + and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [ + 3, + ] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [ + 3, + ] * (K - 1) + [1] + else: + orders = [ + 3, + ] * (K - 1) + [2] + return orders + elif order == 2: + K = steps // 2 + if steps % 2 == 0: + # orders = [2,] * K + K = steps // 2 + 1 + orders = [ + 2, + ] * (K - 2) + [ + 1, + ] * 2 + else: + orders = [ + 2, + ] * K + [1] + return orders + elif order == 1: + return [ + 1, + ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + + def denoise_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE \ + from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = (expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - + expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update( + self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver' + ): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` \ + (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( + s1 + ), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = (expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s - + (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - + expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - + expand_dims(sigma_t * phi_1, dims) * model_s - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * + (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - + expand_dims(sigma_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update( + self, + x, + s, + t, + r1=1. / 3., + r2=2. / 3., + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type='dpm_solver' + ): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` \ + (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( + s + ), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( + s2 + ), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - + expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - + expand_dims(sigma_s2 * phi_12, dims) * model_s - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * + (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - + expand_dims(sigma_t * phi_1, dims) * model_s - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * + (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - + expand_dims(sigma_t * phi_1, dims) * model_s - expand_dims(sigma_t * phi_2, dims) * D1 - + expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( + t_prev_0 + ), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x - + expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x - + expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - + expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - + 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - + expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - + expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( + t_prev_1 + ), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x - expand_dims(alpha_t * + (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - + expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - + expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - + expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - + expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update( + self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None + ): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, \ + `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1 + ) + elif order == 3: + return self.singlestep_dpm_solver_third_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2 + ) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive( + self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver' + ): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. \ + For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. \ + The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. \ + We solve the diffusion ODE until the absolute error between the \ + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, \ + "Gotta go fast when generating data with score-based models," \ + arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0], )).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, solver_type=solver_type, **kwargs + ) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type + ) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update( + x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs + ) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=3, + skip_type='time_uniform', + method='singlestep', + denoise=False, + solver_type='dpm_solver', + atol=0.0078, + rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), \ + which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` \ + to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 \ + and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, \ + and 1 step of singlestep DPM-Solver-2 \ + and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 \ + and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 \ + and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. + The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, \ + then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver \ + (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, \ + with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` \ + to balance the computatation costs (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 \ + and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. \ + 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise: A `bool`. Whether to denoise at the final step. Default is False. + If `denoise` is True, the total NFE is (`steps` + 1). + solver_type: A `str`. The taylor expansion type for the solver. \ + `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive( + x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type + ) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in range(1, order): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type + ) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + vec_t = timesteps[step].expand(x.shape[0]) + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, vec_t, order, solver_type=solver_type + ) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + orders = self.get_orders_for_singlestep_solver(steps=steps, order=order) + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [ + order, + ] * K + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=(K * order), device=device) + with torch.no_grad(): + i = 0 + for order in orders: + vec_s, vec_t = timesteps[i].expand(x.shape[0]), timesteps[i + order].expand(x.shape[0]) + h = self.noise_schedule.marginal_lambda(timesteps[i + order] + ) - self.noise_schedule.marginal_lambda(timesteps[i]) + r1 = None if order <= 1 else ( + self.noise_schedule.marginal_lambda(timesteps[i + 1]) - + self.noise_schedule.marginal_lambda(timesteps[i]) + ) / h + r2 = None if order <= 2 else ( + self.noise_schedule.marginal_lambda(timesteps[i + 2]) - + self.noise_schedule.marginal_lambda(timesteps[i]) + ) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + i += order + if denoise: + x = self.denoise_fn(x, torch.ones((x.shape[0], )).to(device) * t_0) + return x + + +############################################################# +# other utility functions +############################################################# + + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. \ + (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels + (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(..., ) + (None, ) * (dims - 1)] diff --git a/ding/torch_utils/network/activation.py b/ding/torch_utils/network/activation.py index b3c8fcda4c..e97d6645b2 100644 --- a/ding/torch_utils/network/activation.py +++ b/ding/torch_utils/network/activation.py @@ -159,6 +159,7 @@ def build_activation(activation: str, inplace: bool = None) -> nn.Module: "sigmoid": nn.Sigmoid(), "softplus": nn.Softplus(), "elu": nn.ELU(), + "silu": torch.nn.SiLU(inplace=inplace), "square": Lambda(lambda x: x ** 2), "identity": Lambda(lambda x: x), } diff --git a/ding/torch_utils/network/res_block.py b/ding/torch_utils/network/res_block.py index a698869b08..82908ce509 100644 --- a/ding/torch_utils/network/res_block.py +++ b/ding/torch_utils/network/res_block.py @@ -153,3 +153,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.dropout is not None: x = self.dropout(x) return x + + +class TemporalSpatialResBlock(nn.Module): + """ + Overview: + Residual Block using MLP layers for both temporal and spatial input. + t → time_mlp → h1 → dense2 → h2 → out + ↗+ ↗+ + x → dense1 → ↗ + ↘ ↗ + → modify_x → → → → + """ + + def __init__(self, input_dim, output_dim, t_dim=128, activation=torch.nn.SiLU()): + """ + Overview: + Init the temporal spatial residual block. + Arguments: + - input_dim (:obj:`int`): The number of channels in the input tensor. + - output_dim (:obj:`int`): The number of channels in the output tensor. + - t_dim (:obj:`int`): The dimension of the temporal input. + - activation (:obj:`nn.Module`): The optional activation function. + """ + super().__init__() + # temporal input is the embedding of time, which is a Gaussian Fourier Feature tensor + self.time_mlp = nn.Sequential( + activation, + nn.Linear(t_dim, output_dim), + ) + self.dense1 = nn.Sequential(nn.Linear(input_dim, output_dim), activation) + self.dense2 = nn.Sequential(nn.Linear(output_dim, output_dim), activation) + self.modify_x = nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity() + + def forward(self, x, t) -> torch.Tensor: + """ + Overview: + Return the redisual block output. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + - t (:obj:`torch.Tensor`): The temporal input tensor. + """ + h1 = self.dense1(x) + self.time_mlp(t) + h2 = self.dense2(h1) + return h2 + self.modify_x(x) diff --git a/dizoo/d4rl/config/halfcheetah_medium_expert_qgpo_config.py b/dizoo/d4rl/config/halfcheetah_medium_expert_qgpo_config.py new file mode 100644 index 0000000000..787a421d76 --- /dev/null +++ b/dizoo/d4rl/config/halfcheetah_medium_expert_qgpo_config.py @@ -0,0 +1,51 @@ +from easydict import EasyDict + +main_config = dict( + exp_name='halfcheetah_medium_expert_v2_QGPO_seed0', + env=dict( + env_id="halfcheetah-medium-expert-v2", + evaluator_env_num=8, + n_evaluator_episode=8, + ), + dataset=dict( + env_id="halfcheetah-medium-expert-v2", + ), + policy=dict( + cuda=True, + on_policy=False, + #load_path='./halfcheetah_medium_expert_v2_QGPO_seed0/ckpt/iteration_600000.pth.tar', + model=dict( + qgpo_critic=dict( + alpha=3, + q_alpha=1, + ), + device='cuda', + obs_dim=17, + action_dim=6, + ), + learn=dict( + learning_rate=1e-4, + batch_size=4096, + batch_size_q=256, + M=16, + diffusion_steps=15, + behavior_policy_stop_training_iter=600000, + energy_guided_policy_begin_training_iter=600000, + q_value_stop_training_iter=1100000, + ), + eval=dict( + guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], + diffusion_steps=15, + evaluator=dict(eval_freq=50000, ), + ), + ), +) +main_config = EasyDict(main_config) + +create_config = dict( + env_manager=dict(type='base'), + policy=dict( + type='qgpo', + ), +) +create_config = EasyDict(create_config) diff --git a/dizoo/d4rl/config/hopper_medium_expert_qgpo_config.py b/dizoo/d4rl/config/hopper_medium_expert_qgpo_config.py new file mode 100644 index 0000000000..b0941bc13c --- /dev/null +++ b/dizoo/d4rl/config/hopper_medium_expert_qgpo_config.py @@ -0,0 +1,51 @@ +from easydict import EasyDict + +main_config = dict( + exp_name='hopper_medium_expert_v2_QGPO_seed0', + env=dict( + env_id="hopper-medium-expert-v2", + evaluator_env_num=8, + n_evaluator_episode=8, + ), + dataset=dict( + env_id="hopper-medium-expert-v2", + ), + policy=dict( + cuda=True, + on_policy=False, + #load_path='./hopper_medium_expert_v2_QGPO_seed0/ckpt/iteration_600000.pth.tar', + model=dict( + qgpo_critic=dict( + alpha=3, + q_alpha=1, + ), + device='cuda', + obs_dim=11, + action_dim=3, + ), + learn=dict( + learning_rate=1e-4, + batch_size=4096, + batch_size_q=256, + M=16, + diffusion_steps=15, + behavior_policy_stop_training_iter=600000, + energy_guided_policy_begin_training_iter=600000, + q_value_stop_training_iter=1100000, + ), + eval=dict( + guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], + diffusion_steps=15, + evaluator=dict(eval_freq=50000, ), + ), + ), +) +main_config = EasyDict(main_config) + +create_config = dict( + env_manager=dict(type='base'), + policy=dict( + type='qgpo', + ), +) +create_config = EasyDict(create_config) diff --git a/dizoo/d4rl/config/walker2d_medium_expert_qgpo_config.py b/dizoo/d4rl/config/walker2d_medium_expert_qgpo_config.py new file mode 100644 index 0000000000..b4b3bd7bb6 --- /dev/null +++ b/dizoo/d4rl/config/walker2d_medium_expert_qgpo_config.py @@ -0,0 +1,51 @@ +from easydict import EasyDict + +main_config = dict( + exp_name='walker2d_medium_expert_v2_QGPO_seed0', + env=dict( + env_id="walker2d-medium-expert-v2", + evaluator_env_num=8, + n_evaluator_episode=8, + ), + dataset=dict( + env_id="walker2d-medium-expert-v2", + ), + policy=dict( + cuda=True, + on_policy=False, + #load_path='./walker2d_medium_expert_v2_QGPO_seed0/ckpt/iteration_600000.pth.tar', + model=dict( + qgpo_critic=dict( + alpha=3, + q_alpha=1, + ), + device='cuda', + obs_dim=17, + action_dim=6, + ), + learn=dict( + learning_rate=1e-4, + batch_size=4096, + batch_size_q=256, + M=16, + diffusion_steps=15, + behavior_policy_stop_training_iter=600000, + energy_guided_policy_begin_training_iter=600000, + q_value_stop_training_iter=1100000, + ), + eval=dict( + guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], + diffusion_steps=15, + evaluator=dict(eval_freq=50000, ), + ), + ), +) +main_config = EasyDict(main_config) + +create_config = dict( + env_manager=dict(type='base'), + policy=dict( + type='qgpo', + ), +) +create_config = EasyDict(create_config)