Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zlx): League training with slime volley env #23

Merged
merged 12 commits into from
Oct 8, 2021
Merged
4 changes: 2 additions & 2 deletions ding/model/template/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def __init__(
else:
self.actor = [self.actor_encoder, self.actor_head]
self.critic = [self.critic_encoder, self.critic_head]
# for convenience of call some apis(such as: self.critic.parameters()), but may cause
# misunderstanding when print(self)
# Convenient for calling some apis (e.g. self.critic.parameters()),
# but may cause misunderstanding when `print(self)`
self.actor = nn.ModuleList(self.actor)
self.critic = nn.ModuleList(self.critic)

Expand Down
2 changes: 1 addition & 1 deletion dizoo/atari/config/serial/pong/pong_ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
),
policy=dict(
cuda=True,
# (bool) whether to use on-policy training pipeline(on-policy means behaviour policy and training policy are the same)
on_policy=False,
# (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
stop_value=195,
),
policy=dict(
random_collect_size=100,
cuda=False,
on_policy=True,
continuous=False,
Expand Down
2 changes: 1 addition & 1 deletion dizoo/league_demo/league_demo_ppo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
main_player = league.get_player_by_id(main_key)
main_learner = learners[main_key]
main_collector = collectors[main_key]
# collect_mode ppo use multimonial sample for selecting action
# collect_mode ppo use multinomial sample for selecting action
evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator1_cfg.stop_value = cfg.env.stop_value[0]
evaluator1 = BattleInteractionSerialEvaluator(
Expand Down
2 changes: 1 addition & 1 deletion dizoo/league_demo/selfplay_demo_ppo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
tb_logger,
exp_name=cfg.exp_name
)
# collect_mode ppo use multimonial sample for selecting action
# collect_mode ppo use multinomial sample for selecting action
evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator1_cfg.stop_value = cfg.env.stop_value[0]
evaluator1 = BattleInteractionSerialEvaluator(
Expand Down
Empty file added dizoo/slime_volley/__init__.py
Empty file.
1 change: 1 addition & 0 deletions dizoo/slime_volley/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .slime_volley_league_ppo_config import slime_volley_league_ppo_config
78 changes: 78 additions & 0 deletions dizoo/slime_volley/config/slime_volley_league_ppo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from easydict import EasyDict

slime_volley_league_ppo_config = dict(
exp_name="slime_volley_league_ppo",
env=dict(
collector_env_num=8,
Copy link

@zxzzz0 zxzzz0 Oct 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) CPU problem
We've tried this and found that only one core out of our 64 cores got 100% utilization. The remaining cores are nearly 2% utilization all the time. This situation remains the same even if we change collector_env_num from 8 to 64. Do you know how to fully utilize all cores? Should we change SyncSubprocessEnvManager into something else? @PaParaZz1
(2) GPU problem
We have one machine with 2 GPUs and 8 GiB memory each. Although we've changed the batch_size in learner to a 16x larger number. The memory usage for the first GPU is still only 1.6 GiB and the second GPU memory usage is always 0 GiB.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For GPU problem, I want to know what kind of multi-gpu implementation you use, torch.nn.DataParallel, torch.DistributedDataParallel or other methods? And you can open another issue to track these two problems.

Copy link

@zxzzz0 zxzzz0 Oct 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For GPU problem, we don't do any multi-gpu implementation and simply running slime_volley_league_ppo_config.py from master branch. But we want to fully utilize this single machine with multiple GPUs. I've tried adding one line to learner multi_gpu=True and wrap main with with DistContext():. However it raised errors as we reported here. It looks like it's doesn't support single machine with multiple GPUs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will add torch.DataParallel in DI-engine before 10.25, for single machine with multiple GPUs. And you can pay attention to related PR.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. In the related PR please also update slime_volley_ppo_config.py to use it.

evaluator_env_num=10,
n_evaluator_episode=100,
stop_value=0,
# Single-agent env for evaluator; Double-agent env for collector.
# Should be assigned True or False in code.
is_evaluator=None,
manager=dict(shared_memory=False, ),
env_id="SlimeVolley-v0",
),
policy=dict(
cuda=False,
continuous=False,
model=dict(
obs_shape=12,
action_shape=6,
encoder_hidden_size_list=[32, 32],
critic_head_hidden_size=32,
actor_head_hidden_size=32,
share_encoder=False,
),
learn=dict(
update_per_collect=3,
batch_size=32,
learning_rate=0.00001,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
),
collect=dict(
n_episode=128, unroll_len=1, discount_factor=1.0, gae_lambda=1.0, collector=dict(get_train_sample=True, )
),
other=dict(
league=dict(
Copy link

@zxzzz0 zxzzz0 Oct 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we enable league metric (TrueSkill) here as we did in the league demo. This will be useful to see the training progress.

player_category=['default'],
path_policy="slime_volley_league_ppo/policy",
active_players=dict(
main_player=1,
main_exploiter=1,
league_exploiter=1,
),
main_player=dict(
one_phase_step=200,
branch_probs=dict(
pfsp=0.5,
sp=1.0,
),
strong_win_rate=0.7,
),
main_exploiter=dict(
one_phase_step=200,
branch_probs=dict(main_players=1.0, ),
strong_win_rate=0.7,
min_valid_win_rate=0.3,
),
league_exploiter=dict(
one_phase_step=200,
branch_probs=dict(pfsp=1.0, ),
strong_win_rate=0.7,
mutate_prob=0.0,
),
use_pretrain=False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PaParaZz1 Follow up on this. Is it possible that we can have another new config like xxx_league_PPO_using_pretrain_config.py, where use_pretrain is set to True rather than False here so that we don't train from scratch (Motivation is that training from scratch can be slow and we believe using pretrain model can benefit the training)

In this new config each time the model is mutated, it will be reset to the pretrain model. Assume you obtained this pretrain model offline by running PPO to train with the built-in bot dizoo/slime_volley/config/slime_volley_ppo_config.py for a very short period of time.

use_pretrain_init_historical=False,
payoff=dict(
type='battle',
decay=0.99,
min_win_rate_games=8,
)
),
),
),
)
slime_volley_league_ppo_config = EasyDict(slime_volley_league_ppo_config)
57 changes: 57 additions & 0 deletions dizoo/slime_volley/config/slime_volley_ppo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from easydict import EasyDict
from ding.entry import serial_pipeline_onpolicy

slime_volley_ppo_config = dict(
exp_name='slime_volley_ppo',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
agent_vs_agent=False,
stop_value=1000000,
env_id="SlimeVolley-v0",
),
policy=dict(
cuda=False,
on_policy=True,
continuous=False,
model=dict(
obs_shape=12,
action_shape=6,
encoder_hidden_size_list=[64, 64],
critic_head_hidden_size=64,
actor_head_hidden_size=64,
share_encoder=False,
),
learn=dict(
epoch_per_collect=5,
batch_size=64,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
),
collect=dict(
n_sample=4096,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
),
)
slime_volley_ppo_config = EasyDict(slime_volley_ppo_config)
main_config = slime_volley_ppo_config
slime_volley_ppo_create_config = dict(
env=dict(
type='slime_volley',
import_names=['dizoo.slime_volley.envs.slime_volley_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
)
slime_volley_ppo_create_config = EasyDict(slime_volley_ppo_create_config)
create_config = slime_volley_ppo_create_config


if __name__ == "__main__":
serial_pipeline_onpolicy([main_config, create_config], seed=0)
93 changes: 93 additions & 0 deletions dizoo/slime_volley/entry/slime_volley_selfplay_ppo_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
import gym
import numpy as np
import copy
import torch
from tensorboardX import SummaryWriter
from functools import partial

from ding.config import compile_config
from ding.worker import BaseLearner, Episode1v1Collector, NaiveReplayBuffer, BaseSerialEvaluator
from ding.envs import BaseEnvManager
from ding.policy import PPOPolicy
from ding.model import VAC
from ding.utils import set_pkg_seed
from dizoo.slime_volley.envs import SlimeVolleyEnv
from dizoo.slime_volley.config import slime_volley_league_ppo_config


def main(cfg, seed=0, max_iterations=int(1e10)):
cfg.exp_name = 'slime_volley_selfplay_ppo'
cfg = compile_config(
cfg,
BaseEnvManager,
PPOPolicy,
BaseLearner,
Episode1v1Collector,
BaseSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env_cfg = copy.deepcopy(cfg.env)
collector_env_cfg.is_evaluator = False
evaluator_env_cfg = copy.deepcopy(cfg.env)
evaluator_env_cfg.is_evaluator = True
collector_env = BaseEnvManager(
env_fn=[partial(SlimeVolleyEnv, collector_env_cfg) for _ in range(collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManager(
env_fn=[partial(SlimeVolleyEnv, evaluator_env_cfg) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)

collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

model1 = VAC(**cfg.policy.model)
policy1 = PPOPolicy(cfg.policy, model=model1)
model2 = VAC(**cfg.policy.model)
policy2 = PPOPolicy(cfg.policy, model=model2)

tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner1 = BaseLearner(
cfg.policy.learn.learner, policy1.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='learner1'
)
learner2 = BaseLearner(
cfg.policy.learn.learner, policy2.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='learner2'
)
collector = Episode1v1Collector(
cfg.policy.collect.collector,
collector_env, [policy1.collect_mode, policy2.collect_mode],
tb_logger,
exp_name=cfg.exp_name
)
evaluator_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator_cfg.stop_value = cfg.env.stop_value
evaluator = BaseSerialEvaluator(
evaluator_cfg,
evaluator_env,
policy1.eval_mode,
tb_logger,
exp_name=cfg.exp_name,
instance_name='builtin_ai_evaluator'
)

stop_flag = False
for _ in range(max_iterations):
if evaluator.should_eval(learner1.train_iter):
stop_flag, reward = evaluator.eval(learner1.save_checkpoint, learner1.train_iter, collector.envstep)
tb_logger.add_scalar('fixed_evaluator_step/reward_mean', reward, collector.envstep)
if stop_flag:
break
train_data, _ = collector.collect(train_iter=learner1.train_iter)
for data in train_data:
for d in data:
d['adv'] = d['reward']
for i in range(cfg.policy.learn.update_per_collect):
learner1.train(train_data[0], collector.envstep)
learner2.train(train_data[1], collector.envstep)


if __name__ == "__main__":
main(slime_volley_league_ppo_config)
1 change: 1 addition & 0 deletions dizoo/slime_volley/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .slime_volley_env import SlimeVolleyEnv
Loading