Skip to content

Commit

Permalink
solve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Super1ce committed May 25, 2023
2 parents 9162870 + 5804402 commit 5431ba8
Show file tree
Hide file tree
Showing 169 changed files with 1,497 additions and 600 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
matrix:
os:
- 'ubuntu-18.04'
- ubuntu-latest
python-version: [3.7]

steps:
Expand Down
22 changes: 22 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
2023.05.25(v0.4.8)
- env: fix gym hybrid reward dtype bug (#664)
- env: fix atari env id noframeskip bug (#655)
- env: fix typo in gym any_trading env (#654)
- env: update td3bc d4rl config (#659)
- env: polish bipedalwalker config
- algo: add EDAC offline RL algorithm (#639)
- algo: add LN and GN norm_type support in ResBlock (#660)
- algo: add normal value norm baseline for PPOF (#658)
- algo: polish last layer init/norm in MLP (#650)
- algo: polish TD3 monitor variable
- feature: add MAPPO/MASAC task example (#661)
- feature: add PPO example for complex env observation (#644)
- feature: add barrier middleware (#570)
- fix: abnormal collector log and add record_random_collect option (#662)
- fix: to_item compatibility bug (#646)
- fix: trainer dtype transform compatibility bug
- fix: pettingzoo 1.23.0 compatibility bug
- fix: ensemble head unittest bug
- style: fix incompatible gym version bug in Dockerfile.env (#653)
- style: add more algorithm docs

2023.04.11(v0.4.7)
- env: add dmc2gym env support and baseline (#451)
- env: update pettingzoo to the latest version (#597)
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
[![Contributors](https://img.shields.io/github/contributors/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/graphs/contributors)
[![GitHub license](https://img.shields.io/github/license/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/blob/master/LICENSE)

Updated on 2023.04.11 DI-engine-v0.4.7
Updated on 2023.05.25 DI-engine-v0.4.8


## Introduction to DI-engine
Expand Down Expand Up @@ -252,6 +252,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 50 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
| 51 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
| 52 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
| 53 | [edac](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)<br>[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py |
</details>


Expand Down
2 changes: 1 addition & 1 deletion conda/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{% set data = load_setup_py_data() %}
package:
name: di-engine
version: v0.4.7
version: v0.4.8

source:
path: ..
Expand Down
2 changes: 1 addition & 1 deletion ding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

__TITLE__ = 'DI-engine'
__VERSION__ = 'v0.4.7'
__VERSION__ = 'v0.4.8'
__DESCRIPTION__ = 'Decision AI Engine'
__AUTHOR__ = "OpenDILab Contributors"
__AUTHOR_EMAIL__ = "opendilab@pjlab.org.cn"
Expand Down
2 changes: 1 addition & 1 deletion ding/bonus/ppof.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
action_shape = action_space.shape

# Three types of value normalization is supported currently
assert self.cfg.value_norm in ['popart', 'value_rescale', 'symlog']
assert self.cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline']
if model is None:
if self.cfg.value_norm != 'popart':
model = PPOFModel(
Expand Down
8 changes: 3 additions & 5 deletions ding/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def compile_collector_config(
other=dict(replay_buffer=dict()),
)
policy_config_template = EasyDict(policy_config_template)
env_config_template = dict(manager=dict(), stop_value=int(1e10))
env_config_template = dict(manager=dict(), stop_value=int(1e10), n_evaluator_episode=4)
env_config_template = EasyDict(env_config_template)


Expand Down Expand Up @@ -451,14 +451,12 @@ def compile_config(
default_config['reward_model'] = reward_model_config
if len(world_model_config) > 0:
default_config['world_model'] = world_model_config
stop_value_flag = 'stop_value' in cfg.env
cfg = deep_merge_dicts(default_config, cfg)
cfg.seed = seed
# check important key in config
if evaluator in [InteractionSerialEvaluator, BattleInteractionSerialEvaluator]: # env interaction evaluation
if stop_value_flag: # data generation task doesn't need these fields
cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode
cfg.policy.eval.evaluator.stop_value = cfg.env.stop_value
cfg.policy.eval.evaluator.stop_value = cfg.env.stop_value
cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode
if 'exp_name' not in cfg:
cfg.exp_name = 'default_experiment'
if save_cfg:
Expand Down
4 changes: 3 additions & 1 deletion ding/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def random_collect(
if policy_cfg.collect.collector.type == 'episode':
new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs)
else:
new_data = collector.collect(n_sample=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs)
new_data = collector.collect(
n_sample=policy_cfg.random_collect_size, record_random_collect=False, policy_kwargs=collect_kwargs
) # 'record_random_collect=False' means random collect without output log
if postprocess_data_fn is not None:
new_data = postprocess_data_fn(new_data)
replay_buffer.push(new_data, cur_collector_envstep=0)
Expand Down
42 changes: 42 additions & 0 deletions ding/example/edac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import gym
from ditk import logging
from ding.model import QACEnsemble
from ding.policy import EDACPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import create_dataset
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger
from ding.utils import set_pkg_seed
from dizoo.d4rl.envs import D4RLEnv
from dizoo.d4rl.config.halfcheetah_medium_edac_config import main_config, create_config


def main():
# If you don't have offline data, you need to prepare if first and set the data_path in config
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OfflineRLContext()):
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: D4RLEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

dataset = create_dataset(cfg)
model = QACEnsemble(**cfg.policy.model)
policy = EDACPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(offline_data_fetcher(cfg, dataset))
task.use(trainer(cfg, policy.learn_mode))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1e4))
task.use(offline_logger())
task.run()


if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions ding/example/mappo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import gym
from ditk import logging
from ding.model import MAVAC
from ding.policy import PPOPolicy
from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
gae_estimator, online_logger, termination_checker
from ding.utils import set_pkg_seed
from dizoo.petting_zoo.config.ptz_simple_spread_mappo_config import main_config, create_config
from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv


def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = SubprocessEnvManagerV2(
env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = SubprocessEnvManagerV2(
env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = MAVAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(gae_estimator(cfg, policy.collect_mode))
task.use(multistep_trainer(policy.learn_mode, log_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
task.use(online_logger(train_show_freq=10))
task.use(termination_checker(max_env_step=int(1e6)))
task.run()


if __name__ == "__main__":
main()
49 changes: 49 additions & 0 deletions ding/example/masac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import gym
from ditk import logging
from ding.model import MAQAC
from ding.policy import SACDiscretePolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, CkptSaver, \
data_pusher, online_logger, termination_checker, eps_greedy_handler
from ding.utils import set_pkg_seed
from dizoo.petting_zoo.config.ptz_simple_spread_masac_config import main_config, create_config
from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv


def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = MAQAC(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = SACDiscretePolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(
StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
)
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_, log_freq=100))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
task.use(online_logger(train_show_freq=10))
task.use(termination_checker(max_env_step=int(1e6)))
task.run()


if __name__ == "__main__":
main()
15 changes: 9 additions & 6 deletions ding/example/td3.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import gym
from ditk import logging
from ding.model.template.qac import QAC
from ding.model import QAC
from ding.policy import TD3Policy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.envs import BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task
from ding.framework import task, ding_init
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, CkptSaver
from ding.framework.middleware import data_pusher, StepCollector, interaction_evaluator, \
CkptSaver, OffPolicyLearner, termination_checker, online_logger
from ding.utils import set_pkg_seed
from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv
from dizoo.classic_control.pendulum.config.pendulum_td3_config import main_config, create_config
Expand All @@ -16,6 +16,7 @@
def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
Expand All @@ -28,7 +29,7 @@ def main():

model = QAC(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = TD3Policy(cfg.policy, model)
policy = TD3Policy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(
Expand All @@ -37,6 +38,8 @@ def main():
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(termination_checker(max_train_iter=10000))
task.use(online_logger())
task.run()


Expand Down
2 changes: 1 addition & 1 deletion ding/framework/middleware/functional/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _logger(ctx: "OnlineRLContext"):
else:
output = ctx.train_output
for k, v in output.items():
if k in ['priority']:
if k in ['priority', 'td_error_priority']:
continue
if "[scalars]" in k:
new_k = k.split(']')[-1]
Expand Down
9 changes: 5 additions & 4 deletions ding/framework/middleware/functional/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from ding.framework import task, OfflineRLContext, OnlineRLContext


def trainer(cfg: EasyDict, policy: Policy) -> Callable:
def trainer(cfg: EasyDict, policy: Policy, log_freq: int = 100) -> Callable:
"""
Overview:
The middleware that executes a single training process.
Arguments:
- cfg (:obj:`EasyDict`): Config.
- policy (:obj:`Policy`): The policy to be trained in step-by-step mode.
- log_freq (:obj:`int`): The frequency (iteration) of showing log.
"""

def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
Expand All @@ -31,7 +32,7 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
return
data = ctx.train_data
train_output = policy.forward(ctx.train_data)
if ctx.train_iter % cfg.policy.learn.learner.hook.log_show_after_iter == 0:
if ctx.train_iter % log_freq == 0:
if isinstance(ctx, OnlineRLContext):
logging.info(
'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format(
Expand All @@ -50,13 +51,13 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
return _train


def multistep_trainer(policy: Policy, log_freq: int) -> Callable:
def multistep_trainer(policy: Policy, log_freq: int = 100) -> Callable:
"""
Overview:
The middleware that executes training for a target num of steps.
Arguments:
- policy (:obj:`Policy`): The policy specialized for multi-step training.
- int (:obj:`int`): The frequency (iteration) of showing log.
- log_freq (:obj:`int`): The frequency (iteration) of showing log.
"""
last_log_iter = -1

Expand Down
12 changes: 8 additions & 4 deletions ding/framework/middleware/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

if TYPE_CHECKING:
from ding.framework import Context, OnlineRLContext
from ding.policy import Policy
from ding.reward_model import BaseRewardModel


class OffPolicyLearner:
Expand All @@ -25,17 +27,19 @@ def __new__(cls, *args, **kwargs):
def __init__(
self,
cfg: EasyDict,
policy,
policy: 'Policy',
buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]],
reward_model=None
reward_model: Optional['BaseRewardModel'] = None,
log_freq: int = 100,
) -> None:
"""
Arguments:
- cfg (:obj:`EasyDict`): Config.
- policy (:obj:`Policy`): The policy to be trained.
- buffer\_ (:obj:`Buffer`): The replay buffer to store the data for training.
- reward_model (:obj:`nn.Module`): Additional reward estimator likes RND, ICM, etc. \
- buffer (:obj:`Buffer`): The replay buffer to store the data for training.
- reward_model (:obj:`BaseRewardModel`): Additional reward estimator likes RND, ICM, etc. \
default to None.
- log_freq (:obj:`int`): The frequency (iteration) of showing log.
"""
self.cfg = cfg
self._fetcher = task.wrap(offpolicy_data_fetcher(cfg, buffer_))
Expand Down
2 changes: 1 addition & 1 deletion ding/model/common/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, \
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
independent_normal_dist, AttentionPolicyHead, PopArtVHead
independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
from .utils import create_model
Loading

0 comments on commit 5431ba8

Please sign in to comment.