Skip to content

Commit

Permalink
Init: TD3
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Sep 5, 2019
1 parent ad6076b commit 46d8d97
Show file tree
Hide file tree
Showing 16 changed files with 712 additions and 1 deletion.
42 changes: 42 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
*.swp
*.pyc
*.pkl
*.py~
*.bak
.pytest_cache
.DS_Store
.idea
.coverage
.coverage.*
__pycache__/
_build/
*.npz

# Setuptools distribution and build folders.
/dist/
/build
keys/

# Virtualenv
/env


*.sublime-project
*.sublime-workspace

.idea

logs/

.ipynb_checkpoints
ghostdriver.log

htmlcov

junk
src

*.egg-info
.cache

MUJOCO_LOG.TXT
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
The MIT License

Copyright (c) 2019 Antonin Raffin

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# torchy-baselines
# Torchy-Baselines
42 changes: 42 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import sys
import subprocess
from setuptools import setup, find_packages


setup(name='torchy_baselines',
packages=[package for package in find_packages()
if package.startswith('torchy_baselines')],
install_requires=[
'gym[classic_control]>=0.10.9',
'numpy',
'torch>=1.2.0+cpu' # torch>=1.2.0
],
extras_require={
'tests': [
'pytest',
'pytest-cov',
'pytest-env',
'pytest-xdist',
],
'docs': [
'sphinx',
'sphinx-autobuild',
'sphinx-rtd-theme'
]
},
description='Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.',
author='Antonin Raffin',
url='',
author_email='antonin.raffin@dlr.de',
keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
"gym openai stable baselines toolbox python data-science",
license="MIT",
long_description="",
long_description_content_type='text/markdown',
version="0.0.1",
)

# python setup.py sdist
# python setup.py bdist_wheel
# twine upload --repository-url https://test.pypi.org/legacy/ dist/*
# twine upload dist/*
Empty file added tests/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions tests/test_td3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import gym

from torchy_baselines import TD3

def test_simple_run():
env = gym.make("Pendulum-v0")
model = TD3('MlpPolicy', env, policy_kwargs=dict(net_arch=[64, 64]), verbose=1)
model.learn(total_timesteps=50000)
3 changes: 3 additions & 0 deletions torchy_baselines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torchy_baselines.td3 import TD3

__version__ = "0.0.1"
Empty file.
167 changes: 167 additions & 0 deletions torchy_baselines/common/base_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from abc import ABC, abstractmethod


import numpy as np
import gym

from torchy_baselines.common.policies import get_policy_from_name


class BaseRLModel(ABC):
"""
The base RL model
:param policy: (BasePolicy) Policy object
:param env: (Gym environment) The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 debug
:param policy_base: (BasePolicy) the base policy used by this method
"""

def __init__(self, policy, env, policy_base, policy_kwargs=None, verbose=0):
# if isinstance(policy, str) and policy_base is not None:
# self.policy = get_policy_from_name(policy_base, policy)
# else:
# self.policy = policy
self.policy = None
self.env = env
self.verbose = verbose
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
self.observation_space = None
self.action_space = None
self.n_envs = None
self.num_timesteps = 0
self.params = None

if env is not None:
self.env = env
self.n_envs = 1
self.observation_space = env.observation_space
self.action_space = env.action_space

def get_env(self):
"""
returns the current environment (can be None if not defined)
:return: (Gym Environment) The current environment
"""
return self.env

def set_env(self, env):
"""
Checks the validity of the environment, and if it is coherent, set it as the current environment.
:param env: (Gym Environment) The environment for learning a policy
"""
pass

def get_parameter_list(self):
"""
Get pytorch Variables of model's parameters
This includes all variables necessary for continuing training (saving / loading).
:return: (list) List of pytorch Variables
"""
pass

def get_parameters(self):
"""
Get current model parameters as dictionary of variable name -> ndarray.
:return: (OrderedDict) Dictionary of variable name -> ndarray of model's parameters.
"""
raise NotImplementedError()

