Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(lyd): refactor dt_policy in new pipeline and add img input support #693

Merged
merged 27 commits into from
Aug 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from dizoo.classic_control.pendulum.config.pendulum_cql_config import pendulum_cql_config, pendulum_cql_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_qrdqn_generation_data_config import cartpole_qrdqn_generation_data_config, cartpole_qrdqn_generation_data_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_cql_config import cartpole_discrete_cql_config, cartpole_discrete_cql_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_dt_config import cartpole_discrete_dt_config, cartpole_discrete_dt_create_config # noqa
from dizoo.classic_control.pendulum.config.pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config # noqa
from dizoo.classic_control.pendulum.config.pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config # noqa
from dizoo.classic_control.pendulum.config.pendulum_ibc_config import pendulum_ibc_config, pendulum_ibc_create_config
Expand Down Expand Up @@ -621,6 +622,70 @@ def test_discrete_cql():
os.popen('rm -rf cartpole cartpole_cql')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_discrete_dt():
# train expert
config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)]
config[0].policy.learn.update_per_collect = 1
config[0].exp_name = 'dt_cartpole'
try:
serial_pipeline(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"
# collect expert data
import torch
config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)]
state_dict = torch.load('./dt_cartpole/ckpt/iteration_0.pth.tar', map_location='cpu')
try:
collect_demo_data(config, seed=0, collect_count=1000, state_dict=state_dict)
except Exception as e:
assert False, "pipeline fail"
print(repr(e))

# train dt
config = [deepcopy(cartpole_discrete_dt_config), deepcopy(cartpole_discrete_dt_create_config)]
config[0].policy.eval.evaluator.eval_freq = 5
try:
from ding.framework import task
from ding.framework.context import OfflineRLContext
from ding.envs import SubprocessEnvManagerV2, BaseEnvManagerV2
from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper
from dizoo.classic_control.cartpole.envs import CartPoleEnv
from ding.utils import set_pkg_seed
from ding.data import create_dataset
from ding.config import compile_config
from ding.model.template.dt import DecisionTransformer
from ding.policy import DTPolicy
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \
offline_data_fetcher_from_mem_c, offline_logger, termination_checker
config = compile_config(config[0], create_cfg=config[1], auto=True)
with task.start(async_mode=False, ctx=OfflineRLContext()):
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: AllinObsWrapper(CartPoleEnv(config.env)) for _ in range(config.env.evaluator_env_num)],
cfg=config.env.manager
)

set_pkg_seed(config.seed, use_cuda=config.policy.cuda)

dataset = create_dataset(config)

model = DecisionTransformer(**config.policy.model)
policy = DTPolicy(config.policy, model=model)

task.use(termination_checker(max_train_iter=1))
task.use(interaction_evaluator(config, policy.eval_mode, evaluator_env))
task.use(offline_data_fetcher_from_mem_c(config, dataset))
task.use(trainer(config, policy.learn_mode))
task.use(CkptSaver(policy, config.exp_name, train_freq=100))
task.use(offline_logger(config.exp_name))
task.run()
except Exception:
assert False, "pipeline fail"
finally:
os.popen('rm -rf cartpole cartpole_dt')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_td3_bc():
Expand Down
19 changes: 19 additions & 0 deletions ding/envs/env/tests/test_ding_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,22 @@ def test_hybrid(self):
action = ding_env_hybrid.random_action()
print('random_action', action)
assert isinstance(action, dict)

@pytest.mark.unittest
def test_AllinObsWrapper(self):
env_cfg = EasyDict(env_id='PongNoFrameskip-v4', env_wrapper='reward_in_obs')
ding_env_aio = DingEnvWrapper(cfg=env_cfg)

data = ding_env_aio.reset()
assert isinstance(data, dict)
assert 'obs' in data.keys() and 'reward' in data.keys()
assert data['obs'].shape == ding_env_aio.observation_space
while True:
action = ding_env_aio.random_action()
timestep = ding_env_aio.step(action)
# print(timestep.reward)
assert isinstance(timestep.obs, dict)
if timestep.done:
assert 'eval_episode_return' in timestep.info, timestep.info
break
print(ding_env_aio.observation_space, ding_env_aio.action_space, ding_env_aio.reward_space)
37 changes: 37 additions & 0 deletions ding/envs/env_wrappers/env_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,43 @@ def reset(self):
return self.env.reset()


