From 46d8d9725b0a95c239d240ebcb65ae32ba7a15d0 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 5 Sep 2019 17:29:41 +0200 Subject: [PATCH] Init: TD3 --- .gitignore | 42 ++++++ LICENSE | 21 +++ README.md | 2 +- setup.py | 42 ++++++ tests/__init__.py | 0 tests/test_td3.py | 8 + torchy_baselines/__init__.py | 3 + torchy_baselines/common/__init__.py | 0 torchy_baselines/common/base_class.py | 167 +++++++++++++++++++++ torchy_baselines/common/evaluation.py | 0 torchy_baselines/common/policies.py | 61 ++++++++ torchy_baselines/common/replay_buffer.py | 85 +++++++++++ torchy_baselines/common/utils.py | 20 +++ torchy_baselines/td3/__init__.py | 1 + torchy_baselines/td3/policies.py | 82 +++++++++++ torchy_baselines/td3/td3.py | 179 +++++++++++++++++++++++ 16 files changed, 712 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 setup.py create mode 100644 tests/__init__.py create mode 100644 tests/test_td3.py create mode 100644 torchy_baselines/__init__.py create mode 100644 torchy_baselines/common/__init__.py create mode 100644 torchy_baselines/common/base_class.py create mode 100644 torchy_baselines/common/evaluation.py create mode 100644 torchy_baselines/common/policies.py create mode 100644 torchy_baselines/common/replay_buffer.py create mode 100644 torchy_baselines/common/utils.py create mode 100644 torchy_baselines/td3/__init__.py create mode 100644 torchy_baselines/td3/policies.py create mode 100644 torchy_baselines/td3/td3.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..5a26d08e0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,42 @@ +*.swp +*.pyc +*.pkl +*.py~ +*.bak +.pytest_cache +.DS_Store +.idea +.coverage +.coverage.* +__pycache__/ +_build/ +*.npz + +# Setuptools distribution and build folders. +/dist/ +/build +keys/ + +# Virtualenv +/env + + +*.sublime-project +*.sublime-workspace + +.idea + +logs/ + +.ipynb_checkpoints +ghostdriver.log + +htmlcov + +junk +src + +*.egg-info +.cache + +MUJOCO_LOG.TXT diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..0951e29b4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) 2019 Antonin Raffin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/README.md b/README.md index a88497496..2bb028465 100644 --- a/README.md +++ b/README.md @@ -1 +1 @@ -# torchy-baselines \ No newline at end of file +# Torchy-Baselines diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..201356d08 --- /dev/null +++ b/setup.py @@ -0,0 +1,42 @@ +import sys +import subprocess +from setuptools import setup, find_packages + + +setup(name='torchy_baselines', + packages=[package for package in find_packages() + if package.startswith('torchy_baselines')], + install_requires=[ + 'gym[classic_control]>=0.10.9', + 'numpy', + 'torch>=1.2.0+cpu' # torch>=1.2.0 + ], + extras_require={ + 'tests': [ + 'pytest', + 'pytest-cov', + 'pytest-env', + 'pytest-xdist', + ], + 'docs': [ + 'sphinx', + 'sphinx-autobuild', + 'sphinx-rtd-theme' + ] + }, + description='Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.', + author='Antonin Raffin', + url='', + author_email='antonin.raffin@dlr.de', + keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning " + "gym openai stable baselines toolbox python data-science", + license="MIT", + long_description="", + long_description_content_type='text/markdown', + version="0.0.1", + ) + +# python setup.py sdist +# python setup.py bdist_wheel +# twine upload --repository-url https://test.pypi.org/legacy/ dist/* +# twine upload dist/* diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_td3.py b/tests/test_td3.py new file mode 100644 index 000000000..d29d007c6 --- /dev/null +++ b/tests/test_td3.py @@ -0,0 +1,8 @@ +import gym + +from torchy_baselines import TD3 + +def test_simple_run(): + env = gym.make("Pendulum-v0") + model = TD3('MlpPolicy', env, policy_kwargs=dict(net_arch=[64, 64]), verbose=1) + model.learn(total_timesteps=50000) diff --git a/torchy_baselines/__init__.py b/torchy_baselines/__init__.py new file mode 100644 index 000000000..5656e719a --- /dev/null +++ b/torchy_baselines/__init__.py @@ -0,0 +1,3 @@ +from torchy_baselines.td3 import TD3 + +__version__ = "0.0.1" diff --git a/torchy_baselines/common/__init__.py b/torchy_baselines/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py new file mode 100644 index 000000000..c7eaee4df --- /dev/null +++ b/torchy_baselines/common/base_class.py @@ -0,0 +1,167 @@ +from abc import ABC, abstractmethod + + +import numpy as np +import gym + +from torchy_baselines.common.policies import get_policy_from_name + + +class BaseRLModel(ABC): + """ + The base RL model + + :param policy: (BasePolicy) Policy object + :param env: (Gym environment) The environment to learn from + (if registered in Gym, can be str. Can be None for loading trained models) + :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 debug + :param policy_base: (BasePolicy) the base policy used by this method + """ + + def __init__(self, policy, env, policy_base, policy_kwargs=None, verbose=0): + # if isinstance(policy, str) and policy_base is not None: + # self.policy = get_policy_from_name(policy_base, policy) + # else: + # self.policy = policy + self.policy = None + self.env = env + self.verbose = verbose + self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs + self.observation_space = None + self.action_space = None + self.n_envs = None + self.num_timesteps = 0 + self.params = None + + if env is not None: + self.env = env + self.n_envs = 1 + self.observation_space = env.observation_space + self.action_space = env.action_space + + def get_env(self): + """ + returns the current environment (can be None if not defined) + + :return: (Gym Environment) The current environment + """ + return self.env + + def set_env(self, env): + """ + Checks the validity of the environment, and if it is coherent, set it as the current environment. + + :param env: (Gym Environment) The environment for learning a policy + """ + pass + + def get_parameter_list(self): + """ + Get pytorch Variables of model's parameters + + This includes all variables necessary for continuing training (saving / loading). + + :return: (list) List of pytorch Variables + """ + pass + + def get_parameters(self): + """ + Get current model parameters as dictionary of variable name -> ndarray. + + :return: (OrderedDict) Dictionary of variable name -> ndarray of model's parameters. + """ + raise NotImplementedError() + + def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, + adam_epsilon=1e-8, val_interval=None): + """ + Pretrain a model using behavior cloning: + supervised learning given an expert dataset. + + NOTE: only Box and Discrete spaces are supported for now. + + :param dataset: (ExpertDataset) Dataset manager + :param n_epochs: (int) Number of iterations on the training set + :param learning_rate: (float) Learning rate + :param adam_epsilon: (float) the epsilon value for the adam optimizer + :param val_interval: (int) Report training and validation losses every n epochs. + By default, every 10th of the maximum number of epochs. + :return: (BaseRLModel) the pretrained model + """ + raise NotImplementedError() + + @abstractmethod + def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run", + reset_num_timesteps=True): + """ + Return a trained model. + + :param total_timesteps: (int) The total number of samples to train on + :param seed: (int) The initial seed for training, if None: keep current seed + :param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm. + It takes the local and global variables. If it returns False, training is aborted. + :param log_interval: (int) The number of timesteps before logging. + :param tb_log_name: (str) the name of the run for tensorboard log + :param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging) + :return: (BaseRLModel) the trained model + """ + pass + + @abstractmethod + def predict(self, observation, state=None, mask=None, deterministic=False): + """ + Get the model's action from an observation + + :param observation: (np.ndarray) the input observation + :param state: (np.ndarray) The last states (can be None, used in recurrent policies) + :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) + :param deterministic: (bool) Whether or not to return deterministic actions. + :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) + """ + pass + + def load_parameters(self, load_path_or_dict, exact_match=True): + """ + Load model parameters from a file or a dictionary + + Dictionary keys should be tensorflow variable names, which can be obtained + with ``get_parameters`` function. If ``exact_match`` is True, dictionary + should contain keys for all model's parameters, otherwise RunTimeError + is raised. If False, only variables included in the dictionary will be updated. + + This does not load agent's hyper-parameters. + + .. warning:: + This function does not update trainer/optimizer variables (e.g. momentum). + As such training after using this function may lead to less-than-optimal results. + + :param load_path_or_dict: (str or file-like or dict) Save parameter location + or dict of parameters as variable.name -> ndarrays to be loaded. + :param exact_match: (bool) If True, expects load dictionary to contain keys for + all variables in the model. If False, loads parameters only for variables + mentioned in the dictionary. Defaults to True. + """ + raise NotImplementedError() + + @abstractmethod + def save(self, save_path): + """ + Save the current parameters to file + + :param save_path: (str or file-like object) the save location + """ + raise NotImplementedError() + + @classmethod + @abstractmethod + def load(cls, load_path, env=None, **kwargs): + """ + Load the model from file + + :param load_path: (str or file-like) the saved parameter location + :param env: (Gym Envrionment) the new environment to run the loaded model on + (can be None if you only need prediction from a trained model) + :param kwargs: extra arguments to change the model when loading + """ + raise NotImplementedError() diff --git a/torchy_baselines/common/evaluation.py b/torchy_baselines/common/evaluation.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py new file mode 100644 index 000000000..417d3e8f9 --- /dev/null +++ b/torchy_baselines/common/policies.py @@ -0,0 +1,61 @@ +from abc import ABC + + +class BasePolicy(ABC): + """ + The base policy object + + :param observation_space: (Gym Space) The observation space of the environment + :param action_space: (Gym Space) The action space of the environment + """ + + def __init__(self, observation_space, action_space, device='cpu'): + self.observation_space = observation_space + self.action_space = action_space + self.device = device + + +_policy_registry = { + # ActorCriticPolicy: { + # "MlpPolicy": MlpPolicy, + # } +} + + +def get_policy_from_name(base_policy_type, name): + """ + returns the registed policy from the base type and name + + :param base_policy_type: (BasePolicy) the base policy object + :param name: (str) the policy name + :return: (base_policy_type) the policy + """ + if base_policy_type not in _policy_registry: + raise ValueError("Error: the policy type {} is not registered!".format(base_policy_type)) + if name not in _policy_registry[base_policy_type]: + raise ValueError("Error: unknown policy type {}, the only registed policy type are: {}!" + .format(name, list(_policy_registry[base_policy_type].keys()))) + return _policy_registry[base_policy_type][name] + + +def register_policy(name, policy): + """ + returns the registed policy from the base type and name + + :param name: (str) the policy name + :param policy: (subclass of BasePolicy) the policy + """ + sub_class = None + for cls in BasePolicy.__subclasses__(): + if issubclass(policy, cls): + sub_class = cls + break + if sub_class is None: + raise ValueError("Error: the policy {} is not of any known subclasses of BasePolicy!".format(policy)) + + if sub_class not in _policy_registry: + _policy_registry[sub_class] = {} + if name in _policy_registry[sub_class]: + raise ValueError("Error: the name {} is alreay registered for a different policy, will not override." + .format(name)) + _policy_registry[sub_class][name] = policy diff --git a/torchy_baselines/common/replay_buffer.py b/torchy_baselines/common/replay_buffer.py new file mode 100644 index 000000000..1d15b4b22 --- /dev/null +++ b/torchy_baselines/common/replay_buffer.py @@ -0,0 +1,85 @@ +import numpy as np +import torch as th + +# Code based on: +# https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py + +# Expects tuples of (state, next_state, action, reward, done) +# class ReplayBuffer(object): +# def __init__(self, max_size=1e6): +# self.storage = [] +# self.max_size = max_size +# self.ptr = 0 +# +# def add(self, data): +# if len(self.storage) == self.max_size: +# self.storage[int(self.ptr)] = data +# self.ptr = (self.ptr + 1) % self.max_size +# else: +# self.storage.append(data) +# +# def sample(self, batch_size): +# ind = np.random.randint(0, len(self.storage), size=batch_size) +# x, y, u, r, d = [], [], [], [], [] +# +# for i in ind: +# X, Y, U, R, D = self.storage[i] +# x.append(np.array(X, copy=False)) +# y.append(np.array(Y, copy=False)) +# u.append(np.array(U, copy=False)) +# r.append(np.array(R, copy=False)) +# d.append(np.array(D, copy=False)) +# +# return np.array(x), np.array(y), np.array(u), np.array(r).reshape(-1, 1), np.array(d).reshape(-1, 1) + + +class ReplayBuffer(object): + + def __init__(self, buffer_size, state_dim, action_dim, device='cpu'): + super(ReplayBuffer, self).__init__() + # params + self.buffer_size = buffer_size + self.state_dim = state_dim + self.action_dim = action_dim + self.pos = 0 + self.full = False + self.device = device + + self.states = th.zeros(self.buffer_size, self.state_dim) + self.actions = th.zeros(self.buffer_size, self.action_dim) + self.next_states = th.zeros(self.buffer_size, self.state_dim) + self.rewards = th.zeros(self.buffer_size, 1) + self.dones = th.zeros(self.buffer_size, 1) + + def size(self): + if self.full: + return self.buffer_size + return self.pos + + def get_pos(self): + return self.pos + + def add(self, state, next_state, action, reward, done): + + self.states[self.pos] = th.FloatTensor(state) + self.next_states[self.pos] = th.FloatTensor(next_state) + self.actions[self.pos] = th.FloatTensor(action) + self.rewards[self.pos] = th.FloatTensor([reward]) + self.dones[self.pos] = th.FloatTensor([done]) + + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + self.pos = 0 + + def sample(self, batch_size): + + upper_bound = self.buffer_size if self.full else self.pos + batch_inds = th.LongTensor( + np.random.randint(0, upper_bound, size=batch_size)) + + return (self.states[batch_inds].to(self.device), + self.actions[batch_inds].to(self.device), + self.next_states[batch_inds].to(self.device), + self.dones[batch_inds].to(self.device), + self.rewards[batch_inds].to(self.device)) diff --git a/torchy_baselines/common/utils.py b/torchy_baselines/common/utils.py new file mode 100644 index 000000000..a1030cbf7 --- /dev/null +++ b/torchy_baselines/common/utils.py @@ -0,0 +1,20 @@ +import random + +import torch as th +import numpy as np + + +def set_random_seed(seed, using_cuda=False): + """ + Seed the different random generators + :param seed: (int) + :param using_cuda: (bool) + """ + random.seed(seed) + np.random.seed(seed) + th.manual_seed(seed) + + if using_cuda: + # Make CuDNN Determinist + th.backends.cudnn.deterministic = True + th.cuda.manual_seed(seed) diff --git a/torchy_baselines/td3/__init__.py b/torchy_baselines/td3/__init__.py new file mode 100644 index 000000000..51225e620 --- /dev/null +++ b/torchy_baselines/td3/__init__.py @@ -0,0 +1 @@ +from torchy_baselines.td3.td3 import TD3 diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py new file mode 100644 index 000000000..b9a8c27c3 --- /dev/null +++ b/torchy_baselines/td3/policies.py @@ -0,0 +1,82 @@ +import torch as th +import torch.nn as nn + +from torchy_baselines.common.policies import BasePolicy + + +class Actor(nn.Module): + def __init__(self, state_dim, action_dim, net_arch=None): + super(Actor, self).__init__() + + if net_arch is None: + net_arch = [400, 300] + + self.actor_net = nn.Sequential( + nn.Linear(state_dim, net_arch[0]), + nn.ReLU(), + nn.Linear(net_arch[0], net_arch[1]), + nn.ReLU(), + nn.Linear(net_arch[1], action_dim), + nn.Tanh(), + ) + + def forward(self, x): + return self.actor_net(x) + + +class Critic(nn.Module): + def __init__(self, state_dim, action_dim, net_arch=None): + super(Critic, self).__init__() + + if net_arch is None: + net_arch = [400, 300] + + self.q1_net = nn.Sequential( + nn.Linear(state_dim + action_dim, net_arch[0]), + nn.ReLU(), + nn.Linear(net_arch[0], net_arch[1]), + nn.ReLU(), + nn.Linear(net_arch[1], 1), + ) + + self.q2_net = nn.Sequential( + nn.Linear(state_dim + action_dim, net_arch[0]), + nn.ReLU(), + nn.Linear(net_arch[0], net_arch[1]), + nn.ReLU(), + nn.Linear(net_arch[1], 1), + ) + + def forward(self, obs, action): + qvalue_input = th.cat([obs, action], dim=1) + return self.q1_net(qvalue_input), self.q2_net(qvalue_input) + + def q1_forward(self, obs, action): + return self.q1_net( th.cat([obs, action], dim=1)) + + +class TD3Policy(BasePolicy): + def __init__(self, observation_space, action_space, + learning_rate=1e-3, net_arch=None, device='cpu'): + super(TD3Policy, self).__init__(observation_space, action_space, device) + self.state_dim = self.observation_space.shape[0] + self.action_dim = self.action_space.shape[0] + self.net_arch = net_arch + self._build(learning_rate) + + def _build(self, learning_rate): + self.actor = self.make_actor() + self.actor_target = self.make_actor() + self.actor_target.load_state_dict(self.actor.state_dict()) + self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=learning_rate) + + self.critic = self.make_critic() + self.critic_target = self.make_critic() + self.critic_target.load_state_dict(self.critic.state_dict()) + self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate) + + def make_actor(self): + return Actor(self.state_dim, self.action_dim, self.net_arch).to(self.device) + + def make_critic(self): + return Critic(self.state_dim, self.action_dim, self.net_arch).to(self.device) diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py new file mode 100644 index 000000000..6656aa29f --- /dev/null +++ b/torchy_baselines/td3/td3.py @@ -0,0 +1,179 @@ +import torch as th +import torch.nn.functional as F +import numpy as np + +from torchy_baselines.common.base_class import BaseRLModel +from torchy_baselines.common.replay_buffer import ReplayBuffer +from torchy_baselines.common.utils import set_random_seed +from torchy_baselines.td3.policies import TD3Policy + + +class TD3(BaseRLModel): + """ + Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3) + Paper: https://arxiv.org/abs/1802.09477 + Code: https://github.com/sfujim/TD3 + """ + + def __init__(self, policy, env, policy_kwargs=None, verbose=0, + buffer_size=int(1e6), learning_rate=1e-3, seed=0, device='cpu', + action_noise_std=0.1, start_timesteps=10000, _init_setup_model=True): + + super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose) + + self.max_action = float(self.action_space.high) + self.replay_buffer = None + self.policy = None + self.device = device + self.action_noise_std = action_noise_std + self.learning_rate = learning_rate + self.buffer_size = buffer_size + self.start_timesteps = start_timesteps + self.seed = 0 + + if _init_setup_model: + self._setup_model() + + def _setup_model(self, seed=None): + state_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] + set_random_seed(self.seed, using_cuda=self.device != 'cpu') + + self.replay_buffer = ReplayBuffer(self.buffer_size, state_dim, action_dim, self.device) + self.policy = TD3Policy(self.observation_space, self.action_space, + self.learning_rate, device=self.device, **self.policy_kwargs) + self._create_aliases() + + def _create_aliases(self): + self.actor = self.policy.actor + self.actor_target = self.policy.actor_target + self.critic = self.policy.critic + self.critic_target = self.policy.critic_target + + def select_action(self, observation): + with th.no_grad(): + observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device) + return self.actor(observation).cpu().data.numpy().flatten() + + def predict(self, observation, state=None, mask=None, deterministic=True): + """ + Get the model's action from an observation + + :param observation: (np.ndarray) the input observation + :param state: (np.ndarray) The last states (can be None, used in recurrent policies) + :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) + :param deterministic: (bool) Whether or not to return deterministic actions. + :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) + """ + return self.max_action * self.select_action(observation) + + def train(self, n_iterations, batch_size=100, discount=0.99, + tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2): + + for it in range(n_iterations): + + # Sample replay buffer + state, action, next_state, done, reward = self.replay_buffer.sample(batch_size) + + # Select action according to policy and add clipped noise + noise = action.data.normal_(0, policy_noise).to(self.device) + noise = noise.clamp(-noise_clip, noise_clip) + next_action = (self.actor_target(next_state) + noise).clamp(-1, 1) + + # Compute the target Q value + target_q1, target_q2 = self.critic_target(next_state, next_action) + target_q = th.min(target_q1, target_q2) + target_q = reward + ((1 - done) * discount * target_q).detach() + + # Get current Q estimates + current_q1, current_q2 = self.critic(state, action) + + # Compute critic loss + critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) + + # Optimize the critic + self.critic.optimizer.zero_grad() + critic_loss.backward() + self.critic.optimizer.step() + + # Delayed policy updates + if it % policy_freq == 0: + + # Compute actor loss + actor_loss = -self.critic.q1_forward(state, self.actor(state)).mean() + + # Optimize the actor + self.actor.optimizer.zero_grad() + actor_loss.backward() + self.actor.optimizer.step() + + # Update the frozen target models + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) + + for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): + target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) + + def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, + tb_log_name="TD3", reset_num_timesteps=True): + num_timesteps = 0 + timesteps_since_eval = 0 + episode_num = 0 + done = True + + while num_timesteps < total_timesteps: + + if done: + if num_timesteps > 0: + print("Total T: {} Episode Num: {} Episode T: {} Reward: {}".format( + num_timesteps, episode_num, episode_timesteps, episode_reward)) + self.train(episode_timesteps) + + # Evaluate episode + # if timesteps_since_eval >= args.eval_freq: + # timesteps_since_eval %= args.eval_freq + # evaluations.append(evaluate_policy(policy)) + + # Reset environment + obs = self.env.reset() + episode_reward = 0 + episode_timesteps = 0 + episode_num += 1 + + # Select action randomly or according to policy + if num_timesteps < self.start_timesteps: + action = self.env.action_space.sample() + else: + action = self.policy.select_action(np.array(obs)) + + if self.action_noise_std > 0: + # NOTE: in the original implementation, the noise is applied to the unscaled action + action_noise = np.random.normal(0, self.action_noise_std, size=self.action_space.shape[0]) + action = (action + action_noise).clip(-1, 1) + + # Rescale and perform action + new_obs, reward, done, _ = self.env.step(self.max_action * action) + done_bool = 0 if episode_timesteps + 1 == self.env._max_episode_steps else float(done) + episode_reward += reward + + # Store data in replay buffer + # self.replay_buffer.add(state, next_state, action, reward, done) + self.replay_buffer.add(obs, new_obs, action, reward, done_bool) + + obs = new_obs + + episode_timesteps += 1 + num_timesteps += 1 + timesteps_since_eval += 1 + + def save(self, path): + if not path.endswith('.pth'): + path += '.pth' + th.save(self.policy.state_dict(), path) + + def load(self, path, env=None, **_kwargs): + if not path.endswith('.pth'): + path += '.pth' + if env is not None: + pass + self.policy.load_state_dict(th.load(path)) + self._create_aliases()