Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polish(nyz): polish offpolicy RL multi-gpu DDP training #679

Merged
merged 5 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ding/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from easydict import EasyDict
from copy import deepcopy

from ding.utils import deep_merge_dicts
from ding.utils import deep_merge_dicts, get_rank
from ding.envs import get_env_cls, get_env_manager_cls, BaseEnvManager
from ding.policy import get_policy_cls
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, Coordinator, \
Expand Down Expand Up @@ -459,7 +459,7 @@ def compile_config(
cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode
if 'exp_name' not in cfg:
cfg.exp_name = 'default_experiment'
if save_cfg:
if save_cfg and get_rank() == 0:
if os.path.exists(cfg.exp_name) and renew_dir:
cfg.exp_name += datetime.datetime.now().strftime("_%y%m%d_%H%M%S")
try:
Expand Down
33 changes: 17 additions & 16 deletions ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
create_serial_collector, create_serial_evaluator
from ding.config import read_config, compile_config
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from ding.utils import set_pkg_seed, get_rank
from .utils import random_collect


Expand Down Expand Up @@ -61,7 +61,7 @@ def serial_pipeline(
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
Expand Down Expand Up @@ -119,18 +119,19 @@ def serial_pipeline(

# Learner's after_run hook.
learner.call_hook('after_run')
import time
import pickle
import numpy as np
with open(os.path.join(cfg.exp_name, 'result.pkl'), 'wb') as f:
eval_value_raw = [d['eval_episode_return'] for d in eval_info]
final_data = {
'stop': stop,
'env_step': collector.envstep,
'train_iter': learner.train_iter,
'eval_value': np.mean(eval_value_raw),
'eval_value_raw': eval_value_raw,
'finish_time': time.ctime(),
}
pickle.dump(final_data, f)
if get_rank() == 0:
import time
import pickle
import numpy as np
with open(os.path.join(cfg.exp_name, 'result.pkl'), 'wb') as f:
eval_value_raw = [d['eval_episode_return'] for d in eval_info]
final_data = {
'stop': stop,
'env_step': collector.envstep,
'train_iter': learner.train_iter,
'eval_value': np.mean(eval_value_raw),
'eval_value_raw': eval_value_raw,
'finish_time': time.ctime(),
}
pickle.dump(final_data, f)
return policy
4 changes: 2 additions & 2 deletions ding/framework/middleware/tests/test_distributer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def collector_context(ctx: OnlineRLContext):
task.run(max_step=3)


@pytest.mark.unittest
@pytest.mark.tmp
def test_context_exchanger():
Parallel.runner(n_parallel_workers=2)(context_exchanger_main)

Expand Down Expand Up @@ -170,7 +170,7 @@ def pred(ctx):
task.run(2)


@pytest.mark.unittest
@pytest.mark.tmp
def test_model_exchanger():
Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main)

Expand Down
9 changes: 6 additions & 3 deletions ding/policy/policy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]:
actions = {}
for env_id in data:
if not isinstance(action_space, list):
action = torch.as_tensor(action_space.sample())
if isinstance(action_space, gym.spaces.MultiDiscrete):
action = [torch.LongTensor([v]) for v in action]
if isinstance(action_space, gym.spaces.Discrete):
action = torch.LongTensor([action_space.sample()])
elif isinstance(action_space, gym.spaces.MultiDiscrete):
action = [torch.LongTensor([v]) for v in action_space.sample()]
else:
action = torch.as_tensor(action_space.sample())
actions[env_id] = {'action': action}
elif 'global_state' in data[env_id].keys():
# for smac
Expand Down
2 changes: 1 addition & 1 deletion ding/torch_utils/data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def to_item(data: Any, ignore_error: bool = True) -> Any:
if ignore_error:
try:
new_data[k] = to_item(v)
except ValueError:
except (ValueError, RuntimeError):
pass
else:
new_data[k] = to_item(v)
Expand Down
5 changes: 3 additions & 2 deletions ding/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
from .fast_copy import fastcopy
from .bfs_helper import get_vi_sequence

if ding.enable_linklink:
if ding.enable_linklink: # False as default
from .linklink_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \
allreduce, broadcast, DistContext, allreduce_async, synchronize
else:
from .pytorch_ddp_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \
allreduce, broadcast, DistContext, allreduce_async, synchronize
allreduce, broadcast, DDPContext, allreduce_async, synchronize, reduce_data, broadcast_object_list, \
to_ddp_config, allreduce_data
44 changes: 41 additions & 3 deletions ding/utils/pytorch_ddp_dist_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Callable, Tuple, List, Any
from typing import Callable, Tuple, List, Any, Union
from easydict import EasyDict

import os
import numpy as np
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -30,6 +31,7 @@ def get_world_size() -> int:

broadcast = dist.broadcast
allgather = dist.all_gather
broadcast_object_list = dist.broadcast_object_list


def allreduce(x: torch.Tensor) -> None:
Expand All @@ -42,6 +44,35 @@ def allreduce_async(name: str, x: torch.Tensor) -> None:
dist.all_reduce(x, async_op=True)


def reduce_data(x: Union[int, float, torch.Tensor], dst: int) -> Union[int, float, torch.Tensor]:
if np.isscalar(x):
x_tensor = torch.as_tensor([x]).cuda()
dist.reduce(x_tensor, dst)
return x_tensor.item()
elif isinstance(x, torch.Tensor):
dist.reduce(x, dst)
return x
else:
raise TypeError("not supported type: {}".format(type(x)))


def allreduce_data(x: Union[int, float, torch.Tensor], op: str) -> Union[int, float, torch.Tensor]:
assert op in ['sum', 'avg'], op
if np.isscalar(x):
x_tensor = torch.as_tensor([x]).cuda()
dist.all_reduce(x_tensor)
if op == 'avg':
x_tensor.div_(get_world_size())
return x_tensor.item()
elif isinstance(x, torch.Tensor):
dist.all_reduce(x)
if op == 'avg':
x.div_(get_world_size())
return x
else:
raise TypeError("not supported type: {}".format(type(x)))


synchronize = torch.cuda.synchronize


Expand Down Expand Up @@ -119,7 +150,7 @@ def dist_finalize() -> None:
pass


class DistContext:
class DDPContext:

def __init__(self) -> None:
pass
Expand All @@ -146,3 +177,10 @@ def simple_group_split(world_size: int, rank: int, num_groups: int) -> List:
groups.append(dist.new_group(rank_list[i]))
group_size = world_size // num_groups
return groups[rank // group_size]


def to_ddp_config(cfg: EasyDict) -> EasyDict:
w = get_world_size()
cfg.policy.learn.batch_size = int(np.ceil(cfg.policy.learn.batch_size / w))
cfg.policy.collect.n_sample = int(np.ceil(cfg.policy.collect.n_sample) / w)
return cfg
16 changes: 3 additions & 13 deletions ding/worker/collector/interaction_serial_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from collections import namedtuple
import numpy as np
import torch
import torch.distributed as dist

from ding.envs import BaseEnvManager
from ding.torch_utils import to_tensor, to_ndarray, to_item
from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY
from ding.utils import get_world_size, get_rank
from ding.utils import get_world_size, get_rank, broadcast_object_list
from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor


Expand Down Expand Up @@ -65,10 +64,7 @@ def __init__(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name
)
else:
self._logger, _ = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
)
self._tb_logger = None
self._logger, self._tb_logger = None, None # for close elegantly
self.reset(policy, env)

self._timer = EasyTimer()
Expand Down Expand Up @@ -199,12 +195,6 @@ def eval(
- stop_flag (:obj:`bool`): Whether this training program can be ended.
- return_info (:obj:`dict`): Current evaluation return information.
'''
if get_world_size() > 1:
# sum up envstep to rank0
envstep_tensor = torch.tensor(envstep).cuda()
dist.reduce(envstep_tensor, dst=0)
envstep = envstep_tensor.item()

# evaluator only work on rank0
stop_flag, return_info = False, []
if get_rank() == 0:
Expand Down Expand Up @@ -308,7 +298,7 @@ def eval(

if get_world_size() > 1:
objects = [stop_flag, return_info]
dist.broadcast_object_list(objects, src=0)
broadcast_object_list(objects, src=0)
stop_flag, return_info = objects

return_info = to_item(return_info)
Expand Down
58 changes: 41 additions & 17 deletions ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch

from ding.envs import BaseEnvManager
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \
broadcast_object_list, allreduce_data
from ding.torch_utils import to_tensor, to_ndarray
from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions

Expand Down Expand Up @@ -52,16 +53,27 @@ def __init__(
self._cfg = cfg
self._timer = EasyTimer()
self._end_flag = False
self._rank = get_rank()
self._world_size = get_world_size()

if tb_logger is not None:
if self._rank == 0:
if tb_logger is not None:
self._logger, _ = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name),
name=self._instance_name,
need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
)
else:
self._logger, _ = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
)
self._tb_logger = None

self.reset(policy, env)

def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
Expand Down Expand Up @@ -184,8 +196,9 @@ def close(self) -> None:
return
self._end_flag = True
self._env.close()
self._tb_logger.flush()
self._tb_logger.close()
if self._tb_logger:
self._tb_logger.flush()
self._tb_logger.close()

def __del__(self) -> None:
"""
Expand Down Expand Up @@ -231,6 +244,8 @@ def collect(
if policy_kwargs is None:
policy_kwargs = {}
collected_sample = 0
collected_step = 0
collected_episode = 0
return_data = []

while collected_sample < n_sample:
Expand Down Expand Up @@ -276,7 +291,7 @@ def collect(
transition['collect_iter'] = train_iter
self._traj_buffer[env_id].append(transition)
self._env_info[env_id]['step'] += 1
self._total_envstep_count += 1
collected_step += 1
# prepare data
if timestep.done or len(self._traj_buffer[env_id]) == self._traj_len:
# If policy is r2d2:
Expand All @@ -294,7 +309,6 @@ def collect(
transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs)
train_sample = self._policy.get_train_sample(transitions)
return_data.extend(train_sample)
self._total_train_sample_count += len(train_sample)
self._env_info[env_id]['train_sample'] += len(train_sample)
collected_sample += len(train_sample)
self._traj_buffer[env_id].clear()
Expand All @@ -303,7 +317,7 @@ def collect(

# If env is done, record episode info and reset
if timestep.done:
self._total_episode_count += 1
collected_episode += 1
reward = timestep.info['eval_episode_return']
info = {
'reward': reward,
Expand All @@ -315,6 +329,17 @@ def collect(
# Env reset is done by env_manager automatically
self._policy.reset([env_id])
self._reset_stat(env_id)
collected_duration = sum([d['time'] for d in self._episode_info])
# reduce data when enables DDP
if self._world_size > 1:
collected_sample = allreduce_data(collected_sample, 'sum')
collected_step = allreduce_data(collected_step, 'sum')
collected_episode = allreduce_data(collected_episode, 'sum')
collected_duration = allreduce_data(collected_duration, 'sum')
self._total_envstep_count += collected_step
self._total_episode_count += collected_episode
self._total_duration += collected_duration
self._total_train_sample_count += collected_sample
# log
if record_random_collect: # default is true, but when random collect, record_random_collect is False
self._output_log(train_iter)
Expand All @@ -333,19 +358,20 @@ def collect(
def _output_log(self, train_iter: int) -> None:
"""
Overview:
Print the output log information. You can refer to Docs/Best Practice/How to understand\
training generated folders/Serial mode/log/collector for more details.
Print the output log information. You can refer to the docs of `Best Practice` to understand \
the training generated logs and tensorboards.
Arguments:
- train_iter (:obj:`int`): the number of training iteration.
"""
if self._rank != 0:
return
if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0:
self._last_train_iter = train_iter
episode_count = len(self._episode_info)
envstep_count = sum([d['step'] for d in self._episode_info])
train_sample_count = sum([d['train_sample'] for d in self._episode_info])
duration = sum([d['time'] for d in self._episode_info])
episode_return = [d['reward'] for d in self._episode_info]
self._total_duration += duration
info = {
'episode_count': episode_count,
'envstep_count': envstep_count,
Expand All @@ -355,15 +381,13 @@ def _output_log(self, train_iter: int) -> None:
'avg_envstep_per_sec': envstep_count / duration,
'avg_train_sample_per_sec': train_sample_count / duration,
'avg_episode_per_sec': episode_count / duration,
'collect_time': duration,
'reward_mean': np.mean(episode_return),
'reward_std': np.std(episode_return),
'reward_max': np.max(episode_return),
'reward_min': np.min(episode_return),
'total_envstep_count': self._total_envstep_count,
'total_train_sample_count': self._total_train_sample_count,
'total_episode_count': self._total_episode_count,
'total_duration': self._total_duration,
# 'each_reward': episode_return,
}
self._episode_info.clear()
Expand Down
Loading