Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polish(davide) add example of Gail entry + config for Mujoco and Cartpole #114

Merged
merged 30 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3a975df
added gail entry
davide97l Oct 18, 2021
2de346f
added lunarlander and cartpole config
davide97l Oct 19, 2021
2c0b62f
added gail mujoco config
davide97l Oct 20, 2021
c554094
added mujoco exp
davide97l Oct 22, 2021
46f01d6
update22-10
davide97l Oct 22, 2021
180bc49
added third exp
davide97l Oct 25, 2021
2767f5a
added metric to evaluate policies
davide97l Oct 29, 2021
9bbd325
Merge branch 'opendilab:main' into gail-entry-config
davide97l Nov 1, 2021
147310d
added GAIL entry and config for Cartpole and Walker2d
davide97l Nov 1, 2021
804e88f
checked style and unittest
davide97l Nov 1, 2021
0dbcf26
restored lunarlander env
davide97l Nov 1, 2021
fb3dde0
style problems
davide97l Nov 1, 2021
a4d8871
bug correction
davide97l Nov 1, 2021
8c8998d
Delete expert_data_train.pkl
davide97l Nov 1, 2021
30884c4
changed loss of GAIL
davide97l Nov 2, 2021
d3cb245
Merge branch 'opendilab:main' into gail-entry-config
davide97l Nov 2, 2021
1ecd8b4
Update walker2d_ddpg_gail_config.py
davide97l Nov 2, 2021
4a41d0b
Merge branch 'opendilab:main' into gail-entry-config
davide97l Nov 3, 2021
db0ade6
changed gail reward from -D(s, a) to -log(D(s, a))
davide97l Nov 4, 2021
7d54853
added small constant to reward function
davide97l Nov 4, 2021
fa28c4a
added comment to clarify config
davide97l Nov 8, 2021
7ac7627
Update walker2d_ddpg_gail_config.py
davide97l Nov 8, 2021
c4fe1ca
added lunarlander entry + config
davide97l Nov 9, 2021
93c173a
Added Atari discriminator + Pong entry config
davide97l Nov 11, 2021
a1ab0d9
Update gail_irl_model.py
davide97l Nov 15, 2021
56b1ad5
Update gail_irl_model.py
davide97l Nov 15, 2021
6b38eee
added gail serial pipeline and onehot actions for gail atari
davide97l Nov 19, 2021
c8743a1
related to previous commit
davide97l Nov 19, 2021
7499588
removed main files
davide97l Nov 19, 2021
1284543
removed old comment
davide97l Nov 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ding/entry/application_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def collect_demo_data(
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
state_dict: Optional[dict] = None,
state_dict_path: Optional[str] = None,
) -> None:
r"""
Overview:
Expand All @@ -95,6 +96,7 @@ def collect_demo_data(
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
- state_dict_path (:obj:`Optional[str]`): The path of the state_dict of policy or model.
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
Expand Down Expand Up @@ -134,14 +136,15 @@ def collect_demo_data(
# )
collect_demo_policy = policy.collect_mode
if state_dict is None:
state_dict = torch.load(cfg.learner.load_path, map_location='cpu')
assert state_dict_path is not None
state_dict = torch.load(state_dict_path, map_location='cpu')
policy.collect_mode.load_state_dict(state_dict)
collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy)

policy_kwargs = None if not hasattr(cfg.policy.other.get('eps', None), 'collect') \
else {'eps': cfg.policy.other.eps.get('collect', 0.2)}

# Let's collect some expert demostrations
# Let's collect some expert demonstrations
exp_data = collector.collect(n_sample=collect_count, policy_kwargs=policy_kwargs)
if cfg.policy.cuda:
exp_data = to_device(exp_data, 'cpu')
Expand Down
16 changes: 15 additions & 1 deletion ding/entry/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def print_registry(ctx: Context, param: Option, value: str):
@click.option(
'-m',
'--mode',
type=click.Choice(['serial', 'serial_onpolicy', 'serial_sqil', 'serial_dqfd', 'parallel', 'dist', 'eval']),
type=click.Choice(
['serial', 'serial_onpolicy', 'serial_sqil', 'serial_dqfd', 'parallel', 'dist', 'eval', 'serial_reward_model',
'serial_gail']
),
help='serial-train or parallel-train or dist-train or eval'
)
@click.option('-c', '--config', type=str, help='Path to DRL experiment config')
Expand Down Expand Up @@ -157,6 +160,17 @@ def cli(
config = get_predefined_config(env, policy)
expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_sqil(config, expert_config, seed, max_iterations=train_iter)
elif mode == 'serial_reward_model':
from .serial_entry_reward_model import serial_pipeline_reward_model
if config is None:
config = get_predefined_config(env, policy)
serial_pipeline_reward_model(config, seed, max_iterations=train_iter)
elif mode == 'serial_gail':
from .serial_entry_gail import serial_pipeline_gail
if config is None:
config = get_predefined_config(env, policy)
expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_gail(config, expert_config, seed, max_iterations=train_iter, collect_data=True)
elif mode == 'serial_dqfd':
from .serial_entry_dqfd import serial_pipeline_dqfd
if config is None:
Expand Down
165 changes: 165 additions & 0 deletions ding/entry/serial_entry_gail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from typing import Optional, Tuple
import os
import logging
from functools import partial
from tensorboardX import SummaryWriter

from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import compile_config, read_config
from ding.policy import create_policy, PolicyFactory
from ding.reward_model import create_reward_model
from ding.utils import set_pkg_seed
from ding.entry import collect_demo_data
from ding.utils import save_file
import numpy as np


def save_reward_model(path, reward_model, weights_name='best'):
path = os.path.join(path, 'reward_model', 'ckpt')
if not os.path.exists(path):
try:
os.makedirs(path)
except FileExistsError:
pass
path = os.path.join(path, 'ckpt_{}.pth.tar'.format(weights_name))
state_dict = reward_model.state_dict()
save_file(path, state_dict)
print('Saved reward model ckpt in {}'.format(path))


def serial_pipeline_gail(
input_cfg: Tuple[dict, dict],
expert_cfg: Tuple[dict, dict],
seed: int = 0,
max_iterations: Optional[int] = int(1e9),
collect_data: bool = True,
) -> 'Policy': # noqa
"""
Overview:
Serial pipeline entry with reward model.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- expert_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Expert config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- max_iterations (:obj:`Optional[torch.nn.Module]`): Learner's max iteration. Pipeline will stop \
when reaching this iteration.
- collect_data (:obj:`bool`): Collect expert data.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = input_cfg
if isinstance(expert_cfg, str):
expert_cfg, expert_create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type = create_cfg.policy.type + '_command'
cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg, save_cfg=True)
# Load expert data
if collect_data:
if expert_cfg.policy.get('other', None) is not None and expert_cfg.policy.other.get('eps', None) is not None:
expert_cfg.policy.other.eps.collect = -1
if expert_cfg.policy.get('load_path', None) is None:
expert_cfg.policy.load_path = os.path.join(expert_cfg.exp_name, 'ckpt/ckpt_best.pth.tar')
collect_demo_data(
(expert_cfg, expert_create_cfg),
seed,
state_dict_path=expert_cfg.policy.load_path,
expert_data_path=cfg.reward_model.expert_data_path,
collect_count=cfg.reward_model.collect_count
)
# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, enable_field=['learn', 'collect', 'eval', 'command'])

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
commander = BaseSerialCommander(
cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
)
reward_model = create_reward_model(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger)

