Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zp): add dreamerv3 algorithm #652

Merged
merged 17 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ding/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
from .serial_entry_preference_based_irl_onpolicy \
import serial_pipeline_preference_based_irl_onpolicy
from .application_entry_drex_collect_data import drex_collecting_data
from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream
from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream, serial_pipeline_dreamer
from .serial_entry_bco import serial_pipeline_bco
from .serial_entry_pc import serial_pipeline_pc
71 changes: 71 additions & 0 deletions ding/entry/serial_entry_mbrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,74 @@ def serial_pipeline_dream(
learner.call_hook('after_run')

return policy


def serial_pipeline_dreamer(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
Serial pipeline entry for dreamerv3.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
mbrl_entry_setup(input_cfg, seed, env_setting, model)

learner.call_hook('before_run')

# prefill environment buffer
if cfg.policy.get('random_collect_size', 0) > 0:
cfg.policy.random_collect_size = cfg.policy.random_collect_size // cfg.policy.collect.unroll_len
random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)

while True:
collect_kwargs = commander.step()
# eval the policy
if evaluator.should_eval(collector.envstep):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep, policy_kwargs=dict(world_model=world_model))
if stop:
break

# train world model and fill imagination buffer
steps = (
cfg.world_model.pretrain
if world_model.should_pretrain()
else int(world_model.should_train(collector.envstep))
)
for _ in range(steps):
batch_size = learner.policy.get_attribute('batch_size')
batch_length = cfg.policy.learn.batch_length
post, context = world_model.train(env_buffer, collector.envstep, learner.train_iter, batch_size, batch_length)

start = post

learner.train(
start, collector.envstep, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep)
)

# fill environment buffer
data = collector.collect(train_iter=learner.train_iter, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep, **collect_kwargs))
env_buffer.push(data, cur_collector_envstep=collector.envstep)

if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break

learner.call_hook('after_run')

return policy
3 changes: 2 additions & 1 deletion ding/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def random_collect(
new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs)
else:
new_data = collector.collect(
n_sample=policy_cfg.random_collect_size, record_random_collect=False, policy_kwargs=collect_kwargs
n_sample=policy_cfg.random_collect_size, random_collect=True,
record_random_collect=False, policy_kwargs=collect_kwargs
) # 'record_random_collect=False' means random collect without output log
if postprocess_data_fn is not None:
new_data = postprocess_data_fn(new_data)
Expand Down
45 changes: 43 additions & 2 deletions ding/envs/env_wrappers/env_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,17 @@ def observation(self, frame):
import sys
logging.warning("Please install opencv-python first.")
sys.exit(1)
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
# to do
# channel_first
if frame.shape[0] < 10:
frame = frame.transpose(1, 2, 0)
frame = cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
frame = frame.transpose(2, 0, 1)
else:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)

return frame


@ENV_WRAPPER_REGISTRY.register('scaled_float_frame')
Expand Down Expand Up @@ -256,6 +265,38 @@ def reward(self, reward):
"""
return np.sign(reward)

@ENV_WRAPPER_REGISTRY.register('action_repeat')
class ActionRepeatWrapper(gym.Wrapper):
"""
Overview:
Repeat the action to step with env.
Interface:
``__init__``, ``step``
Properties:
- env (:obj:`gym.Env`): the environment to wrap.
- ``action_repeat``

