Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zjow): add PPO demo for complex env observation #644

Merged
merged 9 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions ding/example/ppo_with_complex_obs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from typing import Dict
import os
import torch
import torch.nn as nn
import numpy as np
import gym
from gym import spaces
from ditk import logging
from ding.envs import DingEnvWrapper, EvalEpisodeReturnEnv, \
BaseEnvManagerV2
from ding.config import compile_config
from ding.policy import PPOPolicy
from ding.utils import set_pkg_seed
from ding.model import VAC
from ding.framework import task, ding_init
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
gae_estimator, online_logger
from easydict import EasyDict

my_env_ppo_config = dict(
exp_name='my_env_ppo_seed0',
env=dict(
collector_env_num=4,
evaluator_env_num=4,
n_evaluator_episode=4,
stop_value=195,
),
policy=dict(
cuda=True,
action_space='discrete',
model=dict(
obs_shape=None,
action_shape=2,
action_space='discrete',
critic_head_hidden_size=138,
actor_head_hidden_size=138,
),
learn=dict(
epoch_per_collect=2,
batch_size=64,
learning_rate=0.001,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
learner=dict(hook=dict(save_ckpt_after_iter=100)),
),
collect=dict(
n_sample=256, unroll_len=1, discount_factor=0.9, gae_lambda=0.95, collector=dict(transform_obs=True, )
),
eval=dict(evaluator=dict(eval_freq=100, ), ),
),
)
my_env_ppo_config = EasyDict(my_env_ppo_config)
main_config = my_env_ppo_config
my_env_ppo_create_config = dict(
env_manager=dict(type='base'),
policy=dict(type='ppo'),
)
my_env_ppo_create_config = EasyDict(my_env_ppo_create_config)
create_config = my_env_ppo_create_config


class MyEnv(gym.Env):

def __init__(self, seq_len=5, feature_dim=10, image_size=(10, 10, 3)):
super().__init__()

# Define the action space
self.action_space = spaces.Discrete(2)

# Define the observation space
self.observation_space = spaces.Dict(
(
{
'key_0': spaces.Dict(
{
'k1': spaces.Box(low=0, high=np.inf, shape=(1, ), dtype=np.float32),
'k2': spaces.Box(low=-1, high=1, shape=(1, ), dtype=np.float32),
}
),
'key_1': spaces.Box(low=-np.inf, high=np.inf, shape=(seq_len, feature_dim), dtype=np.float32),
'key_2': spaces.Box(low=0, high=255, shape=image_size, dtype=np.uint8),
'key_3': spaces.Box(low=0, high=np.array([np.inf, 3]), shape=(2, ), dtype=np.float32)
}
)
)

def reset(self):
# Generate a random initial state
return self.observation_space.sample()

def step(self, action):
# Compute the reward and done flag (which are not used in this example)
reward = np.random.uniform(low=0.0, high=1.0)

done = False
if np.random.uniform(low=0.0, high=1.0) > 0.7:
done = True

info = {}

# Return the next state, reward, and done flag
return self.observation_space.sample(), reward, done, info


def ding_env_maker():
return DingEnvWrapper(
MyEnv(), cfg={'env_wrapper': [
lambda env: EvalEpisodeReturnEnv(env),
]}
)


class Encoder(nn.Module):

def __init__(self, feature_dim: int):
super(Encoder, self).__init__()

# Define the networks for each input type
self.fc_net_1_k1 = nn.Sequential(nn.Linear(1, 8), nn.ReLU())
self.fc_net_1_k2 = nn.Sequential(nn.Linear(1, 8), nn.ReLU())
self.fc_net_1 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
'''
Implementation of transformer_encoder refers to Vision Transformer (ViT) code:
https://arxiv.org/abs/2010.11929
https://pytorch.org/vision/main/_modules/torchvision/models/vision_transformer.html
'''
self.class_token = nn.Parameter(torch.zeros(1, 1, feature_dim))
self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_dim, nhead=2, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)

self.conv_net = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU()
)
self.conv_fc_net = nn.Sequential(nn.Flatten(), nn.Linear(3200, 64), nn.ReLU())

self.fc_net_2 = nn.Sequential(nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, 32), nn.ReLU(), nn.Flatten())

def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
# Unpack the input tuple
dict_input = inputs['key_0'] # dict{key:(B)}
transformer_input = inputs['key_1'] # (B, seq_len, feature_dim)
conv_input = inputs['key_2'] # (B, H, W, 3)
fc_input = inputs['key_3'] # (B, X)

B = fc_input.shape[0]

# Pass each input through its corresponding network
dict_output = self.fc_net_1(
torch.cat(
[self.fc_net_1_k1(dict_input['k1'].unsqueeze(-1)),
self.fc_net_1_k2(dict_input['k2'].unsqueeze(-1))],
dim=1
)
)

batch_class_token = self.class_token.expand(B, -1, -1)
transformer_output = self.transformer_encoder(torch.cat([batch_class_token, transformer_input], dim=1))
transformer_output = transformer_output[:, 0]

conv_output = self.conv_fc_net(self.conv_net(conv_input.permute(0, 3, 1, 2)))
fc_output = self.fc_net_2(fc_input)

# Concatenate the outputs along the feature dimension
encoded_output = torch.cat([dict_output, transformer_output, conv_output, fc_output], dim=1)

return encoded_output


