Skip to content

Commit

Permalink
add sb3 example
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEimer committed Dec 11, 2023
1 parent 0b43304 commit 1ba15be
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
38 changes: 38 additions & 0 deletions examples/carl_with_sb3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import carl
import gymnasium as gym
from gymnasium.wrappers import FlattenObservation
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy

from carl.envs import CARLLunarLander
from carl.context.context_space import NormalFloatContextFeature
from carl.context.sampler import ContextSampler

# Create environment
context_distributions = [NormalFloatContextFeature("GRAVITY_X", mu=9.8, sigma=1)]
context_sampler = ContextSampler(
context_distributions=context_distributions,
context_space=CARLLunarLander.get_context_space(),
seed=42,
)
contexts = context_sampler.sample_contexts(n_contexts=5)

print("Training contexts are:")
print(contexts)

env = gym.make("carl/CARLLunarLander-v0", render_mode="rgb_array", contexts=contexts)
env = FlattenObservation(env)

# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
# Train the agent and display a progress bar
model.learn(total_timesteps=int(2e4), progress_bar=True)
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)

# Enjoy trained agent
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def read_file(filepath: str) -> str:
"sphinx-autoapi>=1.8.4",
"automl-sphinx-theme>=0.1.9",
],
"examples": [
"stable-baselines3",
]
}

setuptools.setup(
Expand Down

0 comments on commit 1ba15be

Please sign in to comment.