"""

def __init__(self, env, action_repeat=1):
"""
Overview:
Initialize ``self.`` See ``help(type(self))`` for accurate signature; setup the properties.
Arguments:
- env (:obj:`gym.Env`): the environment to wrap.
"""
super().__init__(env)
self.action_repeat = action_repeat

def step(self, action):
reward = 0
for _ in range(self.action_repeat):
obs, rew, done, info = self.env.step(action)
reward += rew or 0
if done:
break
return obs, reward, done, info


@ENV_WRAPPER_REGISTRY.register('delay_reward')
class DelayRewardWrapper(gym.Wrapper):
Expand Down
33 changes: 27 additions & 6 deletions ding/model/common/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn import functional as F

from ding.torch_utils import ResFCBlock, ResBlock, Flatten, normed_linear, normed_conv2d
from ding.torch_utils.network.dreamer import Conv2dSame, DreamerLayerNorm
from ding.utils import SequenceType


Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(
kernel_size: SequenceType = [8, 4, 3],
stride: SequenceType = [4, 2, 1],
padding: Optional[SequenceType] = None,
layer_norm: Optional[bool] = False,
norm_type: Optional[str] = None
) -> None:
"""
Expand All @@ -50,6 +52,7 @@ def __init__(
- stride (:obj:`SequenceType`): Sequence of ``stride`` of subsequent conv layers.
- padding (:obj:`SequenceType`): Padding added to all four sides of the input for each conv layer. \
See ``nn.Conv2d`` for more details. Default is ``None``.
- layer_norm (:obj:`bool`): Whether to use ``DreamerLayerNorm``.
- norm_type (:obj:`str`): Type of normalization to use. See ``ding.torch_utils.network.ResBlock`` \
for more details. Default is ``None``.
"""
Expand All @@ -63,17 +66,35 @@ def __init__(
layers = []
input_size = obs_shape[0] # in_channel
for i in range(len(kernel_size)):
layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i]))
layers.append(self.act)
if layer_norm:
layers.append(
Conv2dSame(
in_channels=input_size,
out_channels=hidden_size_list[i],
kernel_size=(kernel_size[i], kernel_size[i]),
stride=(2, 2),
bias=False,
)
)
layers.append(DreamerLayerNorm(hidden_size_list[i]))
layers.append(self.act)
else:
layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i]))
layers.append(self.act)
input_size = hidden_size_list[i]
assert len(set(hidden_size_list[3:-1])) <= 1, "Please indicate the same hidden size for res block parts"
for i in range(3, len(self.hidden_size_list) - 1):
layers.append(ResBlock(self.hidden_size_list[i], activation=self.act, norm_type=norm_type))
if len(self.hidden_size_list) >= len(kernel_size) + 2:
zhangpaipai marked this conversation as resolved.
Show resolved Hide resolved
assert self.hidden_size_list[len(kernel_size) - 1] == self.hidden_size_list[
len(kernel_size)], "Please indicate the same hidden size between conv and res block"
assert len(
set(hidden_size_list[len(kernel_size):-1])
) <= 1, "Please indicate the same hidden size for res block parts"
for i in range(len(kernel_size), len(self.hidden_size_list) - 1):
layers.append(ResBlock(self.hidden_size_list[i - 1], activation=self.act, norm_type=norm_type))
layers.append(Flatten())
self.main = nn.Sequential(*layers)

flatten_size = self._get_flatten_size()
self.output_size = hidden_size_list[-1]
self.output_size = hidden_size_list[-1] # outside to use
zhangpaipai marked this conversation as resolved.
Show resolved Hide resolved
self.mid = nn.Linear(flatten_size, hidden_size_list[-1])

def _get_flatten_size(self) -> int:
Expand Down
14 changes: 14 additions & 0 deletions ding/model/common/tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ def test_conv_encoder(self):
self.output_check(model, outputs)
assert outputs.shape == (B, 128)

def test_dreamer_conv_encoder(self):
inputs = torch.randn(B, C, H, W)
model = ConvEncoder(
(C, H, W),
hidden_size_list=[32, 64, 128, 256, 128],
activation=torch.nn.SiLU(),
kernel_size=[4, 4, 4, 4],
layer_norm=True
)
print(model)
outputs = model(inputs)
self.output_check(model, outputs)
assert outputs.shape == (B, 128)

def test_fc_encoder(self):
inputs = torch.randn(B, 32)
hidden_size_list = [128 for _ in range(3)]
Expand Down
2 changes: 1 addition & 1 deletion ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .q_learning import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ
from .qac import QAC, DiscreteQAC
from .pdqn import PDQN
from .vac import VAC
from .vac import VAC, DREAMERVAC
from .bc import DiscreteBC, ContinuousBC
from .pg import PG
# algorithm-specific
Expand Down
88 changes: 88 additions & 0 deletions ding/model/template/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \
FCEncoder, ConvEncoder, IMPALAConvEncoder
from ding.torch_utils.network.dreamer import ActionHead, DenseHead


@MODEL_REGISTRY.register('vac')
Expand Down Expand Up @@ -356,3 +357,90 @@ def compute_actor_critic(self, x: torch.Tensor) -> Dict:
action_type = self.actor_head[0](actor_embedding)
action_args = self.actor_head[1](actor_embedding)
return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}, 'value': value}


@MODEL_REGISTRY.register('dreamervac')
class DREAMERVAC(nn.Module):
r"""
Overview:
The VAC model.
Interfaces:
``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
"""
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType, EasyDict],
dyn_stoch=32,
dyn_deter=512,
dyn_discrete=32,
actor_layers=2,
value_layers=2,
units=512,
act='SiLU',
norm='LayerNorm',
actor_dist='normal',
actor_init_std=1.0,
actor_min_std=0.1,
actor_max_std=1.0,
actor_temp=0.1,
action_unimix_ratio=0.01,
) -> None:
r"""
Overview:
Init the VAC Model according to arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
- action_shape (:obj:`Union[int, SequenceType]`): Action's space.
- action_space (:obj:`str`): Choose action head in ['discrete', 'continuous', 'hybrid']
- share_encoder (:obj:`bool`): Whether share encoder.
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
- actor_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for actor's nn.
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
- critic_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for critic's nn.
- activation (:obj:`Optional[nn.Module]`):
The type of activation function to use in ``MLP`` the after ``layer_fn``,
if ``None`` then default set to ``nn.ReLU()``
- norm_type (:obj:`Optional[str]`):
The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details`
"""
super(DREAMERVAC, self).__init__()
obs_shape: int = squeeze(obs_shape)
action_shape = squeeze(action_shape)
self.obs_shape, self.action_shape = obs_shape, action_shape

if dyn_discrete:
feat_size = dyn_stoch * dyn_discrete + dyn_deter
else:
feat_size = dyn_stoch + dyn_deter
self.actor = ActionHead(
feat_size, # pytorch version
action_shape,
actor_layers,
units,
act,
norm,
actor_dist,
actor_init_std,
actor_min_std,
actor_max_std,
actor_temp,
outscale=1.0,
unimix_ratio=action_unimix_ratio,
)
self.critic = DenseHead(
feat_size, # pytorch version
(255, ),
value_layers,
units,
'SiLU', # act
'LN', # norm
'twohot_symlog',
outscale=0.0,
device='cuda' if torch.cuda.is_available() else 'cpu',
)
5 changes: 5 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .td3_bc import TD3BCPolicy
from .sac import SACPolicy, SACDiscretePolicy
from .mbpolicy.mbsac import MBSACPolicy, STEVESACPolicy
from .mbpolicy.dreamer import DREAMERPolicy
from .qmix import QMIXPolicy
from .wqmix import WQMIXPolicy
from .collaq import CollaQPolicy
Expand Down Expand Up @@ -302,6 +303,10 @@ class MBSACCommandModePolicy(MBSACPolicy, DummyCommandModePolicy):
class STEVESACCommandModePolicy(STEVESACPolicy, DummyCommandModePolicy):
pass

@POLICY_REGISTRY.register('dreamer_command')
class DREAMERCommandModePolicy(DREAMERPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('cql_command')
class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy):
Expand Down
1 change: 1 addition & 0 deletions ding/policy/mbpolicy/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .mbsac import MBSACPolicy
from .dreamer import DREAMERPolicy
Loading
Loading