forked from DLR-RM/stable-baselines3
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
712 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
*.swp | ||
*.pyc | ||
*.pkl | ||
*.py~ | ||
*.bak | ||
.pytest_cache | ||
.DS_Store | ||
.idea | ||
.coverage | ||
.coverage.* | ||
__pycache__/ | ||
_build/ | ||
*.npz | ||
|
||
# Setuptools distribution and build folders. | ||
/dist/ | ||
/build | ||
keys/ | ||
|
||
# Virtualenv | ||
/env | ||
|
||
|
||
*.sublime-project | ||
*.sublime-workspace | ||
|
||
.idea | ||
|
||
logs/ | ||
|
||
.ipynb_checkpoints | ||
ghostdriver.log | ||
|
||
htmlcov | ||
|
||
junk | ||
src | ||
|
||
*.egg-info | ||
.cache | ||
|
||
MUJOCO_LOG.TXT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
The MIT License | ||
|
||
Copyright (c) 2019 Antonin Raffin | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in | ||
all copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
THE SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
# torchy-baselines | ||
# Torchy-Baselines |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import sys | ||
import subprocess | ||
from setuptools import setup, find_packages | ||
|
||
|
||
setup(name='torchy_baselines', | ||
packages=[package for package in find_packages() | ||
if package.startswith('torchy_baselines')], | ||
install_requires=[ | ||
'gym[classic_control]>=0.10.9', | ||
'numpy', | ||
'torch>=1.2.0+cpu' # torch>=1.2.0 | ||
], | ||
extras_require={ | ||
'tests': [ | ||
'pytest', | ||
'pytest-cov', | ||
'pytest-env', | ||
'pytest-xdist', | ||
], | ||
'docs': [ | ||
'sphinx', | ||
'sphinx-autobuild', | ||
'sphinx-rtd-theme' | ||
] | ||
}, | ||
description='Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.', | ||
author='Antonin Raffin', | ||
url='', | ||
author_email='antonin.raffin@dlr.de', | ||
keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning " | ||
"gym openai stable baselines toolbox python data-science", | ||
license="MIT", | ||
long_description="", | ||
long_description_content_type='text/markdown', | ||
version="0.0.1", | ||
) | ||
|
||
# python setup.py sdist | ||
# python setup.py bdist_wheel | ||
# twine upload --repository-url https://test.pypi.org/legacy/ dist/* | ||
# twine upload dist/* |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import gym | ||
|
||
from torchy_baselines import TD3 | ||
|
||
def test_simple_run(): | ||
env = gym.make("Pendulum-v0") | ||
model = TD3('MlpPolicy', env, policy_kwargs=dict(net_arch=[64, 64]), verbose=1) | ||
model.learn(total_timesteps=50000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from torchy_baselines.td3 import TD3 | ||
|
||
__version__ = "0.0.1" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
import numpy as np | ||
import gym | ||
|
||
from torchy_baselines.common.policies import get_policy_from_name | ||
|
||
|
||
class BaseRLModel(ABC): | ||
""" | ||
The base RL model | ||
:param policy: (BasePolicy) Policy object | ||
:param env: (Gym environment) The environment to learn from | ||
(if registered in Gym, can be str. Can be None for loading trained models) | ||
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 debug | ||
:param policy_base: (BasePolicy) the base policy used by this method | ||
""" | ||
|
||
def __init__(self, policy, env, policy_base, policy_kwargs=None, verbose=0): | ||
# if isinstance(policy, str) and policy_base is not None: | ||
# self.policy = get_policy_from_name(policy_base, policy) | ||
# else: | ||
# self.policy = policy | ||
self.policy = None | ||
self.env = env | ||
self.verbose = verbose | ||
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs | ||
self.observation_space = None | ||
self.action_space = None | ||
self.n_envs = None | ||
self.num_timesteps = 0 | ||
self.params = None | ||
|
||
if env is not None: | ||
self.env = env | ||
self.n_envs = 1 | ||
self.observation_space = env.observation_space | ||
self.action_space = env.action_space | ||
|
||
def get_env(self): | ||
""" | ||
returns the current environment (can be None if not defined) | ||
:return: (Gym Environment) The current environment | ||
""" | ||
return self.env | ||
|
||
def set_env(self, env): | ||
""" | ||
Checks the validity of the environment, and if it is coherent, set it as the current environment. | ||
:param env: (Gym Environment) The environment for learning a policy | ||
""" | ||
pass | ||
|
||
def get_parameter_list(self): | ||
""" | ||
Get pytorch Variables of model's parameters | ||
This includes all variables necessary for continuing training (saving / loading). | ||
:return: (list) List of pytorch Variables | ||
""" | ||
pass | ||
|
||
def get_parameters(self): | ||
""" | ||
Get current model parameters as dictionary of variable name -> ndarray. | ||
:return: (OrderedDict) Dictionary of variable name -> ndarray of model's parameters. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, | ||
adam_epsilon=1e-8, val_interval=None): | ||
""" | ||
Pretrain a model using behavior cloning: | ||
supervised learning given an expert dataset. | ||
NOTE: only Box and Discrete spaces are supported for now. | ||
:param dataset: (ExpertDataset) Dataset manager | ||
:param n_epochs: (int) Number of iterations on the training set | ||
:param learning_rate: (float) Learning rate | ||
:param adam_epsilon: (float) the epsilon value for the adam optimizer | ||
:param val_interval: (int) Report training and validation losses every n epochs. | ||
By default, every 10th of the maximum number of epochs. | ||
:return: (BaseRLModel) the pretrained model | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run", | ||
reset_num_timesteps=True): | ||
""" | ||
Return a trained model. | ||
:param total_timesteps: (int) The total number of samples to train on | ||
:param seed: (int) The initial seed for training, if None: keep current seed | ||
:param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm. | ||
It takes the local and global variables. If it returns False, training is aborted. | ||
:param log_interval: (int) The number of timesteps before logging. | ||
:param tb_log_name: (str) the name of the run for tensorboard log | ||
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging) | ||
:return: (BaseRLModel) the trained model | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def predict(self, observation, state=None, mask=None, deterministic=False): | ||
""" | ||
Get the model's action from an observation | ||
:param observation: (np.ndarray) the input observation | ||
:param state: (np.ndarray) The last states (can be None, used in recurrent policies) | ||
:param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) | ||
:param deterministic: (bool) Whether or not to return deterministic actions. | ||
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) | ||
""" | ||
pass | ||
|
||
def load_parameters(self, load_path_or_dict, exact_match=True): | ||
""" | ||
Load model parameters from a file or a dictionary | ||
Dictionary keys should be tensorflow variable names, which can be obtained | ||
with ``get_parameters`` function. If ``exact_match`` is True, dictionary | ||
should contain keys for all model's parameters, otherwise RunTimeError | ||
is raised. If False, only variables included in the dictionary will be updated. | ||
This does not load agent's hyper-parameters. | ||
.. warning:: | ||
This function does not update trainer/optimizer variables (e.g. momentum). | ||
As such training after using this function may lead to less-than-optimal results. | ||
:param load_path_or_dict: (str or file-like or dict) Save parameter location | ||
or dict of parameters as variable.name -> ndarrays to be loaded. | ||
:param exact_match: (bool) If True, expects load dictionary to contain keys for | ||
all variables in the model. If False, loads parameters only for variables | ||
mentioned in the dictionary. Defaults to True. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def save(self, save_path): | ||
""" | ||
Save the current parameters to file | ||
:param save_path: (str or file-like object) the save location | ||
""" | ||
raise NotImplementedError() | ||
|
||
@classmethod | ||
@abstractmethod | ||
def load(cls, load_path, env=None, **kwargs): | ||
""" | ||
Load the model from file | ||
:param load_path: (str or file-like) the saved parameter location | ||
:param env: (Gym Envrionment) the new environment to run the loaded model on | ||
(can be None if you only need prediction from a trained model) | ||
:param kwargs: extra arguments to change the model when loading | ||
""" | ||
raise NotImplementedError() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from abc import ABC | ||
|
||
|
||
class BasePolicy(ABC): | ||
""" | ||
The base policy object | ||
:param observation_space: (Gym Space) The observation space of the environment | ||
:param action_space: (Gym Space) The action space of the environment | ||
""" | ||
|
||
def __init__(self, observation_space, action_space, device='cpu'): | ||
self.observation_space = observation_space | ||
self.action_space = action_space | ||
self.device = device | ||
|
||
|
||
_policy_registry = { | ||
# ActorCriticPolicy: { | ||
# "MlpPolicy": MlpPolicy, | ||
# } | ||
} | ||
|
||
|
||
def get_policy_from_name(base_policy_type, name): | ||
""" | ||
returns the registed policy from the base type and name | ||
:param base_policy_type: (BasePolicy) the base policy object | ||
:param name: (str) the policy name | ||
:return: (base_policy_type) the policy | ||
""" | ||
if base_policy_type not in _policy_registry: | ||
raise ValueError("Error: the policy type {} is not registered!".format(base_policy_type)) | ||
if name not in _policy_registry[base_policy_type]: | ||
raise ValueError("Error: unknown policy type {}, the only registed policy type are: {}!" | ||
.format(name, list(_policy_registry[base_policy_type].keys()))) | ||
return _policy_registry[base_policy_type][name] | ||
|
||
|
||
def register_policy(name, policy): | ||
""" | ||
returns the registed policy from the base type and name | ||
:param name: (str) the policy name | ||
:param policy: (subclass of BasePolicy) the policy | ||
""" | ||
sub_class = None | ||
for cls in BasePolicy.__subclasses__(): | ||
if issubclass(policy, cls): | ||
sub_class = cls | ||
break | ||
if sub_class is None: | ||
raise ValueError("Error: the policy {} is not of any known subclasses of BasePolicy!".format(policy)) | ||
|
||
if sub_class not in _policy_registry: | ||
_policy_registry[sub_class] = {} | ||
if name in _policy_registry[sub_class]: | ||
raise ValueError("Error: the name {} is alreay registered for a different policy, will not override." | ||
.format(name)) | ||
_policy_registry[sub_class][name] = policy |
Oops, something went wrong.