# ==========
# Main loop
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
new_data = collector.collect(n_sample=cfg.policy.random_collect_size)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
best_reward = -np.inf
for _ in range(max_iterations):
# Evaluate policy performance
collect_kwargs = commander.step()
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if reward >= best_reward:
save_reward_model(cfg.exp_name, reward_model)
if stop:
break
new_data_count, target_new_data_count = 0, cfg.reward_model.get('target_new_data_count', 1)
while new_data_count < target_new_data_count:
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
new_data_count += len(new_data)
# collect data for reward_model training
reward_model.collect_data(new_data)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# update reward_model
reward_model.train()
reward_model.clear_data()
# Learn policy from collected data
for i in range(cfg.policy.learn.update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None:
# It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
logging.warning(
"Replay buffer's data can only train for {} steps. ".format(i) +
"You can modify data collect config, e.g. increasing n_sample, n_episode."
)
break
# update train_data reward
reward_model.estimate(train_data)
learner.train(train_data, collector.envstep)
if learner.policy.get_attribute('priority'):
replay_buffer.update(learner.priority_info)
if cfg.policy.on_policy:
# On-policy algorithm must clear the replay buffer.
replay_buffer.clear()

# Learner's after_run hook.
learner.call_hook('after_run')
save_reward_model(cfg.exp_name, reward_model, 'last')
# evaluate
evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
Loading