@ENV_WRAPPER_REGISTRY.register('reward_in_obs')
class AllinObsWrapper(gym.Wrapper):
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
"""
Overview:
This wrapper is used in policy DT.
Set a dict {'obs': obs, 'reward': reward}
as the new wrapped observation,
which including the current obs, previous reward.
Interface:
``__init__``, ``reset``, ``step``, ``seed``
Properties:
- env (:obj:`gym.Env`): the environment to wrap.
"""

def __init__(self, env):
super().__init__(env)

def reset(self):
ret = {'obs': self.env.reset(), 'reward': np.array([0])}
self._observation_space = gym.spaces.Dict(
{
'obs': self.env.observation_space,
'reward': gym.spaces.Box(low=-np.inf, high=np.inf, dtype=np.float32)
}
)
return ret

def step(self, action):
obs, reward, done, info = self.env.step(action)
obs = {'obs': obs, 'reward': reward}
from ding.envs import BaseEnvTimestep
return BaseEnvTimestep(obs, reward, done, info)

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self.env.seed(seed, dynamic_seed)


def update_shape(obs_shape, act_shape, rew_shape, wrapper_names):
"""
Overview:
Expand Down
47 changes: 47 additions & 0 deletions ding/example/dt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import gym
from ditk import logging
from ding.model.template.decision_transformer import DecisionTransformer
from ding.policy import DTPolicy
from ding.envs import DingEnvWrapper, BaseEnvManager, BaseEnvManagerV2
from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper
from ding.data import create_dataset
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \
offline_data_fetcher, offline_logger, termination_checker, final_ctx_saver
from ding.utils import set_pkg_seed
from dizoo.box2d.lunarlander.envs.lunarlander_env import LunarLanderEnv
from dizoo.box2d.lunarlander.config.lunarlander_dt_config import main_config, create_config


def main():
# If you don't have offline data, you need to prepare if first and set the data_path in config
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OfflineRLContext()):
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: AllinObsWrapper(LunarLanderEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

dataset = create_dataset(cfg)
cfg.policy.state_mean, cfg.policy.state_std = dataset.get_state_stats()
model = DecisionTransformer(**cfg.policy.model)
policy = DTPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(offline_data_fetcher(cfg, dataset))
task.use(trainer(cfg, policy.learn_mode))
task.use(termination_checker(max_train_iter=1e5))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(offline_logger())
task.run()


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions ding/framework/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class OfflineRLContext(Context):

# common
total_step: int = 0
env_step: int = 0
train_epoch: int = 0
train_iter: int = 0
train_data: Union[Dict, List] = None
Expand Down
1 change: 1 addition & 0 deletions ding/framework/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .ckpt_handler import CkptSaver
from .distributer import ContextExchanger, ModelExchanger
from .barrier import Barrier, BarrierRuntime
from .data_fetcher import offline_data_fetcher_from_mem_c
100 changes: 100 additions & 0 deletions ding/framework/middleware/data_fetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import TYPE_CHECKING
from threading import Thread, Event
from queue import Queue
import time
import torch
import torch.distributed as dist
from easydict import EasyDict
from ding.framework import task
from ding.data import Dataset, DataLoader
from ding.utils import get_rank
import numpy as np

if TYPE_CHECKING:
from ding.framework import OfflineRLContext


class offline_data_fetcher_from_mem_c:

def __new__(cls, *args, **kwargs):
if task.router.is_active and not task.has_role(task.role.FETCHER):
return task.void()
return super(offline_data_fetcher_from_mem_c, cls).__new__(cls)

def __init__(self, cfg: EasyDict, dataset: Dataset):
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
if device != 'cpu':
stream = torch.cuda.Stream()

def producer(queue, dataset, batch_size, device, event):
torch.set_num_threads(4)
if device != 'cpu':
nonlocal stream
sbatch_size = batch_size * dist.get_world_size()
rank = get_rank()
idx_list = np.random.permutation(len(dataset))
temp_idx_list = []
for i in range(len(dataset) // sbatch_size):
temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size])
idx_iter = iter(temp_idx_list)

if device != 'cpu':
with torch.cuda.stream(stream):
while True:
if queue.full():
time.sleep(0.1)
else:
data = []
for _ in range(batch_size):
try:
data.append(dataset.__getitem__(next(idx_iter)))
except StopIteration:
del idx_iter
idx_list = np.random.permutation(len(dataset))
idx_iter = iter(idx_list)
data.append(dataset.__getitem__(next(idx_iter)))
data = [[i[j] for i in data] for j in range(len(data[0]))]
data = [torch.stack(x).to(device) for x in data]
queue.put(data)
if event.is_set():
break
else:
while True:
if queue.full():
time.sleep(0.1)
else:
data = []
for _ in range(batch_size):
try:
data.append(dataset.__getitem__(next(idx_iter)))
except StopIteration:
del idx_iter
idx_list = np.random.permutation(len(dataset))
idx_iter = iter(idx_list)
data.append(dataset.__getitem__(next(idx_iter)))
data = [[i[j] for i in data] for j in range(len(data[0]))]
data = [torch.stack(x) for x in data]
queue.put(data)
if event.is_set():
break

self.queue = Queue(maxsize=50)
self.event = Event()
self.producer_thread = Thread(
target=producer,
args=(self.queue, dataset, cfg.policy.batch_size, device, self.event),
name='cuda_fetcher_producer'
)

def __call__(self, ctx: "OfflineRLContext"):
if not self.producer_thread.is_alive():
time.sleep(5)
self.producer_thread.start()
while self.queue.empty():
time.sleep(0.001)
ctx.train_data = self.queue.get()

def __del__(self):
if self.producer_thread.is_alive():
self.event.set()
del self.queue
2 changes: 1 addition & 1 deletion ding/framework/middleware/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .trainer import trainer, multistep_trainer
from .data_processor import offpolicy_data_fetcher, data_pusher, offline_data_fetcher, offline_data_saver, \
sqil_data_pusher, buffer_saver
offline_data_fetcher_from_mem, sqil_data_pusher, buffer_saver
from .collector import inferencer, rolloutor, TransitionList
from .evaluator import interaction_evaluator, interaction_evaluator_ttorch
from .termination_checker import termination_checker, ddp_termination_checker
Expand Down
47 changes: 46 additions & 1 deletion ding/framework/middleware/functional/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type
from ding.data.buffer.middleware import PriorityExperienceReplay
from ding.framework import task
from ding.utils import get_rank

if TYPE_CHECKING:
from ding.framework import OnlineRLContext, OfflineRLContext
Expand Down Expand Up @@ -180,6 +181,51 @@ def _fetch(ctx: "OnlineRLContext"):
return _fetch


def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable:

from threading import Thread
from queue import Queue
import time
stream = torch.cuda.Stream()

def producer(queue, dataset, batch_size, device):
torch.set_num_threads(4)
nonlocal stream
idx_iter = iter(range(len(dataset)))
with torch.cuda.stream(stream):
while True:
if queue.full():
time.sleep(0.1)
else:
try:
start_idx = next(idx_iter)
except StopIteration:
del idx_iter
idx_iter = iter(range(len(dataset)))
start_idx = next(idx_iter)
data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)]
data = [[i[j] for i in data] for j in range(len(data[0]))]
data = [torch.stack(x).to(device) for x in data]
queue.put(data)

queue = Queue(maxsize=50)
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
producer_thread = Thread(
target=producer, args=(queue, dataset, cfg.policy.batch_size, device), name='cuda_fetcher_producer'
)

def _fetch(ctx: "OfflineRLContext"):
nonlocal queue, producer_thread
if not producer_thread.is_alive():
time.sleep(5)
producer_thread.start()
while queue.empty():
time.sleep(0.001)
ctx.train_data = queue.get()

return _fetch


def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable:
"""
Overview:
Expand Down Expand Up @@ -208,7 +254,6 @@ def _fetch(ctx: "OfflineRLContext"):
for i, data in enumerate(dataloader):
ctx.train_data = data
yield
ctx.train_epoch += 1
# TODO apply data update (e.g. priority) in offline setting when necessary

return _fetch
Expand Down
Loading
Loading