Skip to content

Commit

Permalink
Merge pull request #33 from RLE-Foundation/envpool_procgen
Browse files Browse the repository at this point in the history
Envpool procgen
  • Loading branch information
Yuanmo authored Oct 16, 2023
2 parents 9673b6a + 446f32e commit e6ce692
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ _build/
/train.py
/checkpoints_in
/rllte/copilot/g4f
/misc

# find . | grep -E "(/__pycache__$|\.pyc$|\.pyo$)" | xargs rm -rf
6 changes: 3 additions & 3 deletions rllte/env/procgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers import NormalizeReward, RecordEpisodeStatistics, TransformObservation, TransformReward
from procgen import ProcgenEnv

Expand All @@ -53,7 +53,7 @@ def __init__(self, env: gym.Env, num_envs: int) -> None:
shape=[3, 64, 64],
dtype=env.observation_space["rgb"].dtype,
)
self.single_action_space = env.action_space
self.single_action_space = Discrete(env.action_space.n)
self.is_vector_env = True
self.num_envs = num_envs

Expand Down Expand Up @@ -169,7 +169,7 @@ def make_procgen_env(
num_levels=num_levels,
start_level=start_level,
distribution_mode=distribution_mode,
rand_seed=seed,
# rand_seed=seed,
)
envs = AdapterEnv(envs, num_envs)
envs = TransformObservation(envs, lambda obs: obs["rgb"].transpose(0, 3, 1, 2))
Expand Down
30 changes: 27 additions & 3 deletions rllte/hub/applications/procgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, agent: str = "PPO", env_id: str = "bigfish", seed: int = 1, d
num_levels=200,
start_level=0,
distribution_mode="easy",
asynchronous=False
)
eval_envs = make_envpool_procgen_env(
env_id=env_id,
Expand Down Expand Up @@ -90,6 +91,29 @@ def __init__(self, agent: str = "PPO", env_id: str = "bigfish", seed: int = 1, d
init_fn="xavier_uniform",
)
elif agent == "DAAC":
# Best hyperparameters for DAAC reported in
# https://github.com/rraileanu/idaac/blob/main/hyperparams.py
if env_id in ['plunder', 'chaser']:
value_epochs = 1
else:
value_epochs = 9

if env_id in ['miner', 'bigfish', 'dodgeball']:
value_freq = 32
elif env_id == 'plunder':
value_freq = 8
else:
value_freq = 1

if env_id == 'plunder':
adv_coef = 0.3
elif env_id == 'chaser':
adv_coef = 0.15
elif env_id in ['climber', 'bigfish']:
adv_coef = 0.05
else:
adv_coef = 0.25

self.agent = DAAC( # type: ignore[assignment]
env=envs,
eval_env=eval_envs,
Expand All @@ -104,11 +128,11 @@ def __init__(self, agent: str = "PPO", env_id: str = "bigfish", seed: int = 1, d
clip_range=0.2,
clip_range_vf=0.2,
policy_epochs=1,
value_epochs=9,
value_freq=3,
value_epochs=value_epochs,
value_freq=value_freq,
vf_coef=0.5,
ent_coef=0.01,
adv_coef=0.05,
adv_coef=adv_coef,
max_grad_norm=0.5,
init_fn="xavier_uniform",
)
Expand Down
2 changes: 1 addition & 1 deletion rllte/xploit/policy/on_policy_decoupled_actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def freeze(self, encoder: nn.Module, dist: Distribution) -> None:
# initialize parameters
self.apply(self.init_fn)
# synchronize the parameters of actor_encoder and critic_encoder
self.critic_encoder.load_state_dict(self.actor_encoder.state_dict())
# self.critic_encoder.load_state_dict(self.actor_encoder.state_dict())
# build optimizers
self.actor_params = itertools.chain(self.actor_encoder.parameters(), self.actor.parameters(), self.gae.parameters())
self.critic_params = itertools.chain(self.critic_encoder.parameters(), self.critic.parameters())
Expand Down

0 comments on commit e6ce692

Please sign in to comment.