diff --git a/README.md b/README.md
index 4833c8fec0..39b7e5f46c 100644
--- a/README.md
+++ b/README.md
@@ -254,16 +254,17 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 43 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [TD3BC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3_bc.html)
[policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u d4rl_td3_bc_main.py |
| 44 | [Decision Transformer](https://arxiv.org/pdf/2106.01345.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/dt](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dt.py) | python3 -u d4rl_dt_mujoco.py |
| 45 | [EDAC](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)
[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py |
-| 46 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py |
-| 47 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py |
-| 48 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)
[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py |
-| 49 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py |
-| 50 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py |
-| 51 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
-| 52 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
-| 53 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
-| 54 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)
[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
-| 55 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
+| 46 | [QGPO](https://arxiv.org/pdf/2304.12824.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [QGPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qgpo.html)
[policy/qgpo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qgpo.py) | python3 -u ding/example/qgpo.py |
+| 47 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py |
+| 48 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py |
+| 49 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)
[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py |
+| 50 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py |
+| 51 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py |
+| 52 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
+| 53 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
+| 54 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
+| 55 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)
[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
+| 56 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
diff --git a/ding/example/qgpo.py b/ding/example/qgpo.py
new file mode 100644
index 0000000000..ed844974be
--- /dev/null
+++ b/ding/example/qgpo.py
@@ -0,0 +1,170 @@
+import torch
+import gym
+import d4rl
+from easydict import EasyDict
+from ditk import logging
+from ding.model import QGPO
+from ding.policy import QGPOPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OfflineRLContext
+from ding.framework.middleware import trainer, CkptSaver, offline_logger, wandb_offline_logger, termination_checker
+from ding.framework.middleware.functional.evaluator import interaction_evaluator
+from ding.framework.middleware.functional.data_processor import qgpo_support_data_generator, qgpo_offline_data_fetcher
+from ding.utils import set_pkg_seed
+
+from dizoo.d4rl.config.halfcheetah_medium_expert_qgpo_config import main_config, create_config
+
+
+class QGPOD4RLDataset(torch.utils.data.Dataset):
+ """
+ Overview:
+ Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, \
+ which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
+ is sampled from the action support generated by the behavior policy.
+ Interface:
+ ``__init__``, ``__getitem__``, ``__len__``.
+ """
+
+ def __init__(self, cfg, device="cpu"):
+ """
+ Overview:
+ Initialization method of QGPOD4RLDataset class
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict
+ - device (:obj:`str`): Device name
+ """
+
+ self.cfg = cfg
+ data = d4rl.qlearning_dataset(gym.make(cfg.env_id))
+ self.device = device
+ self.states = torch.from_numpy(data['observations']).float().to(self.device)
+ self.actions = torch.from_numpy(data['actions']).float().to(self.device)
+ self.next_states = torch.from_numpy(data['next_observations']).float().to(self.device)
+ reward = torch.from_numpy(data['rewards']).view(-1, 1).float().to(self.device)
+ self.is_finished = torch.from_numpy(data['terminals']).view(-1, 1).float().to(self.device)
+
+ reward_tune = "iql_antmaze" if "antmaze" in cfg.env_id else "iql_locomotion"
+ if reward_tune == 'normalize':
+ reward = (reward - reward.mean()) / reward.std()
+ elif reward_tune == 'iql_antmaze':
+ reward = reward - 1.0
+ elif reward_tune == 'iql_locomotion':
+ min_ret, max_ret = QGPOD4RLDataset.return_range(data, 1000)
+ reward /= (max_ret - min_ret)
+ reward *= 1000
+ elif reward_tune == 'cql_antmaze':
+ reward = (reward - 0.5) * 4.0
+ elif reward_tune == 'antmaze':
+ reward = (reward - 0.25) * 2.0
+ self.rewards = reward
+ self.len = self.states.shape[0]
+ logging.info(f"{self.len} data loaded in QGPOD4RLDataset")
+
+ def __getitem__(self, index):
+ """
+ Overview:
+ Get data by index
+ Arguments:
+ - index (:obj:`int`): Index of data
+ Returns:
+ - data (:obj:`dict`): Data dict
+
+ .. note::
+ The data dict contains the following keys:
+ - s (:obj:`torch.Tensor`): State
+ - a (:obj:`torch.Tensor`): Action
+ - r (:obj:`torch.Tensor`): Reward
+ - s_ (:obj:`torch.Tensor`): Next state
+ - d (:obj:`torch.Tensor`): Is finished
+ - fake_a (:obj:`torch.Tensor`): Fake action for contrastive energy prediction and qgpo training \
+ (fake action is sampled from the action support generated by the behavior policy)
+ - fake_a_ (:obj:`torch.Tensor`): Fake next action for contrastive energy prediction and qgpo training \
+ (fake action is sampled from the action support generated by the behavior policy)
+ """
+
+ data = {
+ 's': self.states[index % self.len],
+ 'a': self.actions[index % self.len],
+ 'r': self.rewards[index % self.len],
+ 's_': self.next_states[index % self.len],
+ 'd': self.is_finished[index % self.len],
+ 'fake_a': self.fake_actions[index % self.len]
+ if hasattr(self, "fake_actions") else 0.0, # self.fake_actions
+ 'fake_a_': self.fake_next_actions[index % self.len]
+ if hasattr(self, "fake_next_actions") else 0.0, # self.fake_next_actions
+ }
+ return data
+
+ def __len__(self):
+ return self.len
+
+ def return_range(dataset, max_episode_steps):
+ returns, lengths = [], []
+ ep_ret, ep_len = 0., 0
+ for r, d in zip(dataset['rewards'], dataset['terminals']):
+ ep_ret += float(r)
+ ep_len += 1
+ if d or ep_len == max_episode_steps:
+ returns.append(ep_ret)
+ lengths.append(ep_len)
+ ep_ret, ep_len = 0., 0
+ # returns.append(ep_ret) # incomplete trajectory
+ lengths.append(ep_len) # but still keep track of number of steps
+ assert sum(lengths) == len(dataset['rewards'])
+ return min(returns), max(returns)
+
+
+def main():
+ # If you don't have offline data, you need to prepare if first and set the data_path in config
+ # For demostration, we also can train a RL policy (e.g. SAC) and collect some data
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, policy=QGPOPolicy)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OfflineRLContext()):
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ model = QGPO(cfg=cfg.policy.model)
+ policy = QGPOPolicy(cfg.policy, model=model)
+ dataset = QGPOD4RLDataset(cfg=cfg.dataset, device=policy._device)
+ if hasattr(cfg.policy, "load_path") and cfg.policy.load_path is not None:
+ policy_state_dict = torch.load(cfg.policy.load_path, map_location=torch.device("cpu"))
+ policy.learn_mode.load_state_dict(policy_state_dict)
+
+ task.use(qgpo_support_data_generator(cfg, dataset, policy))
+ task.use(qgpo_offline_data_fetcher(cfg, dataset, collate_fn=None))
+ task.use(trainer(cfg, policy.learn_mode))
+ for guidance_scale in cfg.policy.eval.guidance_scale:
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[
+ lambda: DingEnvWrapper(env=gym.make(cfg.env.env_id), cfg=cfg.env, caller="evaluator")
+ for _ in range(cfg.env.evaluator_env_num)
+ ],
+ cfg=cfg.env.manager
+ )
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env, guidance_scale=guidance_scale))
+ task.use(
+ wandb_offline_logger(
+ cfg=EasyDict(
+ dict(
+ gradient_logger=False,
+ plot_logger=True,
+ video_logger=False,
+ action_logger=False,
+ return_logger=False,
+ vis_dataset=False,
+ )
+ ),
+ exp_config=cfg,
+ metric_list=policy._monitor_vars_learn(),
+ project_name=cfg.exp_name
+ )
+ )
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100000))
+ task.use(offline_logger())
+ task.use(termination_checker(max_train_iter=500000 + cfg.policy.learn.q_value_stop_training_iter))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ding/framework/context.py b/ding/framework/context.py
index 6fb35eec13..12f144766a 100644
--- a/ding/framework/context.py
+++ b/ding/framework/context.py
@@ -68,6 +68,7 @@ class OnlineRLContext(Context):
last_eval_value: int = -np.inf
eval_output: List = dataclasses.field(default_factory=dict)
# wandb
+ info_for_logging: Dict = dataclasses.field(default_factory=dict)
wandb_url: str = ""
def __post_init__(self):
@@ -93,6 +94,7 @@ class OfflineRLContext(Context):
last_eval_value: int = -np.inf
eval_output: List = dataclasses.field(default_factory=dict)
# wandb
+ info_for_logging: Dict = dataclasses.field(default_factory=dict)
wandb_url: str = ""
def __post_init__(self):
diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py
index cbcc39e7a2..5882716ef3 100644
--- a/ding/framework/middleware/functional/data_processor.py
+++ b/ding/framework/middleware/functional/data_processor.py
@@ -2,7 +2,9 @@
from typing import TYPE_CHECKING, Callable, List, Union, Tuple, Dict, Optional
from easydict import EasyDict
from ditk import logging
+import numpy as np
import torch
+import tqdm
from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type
from ding.data.buffer.middleware import PriorityExperienceReplay
from ding.framework import task
@@ -229,7 +231,7 @@ def _fetch(ctx: "OfflineRLContext"):
return _fetch
-def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable:
+def offline_data_fetcher(cfg: EasyDict, dataset: Dataset, collate_fn=lambda x: x) -> Callable:
"""
Overview:
The outer function transforms a Pytorch `Dataset` to `DataLoader`. \
@@ -241,7 +243,7 @@ def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable:
- dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data.
"""
# collate_fn is executed in policy now
- dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)
+ dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=collate_fn)
dataloader = iter(dataloader)
def _fetch(ctx: "OfflineRLContext"):
@@ -261,7 +263,7 @@ def _fetch(ctx: "OfflineRLContext"):
ctx.train_epoch += 1
del dataloader
dataloader = DataLoader(
- dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x
+ dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=collate_fn
)
dataloader = iter(dataloader)
ctx.train_data = next(dataloader)
@@ -319,3 +321,125 @@ def _pusher(ctx: "OnlineRLContext"):
ctx.trajectories = None
return _pusher
+
+
+def qgpo_support_data_generator(cfg, dataset, policy) -> Callable:
+
+ behavior_policy_stop_training_iter = cfg.policy.learn.behavior_policy_stop_training_iter if hasattr(
+ cfg.policy.learn, 'behavior_policy_stop_training_iter'
+ ) else np.inf
+ energy_guided_policy_begin_training_iter = cfg.policy.learn.energy_guided_policy_begin_training_iter if hasattr(
+ cfg.policy.learn, 'energy_guided_policy_begin_training_iter'
+ ) else 0
+ actions_generated = False
+
+ def generate_fake_actions():
+ allstates = dataset.states[:].cpu().numpy()
+ actions_sampled = []
+ for states in tqdm.tqdm(np.array_split(allstates, allstates.shape[0] // 4096 + 1)):
+ actions_sampled.append(
+ policy._model.sample(
+ states,
+ sample_per_state=cfg.policy.learn.M,
+ diffusion_steps=cfg.policy.learn.diffusion_steps,
+ guidance_scale=0.0,
+ )
+ )
+ actions = np.concatenate(actions_sampled)
+
+ allnextstates = dataset.next_states[:].cpu().numpy()
+ actions_next_states_sampled = []
+ for next_states in tqdm.tqdm(np.array_split(allnextstates, allnextstates.shape[0] // 4096 + 1)):
+ actions_next_states_sampled.append(
+ policy._model.sample(
+ next_states,
+ sample_per_state=cfg.policy.learn.M,
+ diffusion_steps=cfg.policy.learn.diffusion_steps,
+ guidance_scale=0.0,
+ )
+ )
+ actions_next_states = np.concatenate(actions_next_states_sampled)
+ return actions, actions_next_states
+
+ def _data_generator(ctx: "OfflineRLContext"):
+ nonlocal actions_generated
+
+ if ctx.train_iter >= energy_guided_policy_begin_training_iter:
+ if ctx.train_iter > behavior_policy_stop_training_iter:
+ # no need to generate fake actions if fake actions are already generated
+ if actions_generated:
+ pass
+ else:
+ actions, actions_next_states = generate_fake_actions()
+ dataset.fake_actions = torch.Tensor(actions.astype(np.float32)).to(cfg.policy.model.device)
+ dataset.fake_next_actions = torch.Tensor(actions_next_states.astype(np.float32)
+ ).to(cfg.policy.model.device)
+ actions_generated = True
+ else:
+ # generate fake actions
+ actions, actions_next_states = generate_fake_actions()
+ dataset.fake_actions = torch.Tensor(actions.astype(np.float32)).to(cfg.policy.model.device)
+ dataset.fake_next_actions = torch.Tensor(actions_next_states.astype(np.float32)
+ ).to(cfg.policy.model.device)
+ actions_generated = True
+ else:
+ # no need to generate fake actions
+ pass
+
+ return _data_generator
+
+
+def qgpo_offline_data_fetcher(cfg: EasyDict, dataset: Dataset, collate_fn=lambda x: x) -> Callable:
+ """
+ Overview:
+ The outer function transforms a Pytorch `Dataset` to `DataLoader`. \
+ The return function is a generator which each time fetches a batch of data from the previous `DataLoader`.\
+ Please refer to the link https://pytorch.org/tutorials/beginner/basics/data_tutorial.html \
+ and https://pytorch.org/docs/stable/data.html for more details.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`.
+ - dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data.
+ """
+ # collate_fn is executed in policy now
+ dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=collate_fn)
+ dataloader_q = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size_q, shuffle=True, collate_fn=collate_fn)
+
+ behavior_policy_stop_training_iter = cfg.policy.learn.behavior_policy_stop_training_iter if hasattr(
+ cfg.policy.learn, 'behavior_policy_stop_training_iter'
+ ) else np.inf
+ energy_guided_policy_begin_training_iter = cfg.policy.learn.energy_guided_policy_begin_training_iter if hasattr(
+ cfg.policy.learn, 'energy_guided_policy_begin_training_iter'
+ ) else 0
+
+ def get_behavior_policy_training_data():
+ while True:
+ yield from dataloader
+
+ data = get_behavior_policy_training_data()
+
+ def get_q_training_data():
+ while True:
+ yield from dataloader_q
+
+ data_q = get_q_training_data()
+
+ def _fetch(ctx: "OfflineRLContext"):
+ """
+ Overview:
+ Every time this generator is iterated, the fetched data will be assigned to ctx.train_data. \
+ After the dataloader is empty, the attribute `ctx.train_epoch` will be incremented by 1.
+ Input of ctx:
+ - train_epoch (:obj:`int`): Number of `train_epoch`.
+ Output of ctx:
+ - train_data (:obj:`List[Tensor]`): The fetched data batch.
+ """
+
+ if ctx.train_iter >= energy_guided_policy_begin_training_iter:
+ ctx.train_data = next(data_q)
+ else:
+ ctx.train_data = next(data)
+
+ # TODO apply data update (e.g. priority) in offline setting when necessary
+ ctx.trained_env_step += len(ctx.train_data)
+
+ return _fetch
diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py
index 611bbcdea6..3fa407b6de 100644
--- a/ding/framework/middleware/functional/evaluator.py
+++ b/ding/framework/middleware/functional/evaluator.py
@@ -210,7 +210,9 @@ def get_episode_output(self):
return output
-def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, render: bool = False) -> Callable:
+def interaction_evaluator(
+ cfg: EasyDict, policy: Policy, env: BaseEnvManager, render: bool = False, **kwargs
+) -> Callable:
"""
Overview:
The middleware that executes the evaluation.
@@ -219,6 +221,7 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, re
- policy (:obj:`Policy`): The policy to be evaluated.
- env (:obj:`BaseEnvManager`): The env for the evaluation.
- render (:obj:`bool`): Whether to render env images and policy logits.
+ - kwargs: (:obj:`Any`): Other arguments for specific evaluation.
"""
if task.router.is_active and not task.has_role(task.role.EVALUATOR):
return task.void()
@@ -239,8 +242,13 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
# evaluation will be executed if the task begins or enough train_iter after last evaluation
if ctx.last_eval_iter != -1 and \
- (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq):
- return
+ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq):
+ if ctx.train_iter != ctx.last_eval_iter:
+ return
+ if len(kwargs) > 0:
+ kwargs_str = '/'.join([f'{k}({v})' for k, v in kwargs.items()])
+ else:
+ kwargs_str = ''
if env.closed:
env.launch()
@@ -252,7 +260,10 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
while not eval_monitor.is_finished():
obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD
- inference_output = policy.forward(obs)
+ if len(kwargs) > 0:
+ inference_output = policy.forward(obs, **kwargs)
+ else:
+ inference_output = policy.forward(obs)
if render:
eval_monitor.update_video(env.ready_imgs)
eval_monitor.update_output(inference_output)
@@ -275,12 +286,14 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
stop_flag = episode_return >= cfg.env.stop_value and ctx.train_iter > 0
if isinstance(ctx, OnlineRLContext):
logging.info(
- 'Evaluation: Train Iter({})\tEnv Step({})\tEpisode Return({:.3f})'.format(
- ctx.train_iter, ctx.env_step, episode_return
+ 'Evaluation: Train Iter({}) Env Step({}) Episode Return({:.3f}) {}'.format(
+ ctx.train_iter, ctx.env_step, episode_return, kwargs_str
)
)
elif isinstance(ctx, OfflineRLContext):
- logging.info('Evaluation: Train Iter({})\tEval Reward({:.3f})'.format(ctx.train_iter, episode_return))
+ logging.info(
+ 'Evaluation: Train Iter({}) Eval Return({:.3f}) {}'.format(ctx.train_iter, episode_return, kwargs_str)
+ )
else:
raise TypeError("not supported ctx type: {}".format(type(ctx)))
ctx.last_eval_iter = ctx.train_iter
@@ -299,6 +312,16 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
else:
ctx.eval_output['output'] = output # for compatibility
+ if len(kwargs) > 0:
+ ctx.info_for_logging.update(
+ {
+ f'{kwargs_str}/eval_episode_return': episode_return,
+ f'{kwargs_str}/eval_episode_return_min': episode_return_min,
+ f'{kwargs_str}/eval_episode_return_max': episode_return_max,
+ f'{kwargs_str}/eval_episode_return_std': episode_return_std,
+ }
+ )
+
if stop_flag:
task.finish = True
diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py
index 9f62e2f429..5f968a5fd3 100644
--- a/ding/framework/middleware/functional/logger.py
+++ b/ding/framework/middleware/functional/logger.py
@@ -611,6 +611,15 @@ def _plot(ctx: "OfflineRLContext"):
)
if ctx.eval_value != -np.inf:
+ if hasattr(ctx, "info_for_logging"):
+ """
+ .. note::
+ The info_for_logging is a dict that contains the information to be logged.
+ Users can add their own information to the dict.
+ All the information in the dict will be logged to wandb.
+ """
+ info_for_logging.update(ctx.info_for_logging)
+
if hasattr(ctx, "eval_value_min"):
info_for_logging.update({
"episode return min": ctx.eval_value_min,
diff --git a/ding/model/common/__init__.py b/ding/model/common/__init__.py
index 4bf7d8be5a..b63aedfcd9 100755
--- a/ding/model/common/__init__.py
+++ b/ding/model/common/__init__.py
@@ -1,5 +1,5 @@
from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, StochasticDuelingHead, \
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead
-from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
+from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder, GaussianFourierProjectionTimeEncoder
from .utils import create_model
diff --git a/ding/model/common/encoder.py b/ding/model/common/encoder.py
index 82dab4808a..8936f6af4d 100644
--- a/ding/model/common/encoder.py
+++ b/ding/model/common/encoder.py
@@ -2,6 +2,7 @@
from functools import reduce
import operator
import math
+import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
@@ -470,3 +471,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.final_relu:
x = torch.relu(x)
return x
+
+
+class GaussianFourierProjectionTimeEncoder(nn.Module):
+ """
+ Overview:
+ Gaussian random features for encoding time steps.
+ This module is used as the encoder of time in generative models such as diffusion model.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(self, embed_dim, scale=30.):
+ """
+ Overview:
+ Initialize the Gaussian Fourier Projection Time Encoder according to arguments.
+ Arguments:
+ - embed_dim (:obj:`int`): The dimension of the output embedding vector.
+ - scale (:obj:`float`): The scale of the Gaussian random features.
+ """
+ super().__init__()
+ # Randomly sample weights during initialization. These weights are fixed
+ # during optimization and are not trainable.
+ self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale * 2 * np.pi, requires_grad=False)
+
+ def forward(self, x):
+ """
+ Overview:
+ Return the output embedding vector of the input time step.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input time step tensor.
+ Returns:
+ - output (:obj:`torch.Tensor`): Output embedding vector.
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B,)`, where B is batch size.
+ - output (:obj:`torch.Tensor`): :math:`(B, embed_dim)`, where B is batch size, embed_dim is the \
+ dimension of the output embedding vector.
+ Examples:
+ >>> encoder = GaussianFourierProjectionTimeEncoder(128)
+ >>> x = torch.randn(100)
+ >>> output = encoder(x)
+ """
+ x_proj = x[..., None] * self.W[None, :]
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
diff --git a/ding/model/common/tests/test_encoder.py b/ding/model/common/tests/test_encoder.py
index cd8a5bf752..d73f10561a 100644
--- a/ding/model/common/tests/test_encoder.py
+++ b/ding/model/common/tests/test_encoder.py
@@ -2,7 +2,7 @@
import numpy as np
import pytest
-from ding.model import ConvEncoder, FCEncoder, IMPALAConvEncoder
+from ding.model import ConvEncoder, FCEncoder, IMPALAConvEncoder, GaussianFourierProjectionTimeEncoder
from ding.torch_utils import is_differentiable
B = 4
@@ -61,3 +61,10 @@ def test_impalaconv_encoder(self):
outputs = model(inputs)
self.output_check(model, outputs)
assert outputs.shape == (B, 256)
+
+ def test_GaussianFourierProjectionTimeEncoder(self):
+ inputs = torch.randn(B)
+ model = GaussianFourierProjectionTimeEncoder(128)
+ print(model)
+ outputs = model(inputs)
+ assert outputs.shape == (B, 128)
diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py
index c9dc17791c..8e902f1504 100755
--- a/ding/model/template/__init__.py
+++ b/ding/model/template/__init__.py
@@ -26,5 +26,6 @@
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
from .bcq import BCQ
from .edac import EDAC
+from .qgpo import QGPO
from .ebm import EBM, AutoregressiveEBM
from .havac import HAVAC
diff --git a/ding/model/template/qgpo.py b/ding/model/template/qgpo.py
new file mode 100644
index 0000000000..135433c42c
--- /dev/null
+++ b/ding/model/template/qgpo.py
@@ -0,0 +1,465 @@
+#############################################################
+# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion
+#############################################################
+
+from easydict import EasyDict
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+from ding.torch_utils import MLP
+from ding.torch_utils.diffusion_SDE.dpm_solver_pytorch import DPM_Solver, NoiseScheduleVP
+from ding.model.common.encoder import GaussianFourierProjectionTimeEncoder
+from ding.torch_utils.network.res_block import TemporalSpatialResBlock
+
+
+def marginal_prob_std(t, device):
+ """
+ Overview:
+ Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.
+ Arguments:
+ - t (:obj:`torch.Tensor`): The input time.
+ - device (:obj:`torch.device`): The device to use.
+ """
+
+ t = torch.tensor(t, device=device)
+ beta_1 = 20.0
+ beta_0 = 0.1
+ log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
+ alpha_t = torch.exp(log_mean_coeff)
+ std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
+ return alpha_t, std
+
+
+class TwinQ(nn.Module):
+ """
+ Overview:
+ Twin Q network for QGPO, which has two Q networks.
+ Interfaces:
+ ``__init__``, ``forward``, ``both``
+ """
+
+ def __init__(self, action_dim, state_dim):
+ """
+ Overview:
+ Initialization of Twin Q.
+ Arguments:
+ - action_dim (:obj:`int`): The dimension of action.
+ - state_dim (:obj:`int`): The dimension of state.
+ """
+ super().__init__()
+ self.q1 = MLP(
+ in_channels=state_dim + action_dim,
+ hidden_channels=256,
+ out_channels=1,
+ activation=nn.ReLU(),
+ layer_num=4,
+ output_activation=False
+ )
+ self.q2 = MLP(
+ in_channels=state_dim + action_dim,
+ hidden_channels=256,
+ out_channels=1,
+ activation=nn.ReLU(),
+ layer_num=4,
+ output_activation=False
+ )
+
+ def both(self, action, condition=None):
+ """
+ Overview:
+ Return the output of two Q networks.
+ Arguments:
+ - action (:obj:`torch.Tensor`): The input action.
+ - condition (:obj:`torch.Tensor`): The input condition.
+ """
+ as_ = torch.cat([action, condition], -1) if condition is not None else action
+ return self.q1(as_), self.q2(as_)
+
+ def forward(self, action, condition=None):
+ """
+ Overview:
+ Return the minimum output of two Q networks.
+ Arguments:
+ - action (:obj:`torch.Tensor`): The input action.
+ - condition (:obj:`torch.Tensor`): The input condition.
+ """
+ return torch.min(*self.both(action, condition))
+
+
+class GuidanceQt(nn.Module):
+ """
+ Overview:
+ Energy Guidance Qt network for QGPO. \
+ In the origin paper, the energy guidance is trained by CEP method.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, action_dim, state_dim, time_embed_dim=32):
+ """
+ Overview:
+ Initialization of Guidance Qt.
+ Arguments:
+ - action_dim (:obj:`int`): The dimension of action.
+ - state_dim (:obj:`int`): The dimension of state.
+ - time_embed_dim (:obj:`int`): The dimension of time embedding. \
+ The time embedding is a Gaussian Fourier Feature tensor.
+ """
+ super().__init__()
+ self.qt = MLP(
+ in_channels=action_dim + time_embed_dim + state_dim,
+ hidden_channels=256,
+ out_channels=1,
+ activation=torch.nn.SiLU(),
+ layer_num=4,
+ output_activation=False
+ )
+ self.embed = nn.Sequential(
+ GaussianFourierProjectionTimeEncoder(embed_dim=time_embed_dim), nn.Linear(time_embed_dim, time_embed_dim)
+ )
+
+ def forward(self, action, t, condition=None):
+ """
+ Overview:
+ Return the output of Guidance Qt.
+ Arguments:
+ - action (:obj:`torch.Tensor`): The input action.
+ - t (:obj:`torch.Tensor`): The input time.
+ - condition (:obj:`torch.Tensor`): The input condition.
+ """
+ embed = self.embed(t)
+ ats = torch.cat([action, embed, condition], -1) if condition is not None else torch.cat([action, embed], -1)
+ return self.qt(ats)
+
+
+class QGPOCritic(nn.Module):
+ """
+ Overview:
+ QGPO critic network.
+ Interfaces:
+ ``__init__``, ``forward``, ``calculateQ``, ``calculate_guidance``
+ """
+
+ def __init__(self, device, cfg, action_dim, state_dim) -> None:
+ """
+ Overview:
+ Initialization of QGPO critic.
+ Arguments:
+ - device (:obj:`torch.device`): The device to use.
+ - cfg (:obj:`EasyDict`): The config dict.
+ - action_dim (:obj:`int`): The dimension of action.
+ - state_dim (:obj:`int`): The dimension of state.
+ """
+
+ super().__init__()
+ # is state_dim is 0 means unconditional guidance
+ assert state_dim > 0
+ # only apply to conditional sampling here
+ self.device = device
+ self.q0 = TwinQ(action_dim, state_dim).to(self.device)
+ self.q0_target = copy.deepcopy(self.q0).requires_grad_(False).to(self.device)
+ self.qt = GuidanceQt(action_dim, state_dim).to(self.device)
+
+ self.alpha = cfg.alpha
+ self.q_alpha = cfg.q_alpha
+
+ def calculate_guidance(self, a, t, condition=None, guidance_scale=1.0):
+ """
+ Overview:
+ Calculate the guidance for conditional sampling.
+ Arguments:
+ - a (:obj:`torch.Tensor`): The input action.
+ - t (:obj:`torch.Tensor`): The input time.
+ - condition (:obj:`torch.Tensor`): The input condition.
+ - guidance_scale (:obj:`float`): The scale of guidance.
+ """
+
+ with torch.enable_grad():
+ a.requires_grad_(True)
+ Q_t = self.qt(a, t, condition)
+ guidance = guidance_scale * torch.autograd.grad(torch.sum(Q_t), a)[0]
+ return guidance.detach()
+
+ def forward(self, a, condition=None):
+ """
+ Overview:
+ Return the output of QGPO critic.
+ Arguments:
+ - a (:obj:`torch.Tensor`): The input action.
+ - condition (:obj:`torch.Tensor`): The input condition.
+ """
+
+ return self.q0(a, condition)
+
+ def calculateQ(self, a, condition=None):
+ """
+ Overview:
+ Return the output of QGPO critic.
+ Arguments:
+ - a (:obj:`torch.Tensor`): The input action.
+ - condition (:obj:`torch.Tensor`): The input condition.
+ """
+
+ return self(a, condition)
+
+
+class ScoreNet(nn.Module):
+ """
+ Overview:
+ Score-based generative model for QGPO.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, device, input_dim, output_dim, embed_dim=32):
+ """
+ Overview:
+ Initialization of ScoreNet.
+ Arguments:
+ - device (:obj:`torch.device`): The device to use.
+ - input_dim (:obj:`int`): The dimension of input.
+ - output_dim (:obj:`int`): The dimension of output.
+ - embed_dim (:obj:`int`): The dimension of time embedding. \
+ The time embedding is a Gaussian Fourier Feature tensor.
+ """
+
+ super().__init__()
+
+ # origin score base
+ self.output_dim = output_dim
+ self.embed = nn.Sequential(
+ GaussianFourierProjectionTimeEncoder(embed_dim=embed_dim), nn.Linear(embed_dim, embed_dim)
+ )
+
+ self.device = device
+ self.pre_sort_condition = nn.Sequential(nn.Linear(input_dim - output_dim, 32), torch.nn.SiLU())
+ self.sort_t = nn.Sequential(
+ nn.Linear(64, 128),
+ torch.nn.SiLU(),
+ nn.Linear(128, 128),
+ )
+ self.down_block1 = TemporalSpatialResBlock(output_dim, 512)
+ self.down_block2 = TemporalSpatialResBlock(512, 256)
+ self.down_block3 = TemporalSpatialResBlock(256, 128)
+ self.middle1 = TemporalSpatialResBlock(128, 128)
+ self.up_block3 = TemporalSpatialResBlock(256, 256)
+ self.up_block2 = TemporalSpatialResBlock(512, 512)
+ self.last = nn.Linear(1024, output_dim)
+
+ def forward(self, x, t, condition):
+ """
+ Overview:
+ Return the output of ScoreNet.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ - t (:obj:`torch.Tensor`): The input time.
+ - condition (:obj:`torch.Tensor`): The input condition.
+ """
+
+ embed = self.embed(t)
+ embed = torch.cat([self.pre_sort_condition(condition), embed], dim=-1)
+ embed = self.sort_t(embed)
+ d1 = self.down_block1(x, embed)
+ d2 = self.down_block2(d1, embed)
+ d3 = self.down_block3(d2, embed)
+ u3 = self.middle1(d3, embed)
+ u2 = self.up_block3(torch.cat([d3, u3], dim=-1), embed)
+ u1 = self.up_block2(torch.cat([d2, u2], dim=-1), embed)
+ u0 = torch.cat([d1, u1], dim=-1)
+ h = self.last(u0)
+ self.h = h
+ # Normalize output
+ return h / marginal_prob_std(t, device=self.device)[1][..., None]
+
+
+class QGPO(nn.Module):
+ """
+ Overview:
+ Model of QGPO algorithm.
+ Interfaces:
+ ``__init__``, ``calculateQ``, ``select_actions``, ``sample``, ``score_model_loss_fn``, ``q_loss_fn``, \
+ ``qt_loss_fn``
+ """
+
+ def __init__(self, cfg: EasyDict) -> None:
+ """
+ Overview:
+ Initialization of QGPO.
+ Arguments:
+ - cfg (:obj:`EasyDict`): The config dict.
+ """
+
+ super(QGPO, self).__init__()
+ self.device = cfg.device
+ self.obs_dim = cfg.obs_dim
+ self.action_dim = cfg.action_dim
+
+ self.noise_schedule = NoiseScheduleVP(schedule='linear')
+
+ self.score_model = ScoreNet(
+ device=self.device,
+ input_dim=self.obs_dim + self.action_dim,
+ output_dim=self.action_dim,
+ )
+
+ self.q = QGPOCritic(self.device, cfg.qgpo_critic, action_dim=self.action_dim, state_dim=self.obs_dim)
+
+ def calculateQ(self, s, a):
+ """
+ Overview:
+ Calculate the Q value.
+ Arguments:
+ - s (:obj:`torch.Tensor`): The input state.
+ - a (:obj:`torch.Tensor`): The input action.
+ """
+
+ return self.q(a, s)
+
+ def select_actions(self, states, diffusion_steps=15, guidance_scale=1.0):
+ """
+ Overview:
+ Select actions for conditional sampling.
+ Arguments:
+ - states (:obj:`list`): The input states.
+ - diffusion_steps (:obj:`int`): The diffusion steps.
+ - guidance_scale (:obj:`float`): The scale of guidance.
+ """
+
+ def forward_dpm_wrapper_fn(x, t):
+ score = self.score_model(x, t, condition=states)
+ result = -(score +
+ self.q.calculate_guidance(x, t, states, guidance_scale=guidance_scale)) * marginal_prob_std(
+ t, device=self.device
+ )[1][..., None]
+ return result
+
+ self.eval()
+ multiple_input = True
+ with torch.no_grad():
+ states = torch.FloatTensor(states).to(self.device)
+ if states.dim == 1:
+ states = states.unsqueeze(0)
+ multiple_input = False
+ num_states = states.shape[0]
+
+ init_x = torch.randn(states.shape[0], self.action_dim, device=self.device)
+ results = DPM_Solver(
+ forward_dpm_wrapper_fn, self.noise_schedule, predict_x0=True
+ ).sample(
+ init_x, steps=diffusion_steps, order=2
+ ).cpu().numpy()
+
+ actions = results.reshape(num_states, self.action_dim).copy() #
+
+ out_actions = [actions[i] for i in range(actions.shape[0])] if multiple_input else actions[0]
+ self.train()
+ return out_actions
+
+ def sample(self, states, sample_per_state=16, diffusion_steps=15, guidance_scale=1.0):
+ """
+ Overview:
+ Sample actions for conditional sampling.
+ Arguments:
+ - states (:obj:`list`): The input states.
+ - sample_per_state (:obj:`int`): The number of samples per state.
+ - diffusion_steps (:obj:`int`): The diffusion steps.
+ - guidance_scale (:obj:`float`): The scale of guidance.
+ """
+
+ def forward_dpm_wrapper_fn(x, t):
+ score = self.score_model(x, t, condition=states)
+ result = -(score + self.q.calculate_guidance(x, t, states, guidance_scale=guidance_scale)) \
+ * marginal_prob_std(t, device=self.device)[1][..., None]
+ return result
+
+ self.eval()
+ num_states = states.shape[0]
+ with torch.no_grad():
+ states = torch.FloatTensor(states).to(self.device)
+ states = torch.repeat_interleave(states, sample_per_state, dim=0)
+
+ init_x = torch.randn(states.shape[0], self.action_dim, device=self.device)
+ results = DPM_Solver(
+ forward_dpm_wrapper_fn, self.noise_schedule, predict_x0=True
+ ).sample(
+ init_x, steps=diffusion_steps, order=2
+ ).cpu().numpy()
+
+ actions = results[:, :].reshape(num_states, sample_per_state, self.action_dim).copy()
+
+ self.train()
+ return actions
+
+ def score_model_loss_fn(self, x, s, eps=1e-3):
+ """
+ Overview:
+ The loss function for training score-based generative models.
+ Arguments:
+ model: A PyTorch model instance that represents a \
+ time-dependent score-based model.
+ x: A mini-batch of training data.
+ eps: A tolerance value for numerical stability.
+ """
+
+ random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
+ z = torch.randn_like(x)
+ alpha_t, std = marginal_prob_std(random_t, device=x.device)
+ perturbed_x = x * alpha_t[:, None] + z * std[:, None]
+ score = self.score_model(perturbed_x, random_t, condition=s)
+ loss = torch.mean(torch.sum((score * std[:, None] + z) ** 2, dim=(1, )))
+ return loss
+
+ def q_loss_fn(self, a, s, r, s_, d, fake_a_, discount=0.99):
+ """
+ Overview:
+ The loss function for training Q function.
+ Arguments:
+ - a (:obj:`torch.Tensor`): The input action.
+ - s (:obj:`torch.Tensor`): The input state.
+ - r (:obj:`torch.Tensor`): The input reward.
+ - s_ (:obj:`torch.Tensor`): The input next state.
+ - d (:obj:`torch.Tensor`): The input done.
+ - fake_a_ (:obj:`torch.Tensor`): The input fake action.
+ - discount (:obj:`float`): The discount factor.
+ """
+
+ with torch.no_grad():
+ softmax = nn.Softmax(dim=1)
+ next_energy = self.q.q0_target(fake_a_, torch.stack([s_] * fake_a_.shape[1], axis=1)).detach().squeeze()
+ next_v = torch.sum(softmax(self.q.q_alpha * next_energy) * next_energy, dim=-1, keepdim=True)
+ # Update Q function
+ targets = r + (1. - d.float()) * discount * next_v.detach()
+ qs = self.q.q0.both(a, s)
+ q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs)
+
+ return q_loss
+
+ def qt_loss_fn(self, s, fake_a):
+ """
+ Overview:
+ The loss function for training Guidance Qt.
+ Arguments:
+ - s (:obj:`torch.Tensor`): The input state.
+ - fake_a (:obj:`torch.Tensor`): The input fake action.
+ """
+
+ # input many s anction ,
+ energy = self.q.q0_target(fake_a, torch.stack([s] * fake_a.shape[1], axis=1)).detach().squeeze()
+
+ # CEP guidance method, as proposed in the paper
+ logsoftmax = nn.LogSoftmax(dim=1)
+ softmax = nn.Softmax(dim=1)
+
+ x0_data_energy = energy * self.q.alpha
+ random_t = torch.rand((fake_a.shape[0], ), device=self.device) * (1. - 1e-3) + 1e-3
+ random_t = torch.stack([random_t] * fake_a.shape[1], dim=1)
+ z = torch.randn_like(fake_a)
+ alpha_t, std = marginal_prob_std(random_t, device=self.device)
+ perturbed_fake_a = fake_a * alpha_t[..., None] + z * std[..., None]
+ xt_model_energy = self.q.qt(perturbed_fake_a, random_t, torch.stack([s] * fake_a.shape[1], axis=1)).squeeze()
+ p_label = softmax(x0_data_energy)
+
+ #
+ qt_loss = -torch.mean(torch.sum(p_label * logsoftmax(xt_model_energy), axis=-1))
+ return qt_loss
diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py
index c85883a0af..1f202da3bb 100755
--- a/ding/policy/__init__.py
+++ b/ding/policy/__init__.py
@@ -51,6 +51,7 @@
from .pc import ProcedureCloningBFSPolicy
from .bcq import BCQPolicy
+from .qgpo import QGPOPolicy
# new-type policy
from .ppof import PPOFPolicy
diff --git a/ding/policy/qgpo.py b/ding/policy/qgpo.py
new file mode 100644
index 0000000000..cfa3cb19c1
--- /dev/null
+++ b/ding/policy/qgpo.py
@@ -0,0 +1,269 @@
+#############################################################
+# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion
+#############################################################
+
+from typing import List, Dict, Any
+import torch
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate
+from ding.torch_utils import to_device
+from .base_policy import Policy
+
+
+@POLICY_REGISTRY.register('qgpo')
+class QGPOPolicy(Policy):
+ """
+ Overview:
+ Policy class of QGPO algorithm
+ Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning
+ https://arxiv.org/abs/2304.12824
+ Interfaces:
+ ``__init__``, ``forward``, ``learn``, ``eval``, ``state_dict``, ``load_state_dict``
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='qgpo',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool type) on_policy: Determine whether on-policy or off-policy.
+ # on-policy setting influences the behaviour of buffer.
+ # Default False in QGPO.
+ on_policy=False,
+ multi_agent=False,
+ model=dict(
+ qgpo_critic=dict(
+ # (float) The scale of the energy guidance when training qt.
+ # \pi_{behavior}\exp(f(s,a)) \propto \pi_{behavior}\exp(alpha * Q(s,a))
+ alpha=3,
+ # (float) The scale of the energy guidance when training q0.
+ # \mathcal{T}Q(s,a)=r(s,a)+\mathbb{E}_{s'\sim P(s'|s,a),a'\sim\pi_{support}(a'|s')}Q(s',a')
+ # \pi_{support} \propto \pi_{behavior}\exp(q_alpha * Q(s,a))
+ q_alpha=1,
+ ),
+ device='cuda',
+ # obs_dim
+ # action_dim
+ ),
+ learn=dict(
+ # learning rate for behavior model training
+ learning_rate=1e-4,
+ # batch size during the training of behavior model
+ batch_size=4096,
+ # batch size during the training of q value
+ batch_size_q=256,
+ # number of fake action support
+ M=16,
+ # number of diffusion time steps
+ diffusion_steps=15,
+ # training iterations when behavior model is fixed
+ behavior_policy_stop_training_iter=600000,
+ # training iterations when energy-guided policy begin training
+ energy_guided_policy_begin_training_iter=600000,
+ # training iterations when q value stop training, default None means no limit
+ q_value_stop_training_iter=1100000,
+ ),
+ eval=dict(
+ # energy guidance scale for policy in evaluation
+ # \pi_{evaluation} \propto \pi_{behavior}\exp(guidance_scale * alpha * Q(s,a))
+ guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0],
+ ),
+ )
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Learn mode initialization method. For QGPO, it mainly contains the optimizer, \
+ algorithm-specific arguments such as qt_update_momentum, discount, behavior_policy_stop_training_iter, \
+ energy_guided_policy_begin_training_iter and q_value_stop_training_iter, etc.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+ """
+ self.cuda = self._cfg.cuda
+
+ self.behavior_model_optimizer = torch.optim.Adam(
+ self._model.score_model.parameters(), lr=self._cfg.learn.learning_rate
+ )
+ self.q_optimizer = torch.optim.Adam(self._model.q.q0.parameters(), lr=3e-4)
+ self.qt_optimizer = torch.optim.Adam(self._model.q.qt.parameters(), lr=3e-4)
+
+ self.qt_update_momentum = 0.005
+ self.discount = 0.99
+
+ self.behavior_policy_stop_training_iter = self._cfg.learn.behavior_policy_stop_training_iter
+ self.energy_guided_policy_begin_training_iter = self._cfg.learn.energy_guided_policy_begin_training_iter
+ self.q_value_stop_training_iter = self._cfg.learn.q_value_stop_training_iter
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ """
+ Overview:
+ Forward function for learning mode.
+ The training of QGPO algorithm is based on contrastive energy prediction, \
+ which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
+ is sampled from the action support generated by the behavior policy.
+ The training process is divided into two stages:
+ 1. Train the behavior model, which is modeled as a diffusion model by parameterizing the score function.
+ 2. Train the Q function by fake action support generated by the behavior model.
+ 3. Train the energy-guided policy by the Q function.
+ Arguments:
+ - data (:obj:`dict`): Dict type data.
+ Returns:
+ - result (:obj:`dict`): Dict type data of algorithm results.
+ """
+
+ if self.cuda:
+ data = to_device(data, self._device)
+
+ s = data['s']
+ a = data['a']
+ r = data['r']
+ s_ = data['s_']
+ d = data['d']
+ fake_a = data['fake_a']
+ fake_a_ = data['fake_a_']
+
+ # training behavior model
+ if self.behavior_policy_stop_training_iter > 0:
+
+ behavior_model_training_loss = self._model.score_model_loss_fn(a, s)
+
+ self.behavior_model_optimizer.zero_grad()
+ behavior_model_training_loss.backward()
+ self.behavior_model_optimizer.step()
+
+ self.behavior_policy_stop_training_iter -= 1
+ behavior_model_training_loss = behavior_model_training_loss.item()
+ else:
+ behavior_model_training_loss = 0
+
+ # training Q function
+ self.energy_guided_policy_begin_training_iter -= 1
+ self.q_value_stop_training_iter -= 1
+ if self.energy_guided_policy_begin_training_iter < 0:
+ if self.q_value_stop_training_iter > 0:
+ q0_loss = self._model.q_loss_fn(a, s, r, s_, d, fake_a_, discount=self.discount)
+
+ self.q_optimizer.zero_grad()
+ q0_loss.backward()
+ self.q_optimizer.step()
+
+ # Update target
+ for param, target_param in zip(self._model.q.q0.parameters(), self._model.q.q0_target.parameters()):
+ target_param.data.copy_(
+ self.qt_update_momentum * param.data + (1 - self.qt_update_momentum) * target_param.data
+ )
+
+ q0_loss = q0_loss.item()
+
+ else:
+ q0_loss = 0
+ qt_loss = self._model.qt_loss_fn(s, fake_a)
+
+ self.qt_optimizer.zero_grad()
+ qt_loss.backward()
+ self.qt_optimizer.step()
+
+ qt_loss = qt_loss.item()
+
+ else:
+ q0_loss = 0
+ qt_loss = 0
+
+ total_loss = behavior_model_training_loss + q0_loss + qt_loss
+
+ return dict(
+ total_loss=total_loss,
+ behavior_model_training_loss=behavior_model_training_loss,
+ q0_loss=q0_loss,
+ qt_loss=qt_loss,
+ )
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Collect mode initialization method. Not supported for QGPO.
+ """
+ pass
+
+ def _forward_collect(self) -> None:
+ """
+ Overview:
+ Forward function for collect mode. Not supported for QGPO.
+ """
+ pass
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Eval mode initialization method. For QGPO, it mainly contains the guidance_scale and diffusion_steps, etc.
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+ """
+
+ self.diffusion_steps = self._cfg.eval.diffusion_steps
+
+ def _forward_eval(self, data: dict, guidance_scale: float) -> dict:
+ """
+ Overview:
+ Forward function for eval mode. The eval process is based on the energy-guided policy, \
+ which is modeled as a diffusion model by parameterizing the score function.
+ Arguments:
+ - data (:obj:`dict`): Dict type data.
+ - guidance_scale (:obj:`float`): The scale of the energy guidance.
+ Returns:
+ - output (:obj:`dict`): Dict type data of algorithm output.
+ """
+
+ data_id = list(data.keys())
+ states = default_collate(list(data.values()))
+ actions = self._model.select_actions(
+ states, diffusion_steps=self.diffusion_steps, guidance_scale=guidance_scale
+ )
+ output = actions
+
+ return {i: {"action": d} for i, d in zip(data_id, output)}
+
+ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ Get the train sample from the replay buffer, currently not supported for QGPO.
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]]`): The data from replay buffer.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The data for training.
+ """
+ pass
+
+ def _process_transition(self) -> None:
+ """
+ Overview:
+ Process the transition data, currently not supported for QGPO.
+ """
+ pass
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state dict for saving.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict.
+ """
+ return {
+ 'model': self._model.state_dict(),
+ 'behavior_model_optimizer': self.behavior_model_optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state dict.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict.
+ """
+ self._model.load_state_dict(state_dict['model'])
+ self.behavior_model_optimizer.load_state_dict(state_dict['behavior_model_optimizer'])
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the variables names to be monitored.
+ """
+ return ['total_loss', 'behavior_model_training_loss', 'q0_loss', 'qt_loss']
diff --git a/ding/torch_utils/diffusion_SDE/__init__.py b/ding/torch_utils/diffusion_SDE/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/ding/torch_utils/diffusion_SDE/dpm_solver_pytorch.py b/ding/torch_utils/diffusion_SDE/dpm_solver_pytorch.py
new file mode 100644
index 0000000000..4f43c10510
--- /dev/null
+++ b/ding/torch_utils/diffusion_SDE/dpm_solver_pytorch.py
@@ -0,0 +1,1288 @@
+#############################################################
+# This DPM-Solver snippet is from https://github.com/ChenDRAG/CEP-energy-guided-diffusion
+# wich is based on https://github.com/LuChengTHU/dpm-solver
+#############################################################
+
+import torch
+import math
+
+
+class NoiseScheduleVP:
+
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation \
+ for log_alpha_t. We recommend to use schedule='discrete' for the discrete-time diffusion models, \
+ especially for high-resolution images.
+ ***
+
+ The forward SDE ensures that the condition distribution \
+ q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), \
+ which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t.
+ For t in [0, T], we have:
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+
+ t = self.inverse_lambda(lambda_t)
+
+ ===============================================================
+
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and \
+ continuous-time DPMs (trained on t in [t_0, T]).
+
+ 1. For discrete-time DPMs:
+
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, \
+ we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. \
+ (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. \
+ (See the original DDPM paper for details)
+
+ Note that we always have alphas_cumprod = cumprod(betas). \
+ Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. \
+ Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver.
+ In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+
+
+ 2. For continuous-time DPMs:
+
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM).
+ The hyperparameters for the noise schedule are the default settings in DDPM and improved-DDPM:
+
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+
+ ===============================================================
+
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+
+ Example:
+
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array \
+ for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError(
+ "Unsupported noise schedule {}. \
+ The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)
+ )
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((
+ 1,
+ -1,
+ ))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi
+ ) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
+ self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1, )).to(lamb))
+ Delta = self.beta_0 ** 2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1, )).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(
+ log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+ torch.flip(self.t_array.to(lamb.device), [1])
+ )
+ return t.reshape((-1, ))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1, )).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)
+ ) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ condition=None,
+ unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+
+ We support four types of the diffusion model by setting `model_type`:
+
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+
+ ===============================================================
+
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1, )).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+
+ def cond_grad_fn(x, t_input):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1, )).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t_continuous] * 2)
+ c_in = torch.cat([unconditional_condition, condition])
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+
+
+class DPM_Solver:
+
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+ """Construct a DPM-Solver.
+
+ We support both the noise prediction model ("predicting epsilon") and \
+ the data prediction model ("predicting x0").
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs \
+ with large guidance scales.
+
+ Args:
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+ ``
+ def model_fn(x, t_continuous):
+ return noise
+ ``
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. \
+ The max value for thresholding.
+
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, \
+ Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. \
+ Photorealistic text-to-image diffusion models with deep language understanding. \
+ arXiv preprint arXiv:2205.11487, 2022b.
+ """
+ self.model = model_fn
+ self.noise_schedule = noise_schedule
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / (s / self.max_val)
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+
+ Args:
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. \
+ (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. \
+ (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ N: A `int`. The total number of the spacing of the time steps.
+ device: A torch device.
+ Returns:
+ A pytorch tensor of the time steps, with the shape (N + 1,).
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError(
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
+ )
+
+ def get_orders_for_singlestep_solver(self, steps, order):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+ - If order == 1:
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
+ - If order == 2:
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If order == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, \
+ and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+
+ ============================================
+ Args:
+ order: A `int`. The max order for the solver (2 or 3).
+ steps: A `int`. The total number of function evaluations (NFE).
+ Returns:
+ orders: A list of the solver order of each step.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [
+ 3,
+ ] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [
+ 3,
+ ] * (K - 1) + [1]
+ else:
+ orders = [
+ 3,
+ ] * (K - 1) + [2]
+ return orders
+ elif order == 2:
+ K = steps // 2
+ if steps % 2 == 0:
+ # orders = [2,] * K
+ K = steps // 2 + 1
+ orders = [
+ 2,
+ ] * (K - 2) + [
+ 1,
+ ] * 2
+ else:
+ orders = [
+ 2,
+ ] * K + [1]
+ return orders
+ elif order == 1:
+ return [
+ 1,
+ ] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+
+ def denoise_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE \
+ from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+ """
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_1 = torch.expm1(-h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s)
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+ else:
+ phi_1 = torch.expm1(h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x -
+ expand_dims(sigma_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_second_update(
+ self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'
+ ):
+ """
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the second-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` \
+ (the intermediate time).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 0.5
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
+ s1
+ ), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_1 = torch.expm1(-h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s)
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s -
+ (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s +
+ (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s)
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_1 = torch.expm1(h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x -
+ expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x -
+ expand_dims(sigma_t * phi_1, dims) * model_s - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) *
+ (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x -
+ expand_dims(sigma_t * phi_1, dims) * model_s -
+ (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_third_update(
+ self,
+ x,
+ s,
+ t,
+ r1=1. / 3.,
+ r2=2. / 3.,
+ model_s=None,
+ model_s1=None,
+ return_intermediate=False,
+ solver_type='dpm_solver'
+ ):
+ """
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` \
+ (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 1. / 3.
+ if r2 is None:
+ r2 = 2. / 3.
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ lambda_s2 = lambda_s + r2 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ s2 = ns.inverse_lambda(lambda_s2)
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+ s
+ ), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
+ s2
+ ), ns.marginal_std(t)
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_12 = torch.expm1(-r2 * h)
+ phi_1 = torch.expm1(-h)
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+ phi_2 = phi_1 / h + 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s)
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(sigma_s2 / sigma_s, dims) * x - expand_dims(alpha_s2 * phi_12, dims) * model_s +
+ r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s +
+ (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s +
+ expand_dims(alpha_t * phi_2, dims) * D1 - expand_dims(alpha_t * phi_3, dims) * D2
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_12 = torch.expm1(r2 * h)
+ phi_1 = torch.expm1(h)
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+ phi_2 = phi_1 / h - 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x -
+ expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x -
+ expand_dims(sigma_s2 * phi_12, dims) * model_s - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) *
+ (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x -
+ expand_dims(sigma_t * phi_1, dims) * model_s - (1. / r2) * expand_dims(sigma_t * phi_2, dims) *
+ (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x -
+ expand_dims(sigma_t * phi_1, dims) * model_s - expand_dims(sigma_t * phi_2, dims) * D1 -
+ expand_dims(sigma_t * phi_3, dims) * D2
+ )
+
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+ else:
+ return x_t
+
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+ """
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_1, model_prev_0 = model_prev_list
+ t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+ t_prev_0
+ ), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0 = h_0 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ if self.predict_x0:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x -
+ expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 -
+ 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x -
+ expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 +
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+ )
+ else:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x -
+ expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 -
+ 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x -
+ expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 -
+ expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+ """
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
+ t_prev_1
+ ), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_1 = lambda_prev_1 - lambda_prev_2
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0, r1 = h_0 / h, h_1 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+ if self.predict_x0:
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x - expand_dims(alpha_t *
+ (torch.exp(-h) - 1.), dims) * model_prev_0 +
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 -
+ expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
+ )
+ else:
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x -
+ expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 -
+ expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 -
+ expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
+ )
+ return x_t
+
+ def singlestep_dpm_solver_update(
+ self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None
+ ):
+ """
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, \
+ `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+ elif order == 2:
+ return self.singlestep_dpm_solver_second_update(
+ x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1
+ )
+ elif order == 3:
+ return self.singlestep_dpm_solver_third_update(
+ x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2
+ )
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+ """
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+ elif order == 2:
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ elif order == 3:
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def dpm_solver_adaptive(
+ self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'
+ ):
+ """
+ The adaptive step size solver based on singlestep DPM-Solver.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `t_T`.
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ h_init: A `float`. The initial step size (for logSNR).
+ atol: A `float`. The absolute tolerance of the solver. \
+ For image data, the default setting is 0.0078, followed [1].
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+ theta: A `float`. The safety hyperparameter for adapting the step size. \
+ The default setting is 0.9, followed [1].
+ t_err: A `float`. The tolerance for the time. \
+ We solve the diffusion ODE until the absolute error between the \
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
+
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, \
+ "Gotta go fast when generating data with score-based models," \
+ arXiv preprint arXiv:2105.14080, 2021.
+ """
+ ns = self.noise_schedule
+ s = t_T * torch.ones((x.shape[0], )).to(x)
+ lambda_s = ns.marginal_lambda(s)
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+ h = h_init * torch.ones_like(s).to(x)
+ x_prev = x
+ nfe = 0
+ if order == 2:
+ r1 = 0.5
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
+ x, s, t, r1=r1, solver_type=solver_type, **kwargs
+ )
+ elif order == 3:
+ r1, r2 = 1. / 3., 2. / 3.
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
+ x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
+ )
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
+ x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
+ )
+ else:
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+ while torch.abs((s - t_0)).mean() > t_err:
+ t = ns.inverse_lambda(lambda_s + h)
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+ E = norm_fn((x_higher - x_lower) / delta).max()
+ if torch.all(E <= 1.):
+ x = x_higher
+ s = t
+ x_prev = x_lower
+ lambda_s = ns.marginal_lambda(s)
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
+ nfe += order
+ print('adaptive solver nfe', nfe)
+ return x
+
+ def sample(
+ self,
+ x,
+ steps=20,
+ t_start=None,
+ t_end=None,
+ order=3,
+ skip_type='time_uniform',
+ method='singlestep',
+ denoise=False,
+ solver_type='dpm_solver',
+ atol=0.0078,
+ rtol=0.05,
+ ):
+ """
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+
+ =====================================================
+
+ We support the following algorithms for both noise prediction model and data prediction model:
+ - 'singlestep':
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), \
+ which combines different orders of singlestep DPM-Solver.
+ We combine all the singlestep solvers with order <= `order` \
+ to use up all the function evaluations (steps).
+ The total number of function evaluations (NFE) == `steps`.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ - If `order` == 1:
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 \
+ and 1 step of DPM-Solver-1.
+ - If `order` == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, \
+ and 1 step of singlestep DPM-Solver-2 \
+ and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 \
+ and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 \
+ and 1 step of singlestep DPM-Solver-2.
+ - 'multistep':
+ Multistep DPM-Solver with the order of `order`.
+ The total number of function evaluations (NFE) == `steps`.
+ We initialize the first `order` values by lower order multistep solvers.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ Denote K = steps.
+ - If `order` == 1:
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+ - If `order` == 3:
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, \
+ then (K - 2) step of multistep DPM-Solver-3.
+ - 'singlestep_fixed':
+ Fixed order singlestep DPM-Solver \
+ (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, \
+ with total [`steps` // `order`] * `order` NFE.
+ - 'adaptive':
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` \
+ to balance the computatation costs (NFE) and the sample quality.
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 \
+ and singlestep DPM-Solver-3.
+
+ =====================================================
+
+ Some advices for choosing the algorithm:
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+ skip_type='time_uniform', method='singlestep')
+ - For **guided sampling with large guidance scale** by DPMs:
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+ skip_type='time_uniform', method='multistep')
+
+ We support three types of `skip_type`:
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+ - 'time_quadratic': quadratic time for the time steps.
+
+ =====================================================
+ Args:
+ x: A pytorch tensor. The initial value at time `t_start`
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+ steps: A `int`. The total number of function evaluations (NFE).
+ t_start: A `float`. The starting time of the sampling.
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
+ t_end: A `float`. The ending time of the sampling.
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
+ For discrete-time DPMs:
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+ For continuous-time DPMs:
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+ order: A `int`. The order of DPM-Solver.
+ skip_type: A `str`. The type for the spacing of the time steps. \
+ 'time_uniform' or 'logSNR' or 'time_quadratic'.
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+ denoise: A `bool`. Whether to denoise at the final step. Default is False.
+ If `denoise` is True, the total NFE is (`steps` + 1).
+ solver_type: A `str`. The taylor expansion type for the solver. \
+ `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ Returns:
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
+
+ """
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == 'adaptive':
+ with torch.no_grad():
+ x = self.dpm_solver_adaptive(
+ x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type
+ )
+ elif method == 'multistep':
+ assert steps >= order
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ assert timesteps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = timesteps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in range(1, order):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x = self.multistep_dpm_solver_update(
+ x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type
+ )
+ model_prev_list.append(self.model_fn(x, vec_t))
+ t_prev_list.append(vec_t)
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
+ for step in range(order, steps + 1):
+ vec_t = timesteps[step].expand(x.shape[0])
+ x = self.multistep_dpm_solver_update(
+ x, model_prev_list, t_prev_list, vec_t, order, solver_type=solver_type
+ )
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ model_prev_list[-1] = self.model_fn(x, vec_t)
+ elif method in ['singlestep', 'singlestep_fixed']:
+ if method == 'singlestep':
+ orders = self.get_orders_for_singlestep_solver(steps=steps, order=order)
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ elif method == 'singlestep_fixed':
+ K = steps // order
+ orders = [
+ order,
+ ] * K
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=(K * order), device=device)
+ with torch.no_grad():
+ i = 0
+ for order in orders:
+ vec_s, vec_t = timesteps[i].expand(x.shape[0]), timesteps[i + order].expand(x.shape[0])
+ h = self.noise_schedule.marginal_lambda(timesteps[i + order]
+ ) - self.noise_schedule.marginal_lambda(timesteps[i])
+ r1 = None if order <= 1 else (
+ self.noise_schedule.marginal_lambda(timesteps[i + 1]) -
+ self.noise_schedule.marginal_lambda(timesteps[i])
+ ) / h
+ r2 = None if order <= 2 else (
+ self.noise_schedule.marginal_lambda(timesteps[i + 2]) -
+ self.noise_schedule.marginal_lambda(timesteps[i])
+ ) / h
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+ i += order
+ if denoise:
+ x = self.denoise_fn(x, torch.ones((x.shape[0], )).to(device) * t_0)
+ return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. \
+ (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels
+ (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K),
+ torch.tensor(K - 2, device=x.device),
+ cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K),
+ torch.tensor(K - 2, device=x.device),
+ cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(..., ) + (None, ) * (dims - 1)]
diff --git a/ding/torch_utils/network/activation.py b/ding/torch_utils/network/activation.py
index b3c8fcda4c..e97d6645b2 100644
--- a/ding/torch_utils/network/activation.py
+++ b/ding/torch_utils/network/activation.py
@@ -159,6 +159,7 @@ def build_activation(activation: str, inplace: bool = None) -> nn.Module:
"sigmoid": nn.Sigmoid(),
"softplus": nn.Softplus(),
"elu": nn.ELU(),
+ "silu": torch.nn.SiLU(inplace=inplace),
"square": Lambda(lambda x: x ** 2),
"identity": Lambda(lambda x: x),
}
diff --git a/ding/torch_utils/network/res_block.py b/ding/torch_utils/network/res_block.py
index a698869b08..82908ce509 100644
--- a/ding/torch_utils/network/res_block.py
+++ b/ding/torch_utils/network/res_block.py
@@ -153,3 +153,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.dropout is not None:
x = self.dropout(x)
return x
+
+
+class TemporalSpatialResBlock(nn.Module):
+ """
+ Overview:
+ Residual Block using MLP layers for both temporal and spatial input.
+ t → time_mlp → h1 → dense2 → h2 → out
+ ↗+ ↗+
+ x → dense1 → ↗
+ ↘ ↗
+ → modify_x → → → →
+ """
+
+ def __init__(self, input_dim, output_dim, t_dim=128, activation=torch.nn.SiLU()):
+ """
+ Overview:
+ Init the temporal spatial residual block.
+ Arguments:
+ - input_dim (:obj:`int`): The number of channels in the input tensor.
+ - output_dim (:obj:`int`): The number of channels in the output tensor.
+ - t_dim (:obj:`int`): The dimension of the temporal input.
+ - activation (:obj:`nn.Module`): The optional activation function.
+ """
+ super().__init__()
+ # temporal input is the embedding of time, which is a Gaussian Fourier Feature tensor
+ self.time_mlp = nn.Sequential(
+ activation,
+ nn.Linear(t_dim, output_dim),
+ )
+ self.dense1 = nn.Sequential(nn.Linear(input_dim, output_dim), activation)
+ self.dense2 = nn.Sequential(nn.Linear(output_dim, output_dim), activation)
+ self.modify_x = nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity()
+
+ def forward(self, x, t) -> torch.Tensor:
+ """
+ Overview:
+ Return the redisual block output.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ - t (:obj:`torch.Tensor`): The temporal input tensor.
+ """
+ h1 = self.dense1(x) + self.time_mlp(t)
+ h2 = self.dense2(h1)
+ return h2 + self.modify_x(x)
diff --git a/dizoo/d4rl/config/halfcheetah_medium_expert_qgpo_config.py b/dizoo/d4rl/config/halfcheetah_medium_expert_qgpo_config.py
new file mode 100644
index 0000000000..787a421d76
--- /dev/null
+++ b/dizoo/d4rl/config/halfcheetah_medium_expert_qgpo_config.py
@@ -0,0 +1,51 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='halfcheetah_medium_expert_v2_QGPO_seed0',
+ env=dict(
+ env_id="halfcheetah-medium-expert-v2",
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ ),
+ dataset=dict(
+ env_id="halfcheetah-medium-expert-v2",
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ #load_path='./halfcheetah_medium_expert_v2_QGPO_seed0/ckpt/iteration_600000.pth.tar',
+ model=dict(
+ qgpo_critic=dict(
+ alpha=3,
+ q_alpha=1,
+ ),
+ device='cuda',
+ obs_dim=17,
+ action_dim=6,
+ ),
+ learn=dict(
+ learning_rate=1e-4,
+ batch_size=4096,
+ batch_size_q=256,
+ M=16,
+ diffusion_steps=15,
+ behavior_policy_stop_training_iter=600000,
+ energy_guided_policy_begin_training_iter=600000,
+ q_value_stop_training_iter=1100000,
+ ),
+ eval=dict(
+ guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0],
+ diffusion_steps=15,
+ evaluator=dict(eval_freq=50000, ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='qgpo',
+ ),
+)
+create_config = EasyDict(create_config)
diff --git a/dizoo/d4rl/config/hopper_medium_expert_qgpo_config.py b/dizoo/d4rl/config/hopper_medium_expert_qgpo_config.py
new file mode 100644
index 0000000000..b0941bc13c
--- /dev/null
+++ b/dizoo/d4rl/config/hopper_medium_expert_qgpo_config.py
@@ -0,0 +1,51 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='hopper_medium_expert_v2_QGPO_seed0',
+ env=dict(
+ env_id="hopper-medium-expert-v2",
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ ),
+ dataset=dict(
+ env_id="hopper-medium-expert-v2",
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ #load_path='./hopper_medium_expert_v2_QGPO_seed0/ckpt/iteration_600000.pth.tar',
+ model=dict(
+ qgpo_critic=dict(
+ alpha=3,
+ q_alpha=1,
+ ),
+ device='cuda',
+ obs_dim=11,
+ action_dim=3,
+ ),
+ learn=dict(
+ learning_rate=1e-4,
+ batch_size=4096,
+ batch_size_q=256,
+ M=16,
+ diffusion_steps=15,
+ behavior_policy_stop_training_iter=600000,
+ energy_guided_policy_begin_training_iter=600000,
+ q_value_stop_training_iter=1100000,
+ ),
+ eval=dict(
+ guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0],
+ diffusion_steps=15,
+ evaluator=dict(eval_freq=50000, ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='qgpo',
+ ),
+)
+create_config = EasyDict(create_config)
diff --git a/dizoo/d4rl/config/walker2d_medium_expert_qgpo_config.py b/dizoo/d4rl/config/walker2d_medium_expert_qgpo_config.py
new file mode 100644
index 0000000000..b4b3bd7bb6
--- /dev/null
+++ b/dizoo/d4rl/config/walker2d_medium_expert_qgpo_config.py
@@ -0,0 +1,51 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='walker2d_medium_expert_v2_QGPO_seed0',
+ env=dict(
+ env_id="walker2d-medium-expert-v2",
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ ),
+ dataset=dict(
+ env_id="walker2d-medium-expert-v2",
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ #load_path='./walker2d_medium_expert_v2_QGPO_seed0/ckpt/iteration_600000.pth.tar',
+ model=dict(
+ qgpo_critic=dict(
+ alpha=3,
+ q_alpha=1,
+ ),
+ device='cuda',
+ obs_dim=17,
+ action_dim=6,
+ ),
+ learn=dict(
+ learning_rate=1e-4,
+ batch_size=4096,
+ batch_size_q=256,
+ M=16,
+ diffusion_steps=15,
+ behavior_policy_stop_training_iter=600000,
+ energy_guided_policy_begin_training_iter=600000,
+ q_value_stop_training_iter=1100000,
+ ),
+ eval=dict(
+ guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0],
+ diffusion_steps=15,
+ evaluator=dict(eval_freq=50000, ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='qgpo',
+ ),
+)
+create_config = EasyDict(create_config)