Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDPG parameters update #583

Merged
merged 3 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions tests/rl/gym_wrapper/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,14 @@
from typing import cast

from maro.simulator import Env
from maro.utils import set_seeds

from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine

set_seeds(123)

env_conf = {
"topology": "Walker2d-v4", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4
"start_tick": 0,
"durations": 100000, # Set a very large number
"options": {
"random_seed": None,
},
"options": {},
}

learn_env = Env(business_engine_cls=GymBusinessEngine, **env_conf)
Expand Down
Binary file modified tests/rl/log/Ant_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/rl/log/Ant_11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/rl/log/HalfCheetah_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/rl/log/HalfCheetah_11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/rl/log/Hopper_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/rl/log/Hopper_11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/rl/log/Swimmer_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/rl/log/Swimmer_11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/rl/log/Walker2d_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/rl/log/Walker2d_11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 7 additions & 1 deletion tests/rl/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
color_map = {
"ppo": "green",
"sac": "goldenrod",
"ddpg": "firebrick",
"vpg": "cornflowerblue",
"td3": "mediumpurple",
}


Expand Down Expand Up @@ -62,6 +65,9 @@ def plot_performance_curves(title: str, dir_names: List[str], smooth_window_size
elif "sac" in name:
algorithm = "sac"
func = get_off_policy_data
elif "ddpg" in name:
algorithm = "ddpg"
func = get_off_policy_data
else:
raise "unknown algorithm name"

Expand All @@ -85,6 +91,6 @@ def plot_performance_curves(title: str, dir_names: List[str], smooth_window_size
for env_name in ["HalfCheetah", "Hopper", "Walker2d", "Swimmer", "Ant"]:
plot_performance_curves(
title=env_name,
dir_names=[f"{algorithm}_{env_name.lower()}" for algorithm in ["ppo", "sac"]],
dir_names=[f"{algorithm}_{env_name.lower()}" for algorithm in ["ppo", "sac", "ddpg"]],
smooth_window_size=args.smooth,
)
4 changes: 4 additions & 0 deletions tests/rl/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
import argparse

from maro.cli.local.commands import run
from maro.utils.utils import set_seeds


def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("conf_path", help="Path of the job deployment")
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
parser.add_argument("--seed", type=int, help="The random seed set before running this job")
return parser.parse_args()


if __name__ == "__main__":
args = get_args()
if args.seed is not None:
set_seeds(seed=args.seed)
run(conf_path=args.conf_path, containerize=False, evaluate_only=args.evaluate_only)
12 changes: 7 additions & 5 deletions tests/rl/tasks/ddpg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
}
critic_net_conf = {
"hidden_dims": [256, 256],
"activation": torch.nn.Tanh,
"activation": torch.nn.ReLU,
}
actor_learning_rate = 1e-3
critic_learning_rate = 1e-3


class MyContinuousDDPGNet(ContinuousDDPGNet):
def __init__(self, state_dim: int, action_dim: int, action_limit: float) -> None:
def __init__(self, state_dim: int, action_dim: int, action_limit: float, noise_scale: float = 0.1) -> None:
super(MyContinuousDDPGNet, self).__init__(state_dim=state_dim, action_dim=action_dim)

self._net = FullyConnected(
Expand All @@ -49,12 +49,13 @@ def __init__(self, state_dim: int, action_dim: int, action_limit: float) -> None
)
self._optim = Adam(self._net.parameters(), lr=critic_learning_rate)
self._action_limit = action_limit
self._noise_scale = 0.1 # TODO
self._noise_scale = noise_scale

def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
action = self._net(states) * self._action_limit
if exploring:
action += torch.randn(self.action_dim) * self._noise_scale
noise = torch.randn(self.action_dim) * self._noise_scale
action += noise.to(action.device)
action = torch.clamp(action, -self._action_limit, self._action_limit)
return action

Expand Down Expand Up @@ -97,9 +98,10 @@ def get_ddpg_trainer(name: str, state_dim: int, action_dim: int) -> DDPGTrainer:
batch_size=100,
params=DDPGParams(
get_q_critic_net_func=lambda: MyQCriticNet(state_dim, action_dim),
num_epochs=20,
num_epochs=50,
n_start_train=1000,
soft_update_coef=0.005,
update_target_every=1,
),
)

Expand Down
13 changes: 5 additions & 8 deletions tests/rl/tasks/ddpg/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
# Example RL config file for GYM scenario.
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.

# Run this workflow by executing one of the following commands:
# - python tests/rl/run.py tests/rl/config.yml

job: gym_rl_workflow
scenario_path: "tests/rl/tasks/ddpg"
log_path: "tests/rl/log/ddpg"
log_path: "tests/rl/log/ddpg_walker2d"
main:
num_episodes: 25000
num_steps: 200
eval_schedule: 25
num_episodes: 80000
num_steps: 50
eval_schedule: 200
num_eval_episodes: 10
min_n_sample: 1
logging:
Expand All @@ -29,7 +26,7 @@ training:
load_episode: null
checkpointing:
path: null
interval: 25
interval: 200
logging:
stdout: INFO
file: DEBUG