def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4,
adam_epsilon=1e-8, val_interval=None):
"""
Pretrain a model using behavior cloning:
supervised learning given an expert dataset.
NOTE: only Box and Discrete spaces are supported for now.
:param dataset: (ExpertDataset) Dataset manager
:param n_epochs: (int) Number of iterations on the training set
:param learning_rate: (float) Learning rate
:param adam_epsilon: (float) the epsilon value for the adam optimizer
:param val_interval: (int) Report training and validation losses every n epochs.
By default, every 10th of the maximum number of epochs.
:return: (BaseRLModel) the pretrained model
"""
raise NotImplementedError()

@abstractmethod
def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run",
reset_num_timesteps=True):
"""
Return a trained model.
:param total_timesteps: (int) The total number of samples to train on
:param seed: (int) The initial seed for training, if None: keep current seed
:param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm.
It takes the local and global variables. If it returns False, training is aborted.
:param log_interval: (int) The number of timesteps before logging.
:param tb_log_name: (str) the name of the run for tensorboard log
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
:return: (BaseRLModel) the trained model
"""
pass

@abstractmethod
def predict(self, observation, state=None, mask=None, deterministic=False):
"""
Get the model's action from an observation
:param observation: (np.ndarray) the input observation
:param state: (np.ndarray) The last states (can be None, used in recurrent policies)
:param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
:param deterministic: (bool) Whether or not to return deterministic actions.
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
"""
pass

def load_parameters(self, load_path_or_dict, exact_match=True):
"""
Load model parameters from a file or a dictionary
Dictionary keys should be tensorflow variable names, which can be obtained
with ``get_parameters`` function. If ``exact_match`` is True, dictionary
should contain keys for all model's parameters, otherwise RunTimeError
is raised. If False, only variables included in the dictionary will be updated.
This does not load agent's hyper-parameters.
.. warning::
This function does not update trainer/optimizer variables (e.g. momentum).
As such training after using this function may lead to less-than-optimal results.
:param load_path_or_dict: (str or file-like or dict) Save parameter location
or dict of parameters as variable.name -> ndarrays to be loaded.
:param exact_match: (bool) If True, expects load dictionary to contain keys for
all variables in the model. If False, loads parameters only for variables
mentioned in the dictionary. Defaults to True.
"""
raise NotImplementedError()

@abstractmethod
def save(self, save_path):
"""
Save the current parameters to file
:param save_path: (str or file-like object) the save location
"""
raise NotImplementedError()

@classmethod
@abstractmethod
def load(cls, load_path, env=None, **kwargs):
"""
Load the model from file
:param load_path: (str or file-like) the saved parameter location
:param env: (Gym Envrionment) the new environment to run the loaded model on
(can be None if you only need prediction from a trained model)
:param kwargs: extra arguments to change the model when loading
"""
raise NotImplementedError()
Empty file.
61 changes: 61 additions & 0 deletions torchy_baselines/common/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from abc import ABC


class BasePolicy(ABC):
"""
The base policy object
:param observation_space: (Gym Space) The observation space of the environment
:param action_space: (Gym Space) The action space of the environment
"""

def __init__(self, observation_space, action_space, device='cpu'):
self.observation_space = observation_space
self.action_space = action_space
self.device = device


_policy_registry = {
# ActorCriticPolicy: {
# "MlpPolicy": MlpPolicy,
# }
}


def get_policy_from_name(base_policy_type, name):
"""
returns the registed policy from the base type and name
:param base_policy_type: (BasePolicy) the base policy object
:param name: (str) the policy name
:return: (base_policy_type) the policy
"""
if base_policy_type not in _policy_registry:
raise ValueError("Error: the policy type {} is not registered!".format(base_policy_type))
if name not in _policy_registry[base_policy_type]:
raise ValueError("Error: unknown policy type {}, the only registed policy type are: {}!"
.format(name, list(_policy_registry[base_policy_type].keys())))
return _policy_registry[base_policy_type][name]


def register_policy(name, policy):
"""
returns the registed policy from the base type and name
:param name: (str) the policy name
:param policy: (subclass of BasePolicy) the policy
"""
sub_class = None
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
if sub_class is None:
raise ValueError("Error: the policy {} is not of any known subclasses of BasePolicy!".format(policy))

if sub_class not in _policy_registry:
_policy_registry[sub_class] = {}
if name in _policy_registry[sub_class]:
raise ValueError("Error: the name {} is alreay registered for a different policy, will not override."
.format(name))
_policy_registry[sub_class][name] = policy
Loading

0 comments on commit 46d8d97

Please sign in to comment.