def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[ding_env_maker for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[ding_env_maker for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

encoder = Encoder(feature_dim=10)
model = VAC(encoder=encoder, **cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(gae_estimator(cfg, policy.collect_mode))
task.use(multistep_trainer(policy.learn_mode, log_freq=50))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(online_logger(train_show_freq=3))
task.run()


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions ding/framework/middleware/functional/advantage_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def _gae(ctx: "OnlineRLContext"):
with torch.no_grad():
if cfg.policy.cuda:
data = data.cuda()
value = model.forward(data.obs, mode='compute_critic')['value']
next_value = model.forward(data.next_obs, mode='compute_critic')['value']
value = model.forward(data.obs.to(dtype=ttorch.float32), mode='compute_critic')['value']
next_value = model.forward(data.next_obs.to(dtype=ttorch.float32), mode='compute_critic')['value']
data.value = value

traj_flag = data.done.clone()
Expand Down
3 changes: 2 additions & 1 deletion ding/framework/middleware/functional/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ def _inference(ctx: "OnlineRLContext"):
if env.closed:
env.launch()

obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
obs = ttorch.as_tensor(env.ready_obs)
ctx.obs = obs
obs = obs.to(dtype=ttorch.float32)
# TODO mask necessary rollout

obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD
Expand Down
13 changes: 12 additions & 1 deletion ding/framework/middleware/functional/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Callable, Union
from easydict import EasyDict
import treetensor.torch as ttorch
from ditk import logging
import numpy as np
from ding.policy import Policy
Expand Down Expand Up @@ -28,6 +29,9 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):

if ctx.train_data is None:
return
data = ctx.train_data
data['obs'] = data['obs'].to(dtype=ttorch.float32)
data['next_obs'] = data['next_obs'].to(dtype=ttorch.float32)
train_output = policy.forward(ctx.train_data)
#if ctx.train_iter % cfg.policy.learn.learner.hook.log_show_after_iter == 0:
if True:
Expand Down Expand Up @@ -72,7 +76,14 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):

if ctx.train_data is None: # no enough data from data fetcher
return
data = ctx.train_data.to(policy._device)
if hasattr(policy, "_device"): # For ppof policy
data = ctx.train_data.to(policy._device)
elif hasattr(policy, "get_attribute"): # For other policy
data = ctx.train_data.to(policy.get_attribute("device"))
else:
assert AttributeError("Policy should have attribution '_device'.")
data['obs'] = data['obs'].to(dtype=ttorch.float32)
data['next_obs'] = data['next_obs'].to(dtype=ttorch.float32)
train_output = policy.forward(data)
nonlocal last_log_iter
if ctx.train_iter - last_log_iter >= log_freq:
Expand Down
19 changes: 2 additions & 17 deletions ding/torch_utils/data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
import treetensor.torch as ttorch

from ding.utils.default_helper import get_shape0


def to_device(item: Any, device: str, ignore_keys: list = []) -> Any:
"""
Expand Down Expand Up @@ -456,20 +458,3 @@ def get_null_data(template: Any, num: int) -> List[Any]:
data['reward'].zero_()
ret.append(data)
return ret


def get_shape0(data):
if isinstance(data, torch.Tensor):
return data.shape[0]
elif isinstance(data, ttorch.Tensor):

def fn(t):
item = list(t.values())[0]
if np.isscalar(item[0]):
return item[0]
else:
return fn(item)

return fn(data.shape)
else:
raise TypeError("not support type: {}".format(data))
14 changes: 1 addition & 13 deletions ding/torch_utils/tests/test_data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import treetensor.torch as ttorch

from ding.torch_utils import CudaFetcher, to_device, to_dtype, to_tensor, to_ndarray, to_list, \
tensor_to_list, same_shape, build_log_buffer, get_tensor_data, get_shape0, to_item
tensor_to_list, same_shape, build_log_buffer, get_tensor_data, to_item
from ding.utils import EasyTimer


Expand Down Expand Up @@ -163,18 +163,6 @@ def test_get_tensor_data(self):
with pytest.raises(TypeError):
get_tensor_data(EasyTimer())

def test_get_shape0(self):
a = {
'a': {
'b': torch.randn(4, 3)
},
'c': {
'd': torch.randn(4)
},
}
a = ttorch.as_tensor(a)
assert get_shape0(a) == 4


@pytest.mark.unittest
def test_log_dict():
Expand Down
28 changes: 27 additions & 1 deletion ding/utils/default_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,32 @@
from functools import lru_cache # in python3.9, we can change to cache
import numpy as np
import torch
import treetensor.torch as ttorch


def get_shape0(data):
"""
Get shape[0] of data's torch tensor or treetensor
"""
if isinstance(data, list) or isinstance(data, tuple):
return get_shape0(data[0])
elif isinstance(data, dict):
for k, v in data.items():
return get_shape0(v)
elif isinstance(data, torch.Tensor):
return data.shape[0]
elif isinstance(data, ttorch.Tensor):

def fn(t):
item = list(t.values())[0]
if np.isscalar(item[0]):
return item[0]
else:
return fn(item)

return fn(data.shape)
else:
raise TypeError("Error in getting shape0, not support type: {}".format(data))


def lists_to_dicts(
Expand Down Expand Up @@ -430,7 +456,7 @@ def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> d
elif k in ['prev_state', 'prev_actor_state', 'prev_critic_state']:
length.append(len(v))
elif isinstance(v, list) or isinstance(v, tuple):
length.append(len(v[0]))
length.append(get_shape0(v[0]))
elif isinstance(v, dict):
length.append(len(v[list(v.keys())[0]]))
else:
Expand Down
Loading