Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zt): add metadrive-simulator env and related onppo config #574

Merged
merged 10 commits into from
Feb 15, 2023
103 changes: 103 additions & 0 deletions dizoo/metadrive/config/test_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from easydict import EasyDict
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
from functools import partial
from tensorboardX import SummaryWriter
import torch
from ding.envs import BaseEnvManager, SyncSubprocessEnvManager
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
from ding.config import compile_config
from ding.model.template import VAC
from ding.policy import PPOPolicy
from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper


# ckpt dir:
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
model_dir = None
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
metadrive_basic_config = dict(
exp_name='test_ppo_metadrive',
env=dict(
metadrive=dict(
use_render = True,
traffic_density=0.10,
map = 'XSOS',
horizon = 4000,
driving_reward = 1.0,
speed_reward = 0.10,
out_of_road_penalty = 40.0,
crash_vehicle_penalty = 40.0,
decision_repeat=20,
use_lateral_reward = False,
out_of_route_done = True,
),
manager=dict(
shared_memory=False,
max_retry=2,
context='spawn',
),
n_evaluator_episode=16,
stop_value=99999,
collector_env_num=1,
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
evaluator_env_num=1,
),
policy=dict(
cuda=True,
action_space='continuous',
model=dict(
obs_shape=[5, 84, 84],
action_shape=2,
action_space='continuous',
bound_type='tanh',
encoder_hidden_size_list=[128, 128, 64],
),
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
entropy_weight = 0.001,
value_weight=0.5,
clip_ratio = 0.02,
adv_norm=False,
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
value_norm=True,
grad_clip_value=10,
),
collect=dict(
n_sample=1000,
),
eval=dict(
evaluator=dict(
eval_freq=1000,
),
),
),
)
main_config = EasyDict(metadrive_basic_config)

def wrapped_env(env_cfg, wrapper_cfg=None):
return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg)

def main(cfg):
cfg = compile_config(
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
evaluator_env = BaseEnvManager(
env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(evaluator_env_num)],
cfg=cfg.env.manager,
)
model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
if model_dir is not None:
policy._load_state_dict_collect(torch.load(model_dir, map_location='cpu'))
tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
learner.call_hook('before_run')
stop, rate = evaluator.eval()
evaluator.close()
learner.close()


if __name__ == '__main__':
main(main_config)
115 changes: 115 additions & 0 deletions dizoo/metadrive/config/train_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from easydict import EasyDict
from functools import partial
from tensorboardX import SummaryWriter
import metadrive
import gym
from ding.envs import BaseEnvManager, SyncSubprocessEnvManager
from ding.config import compile_config
from ding.model.template import QAC, VAC
from ding.policy import PPOPolicy
from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper


metadrive_basic_config = dict(
exp_name='train_ppo_metadrive',
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
env=dict(
metadrive=dict(
use_render = False,
Copy link
Collaborator

@puyuan1996 puyuan1996 Feb 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no space near =, execute bash format.sh diizoo/metadrive to reformat the files

traffic_density=0.10,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some anatations about the key parameters in metadrive env?

map = 'XSOS',
horizon = 4000,
driving_reward = 1.0,
speed_reward = 0.1,
out_of_road_penalty = 40.0,
crash_vehicle_penalty = 40.0,
decision_repeat=20,
use_lateral_reward=False,
out_of_route_done = True,
),
manager=dict(
shared_memory=False,
max_retry=2,
context='spawn',
),
n_evaluator_episode=16,
stop_value=99999,
collector_env_num=8,
evaluator_env_num=8,
),
policy=dict(
cuda=True,
action_space='continuous',
model=dict(
obs_shape=[5, 84, 84],
action_shape=2,
action_space='continuous',
bound_type='tanh',
encoder_hidden_size_list=[128, 128, 64],
),
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
entropy_weight = 0.001,
value_weight=0.5,
clip_ratio = 0.02,
adv_norm=False,
value_norm=True,
grad_clip_value=10,
),
collect=dict(
n_sample=3000,
),
eval=dict(
evaluator=dict(
eval_freq=1000,
),
),
),
)
main_config = EasyDict(metadrive_basic_config)

def wrapped_env(env_cfg, wrapper_cfg=None):
return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg)

def main(cfg):
cfg = compile_config(
cfg, SyncSubprocessEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = SyncSubprocessEnvManager(
env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(collector_env_num)],
cfg=cfg.env.manager,
)
evaluator_env = SyncSubprocessEnvManager(
env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(evaluator_env_num)],
cfg=cfg.env.manager,
)
model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
learner.call_hook('before_run')
while True:
if evaluator.should_eval(learner.train_iter):
stop, rate = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Sampling data from environments
new_data = collector.collect(cfg.policy.collect.n_sample, train_iter=learner.train_iter)
learner.train(new_data, collector.envstep)
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
collector.close()
evaluator.close()
learner.close()


if __name__ == '__main__':
main(main_config)
Loading