Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pid-lagrange, test): add algo and update test #210

Merged
merged 23 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -424,3 +424,4 @@ nums
bdg
num
gpu
Stooke
2 changes: 0 additions & 2 deletions docs/source/utils/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ OmniSafe Distributed

setup_distributed
get_rank
is_master
world_size
fork
avg_tensor
Expand All @@ -34,7 +33,6 @@ Set up distributed training

.. autofunction:: setup_distributed
.. autofunction:: get_rank
.. autofunction:: is_master
.. autofunction:: world_size
.. autofunction:: fork

Expand Down
5 changes: 0 additions & 5 deletions docs/source/utils/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ OmniSafe Math

get_transpose
get_diagonal
safe_inverse
discount_cumsum
conjugate_gradients
gaussian_kl
SafeTanhTransformer
TanhNormal

Expand All @@ -27,7 +25,6 @@ Tensor Operations

.. autofunction:: get_transpose
.. autofunction:: get_diagonal
.. autofunction:: safe_inverse
.. autofunction:: discount_cumsum
.. autofunction:: conjugate_gradients

Expand All @@ -42,8 +39,6 @@ Distribution Operations
Documentation
^^^

.. autofunction:: gaussian_kl

.. autoclass:: SafeTanhTransformer
:members:
:private-members:
Expand Down
2 changes: 0 additions & 2 deletions docs/source/utils/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ OmniSafe Model Utils
initialize_layer
get_activation
build_mlp_network
set_optimizer

Model Building Utils
--------------------
Expand All @@ -23,4 +22,3 @@ Model Building Utils
.. autofunction:: initialize_layer
.. autofunction:: get_activation
.. autofunction:: build_mlp_network
.. autofunction:: set_optimizer
4 changes: 3 additions & 1 deletion examples/evaluate_saved_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
LOG_DIR = ''
if __name__ == '__main__':
evaluator = omnisafe.Evaluator(render_mode='rgb_array')
for item in os.scandir(os.path.join(LOG_DIR, 'torch_save')):
scan_dir = os.scandir(os.path.join(LOG_DIR, 'torch_save'))
for item in scan_dir:
if item.is_file() and item.name.split('.')[-1] == 'pt':
evaluator.load_saved(
save_dir=LOG_DIR,
Expand All @@ -35,3 +36,4 @@
)
evaluator.render(num_episodes=1)
evaluator.evaluate(num_episodes=1)
scan_dir.close()
2 changes: 1 addition & 1 deletion examples/train_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
parser.add_argument(
'--vector-env-nums',
type=int,
default=2,
default=1,
Gaiejj marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value of vector-env-nums needs to be 1 because the early-terminated algorithms does not support vector environments. The original default value 2 will trigger error, so we set 1 instead.

metavar='VECTOR-ENV',
help='number of vector envs to use for training',
)
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/adapter/early_terminated_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, env_id: str, num_envs: int, seed: int, cfgs: Config) -> None:

super().__init__(env_id, num_envs, seed, cfgs)

self._cost_limit = cfgs.cost_limit
self._cost_limit = cfgs.algo_cfgs.cost_limit
self._cost_logger = torch.zeros(self._env.num_envs)

def step(
Expand Down
5 changes: 0 additions & 5 deletions omnisafe/adapter/simmer_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ def safety_budget(self) -> torch.Tensor:
"""Return the safety budget."""
return self._safety_budget

@property
def upper_budget(self) -> torch.Tensor:
"""Return the upper budget."""
return self._upper_budget

@safety_budget.setter
def safety_budget(self, safety_budget: torch.Tensor) -> None:
"""Set the safety budget."""
Expand Down
7 changes: 7 additions & 0 deletions omnisafe/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,26 @@
# On-Policy Safe
from omnisafe.algorithms.on_policy import (
CPO,
CPPOPID,
CUP,
FOCOPS,
PCPO,
PDO,
PPO,
RCPO,
TRPO,
TRPOPID,
NaturalPG,
OnCRPO,
PolicyGradient,
PPOEarlyTerminated,
PPOLag,
PPOSaute,
PPOSimmerPID,
TRPOEarlyTerminated,
TRPOLag,
TRPOSaute,
TRPOSimmerPID,
)


Expand Down
39 changes: 15 additions & 24 deletions omnisafe/algorithms/algo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@

from __future__ import annotations

import difflib
import os
import sys
from typing import Any

import psutil
import torch

from omnisafe.algorithms import ALGORITHM2TYPE, ALGORITHMS, registry
Expand Down Expand Up @@ -59,17 +57,10 @@ def __init__(

def _init_config(self):
"""Init config."""
assert self.algo in ALGORITHMS['all'], (
f"{self.algo} doesn't exist. "
f"Did you mean {difflib.get_close_matches(self.algo, ALGORITHMS['all'], n=1)[0]}?"
)
assert (
self.algo in ALGORITHMS['all']
), f"{self.algo} doesn't exist. Please choose from {ALGORITHMS['all']}."
self.algo_type = ALGORITHM2TYPE.get(self.algo, '')
if self.algo_type is None or self.algo_type == '':
raise ValueError(f'{self.algo} is not supported!')
if self.algo_type in {'off-policy', 'model-based'} and self.train_terminal_cfgs is not None:
assert (
self.train_terminal_cfgs['parallel'] == 1
), 'off-policy or model-based only support parallel==1!'
cfgs = get_default_kwargs_yaml(self.algo, self.env_id, self.algo_type)

# update the cfgs from custom configurations
Expand Down Expand Up @@ -105,6 +96,10 @@ def _init_config(self):
cfgs.train_cfgs.recurisve_update(
{'epochs': cfgs.train_cfgs.total_steps // cfgs.algo_cfgs.update_cycle},
)
if self.algo_type in {'off-policy', 'model-based'}:
assert (
cfgs.train_cfgs.parallel == 1
), 'off-policy or model-based only support parallel==1!'
return cfgs

def _init_checks(self):
Expand All @@ -113,21 +108,14 @@ def _init_checks(self):
assert isinstance(self.cfgs.train_cfgs.parallel, int), 'parallel must be an integer!'
assert self.cfgs.train_cfgs.parallel > 0, 'parallel must be greater than 0!'
assert (
isinstance(self.custom_cfgs, dict) or self.custom_cfgs is None
), 'custom_cfgs must be a dict!'
assert self.env_id in support_envs(), (
f"{self.env_id} doesn't exist. "
f'Did you mean {difflib.get_close_matches(self.env_id, support_envs(), n=1)[0]}?'
)
self.env_id in support_envs()
), f"{self.env_id} doesn't exist. Please choose from {support_envs()}."

def learn(self):
"""Agent Learning."""
# Use number of physical cores as default.
# If also hardware threading CPUs should be used
# enable this by the use_number_of_threads=True
physical_cores = psutil.cpu_count(logical=False)
use_number_of_threads = bool(self.cfgs.train_cfgs.parallel > physical_cores)

check_all_configs(self.cfgs, self.algo_type)
device = self.cfgs.train_cfgs.device
if device == 'cpu':
Expand All @@ -137,7 +125,6 @@ def learn(self):
torch.cuda.set_device(self.cfgs.train_cfgs.device)
if distributed.fork(
self.cfgs.train_cfgs.parallel,
use_number_of_threads=use_number_of_threads,
device=self.cfgs.train_cfgs.device,
):
# Re-launches the current script with workers linked by MPI
Expand Down Expand Up @@ -186,10 +173,12 @@ def evaluate(self, num_episodes: int = 10, cost_criteria: float = 1.0):
cost_criteria (float): the cost criteria to evaluate.
"""
assert self._evaluator is not None, 'Please run learn() first!'
for item in os.scandir(os.path.join(self.agent.logger.log_dir, 'torch_save')):
scan_dir = os.scandir(os.path.join(self.agent.logger.log_dir, 'torch_save'))
for item in scan_dir:
if item.is_file() and item.name.split('.')[-1] == 'pt':
self._evaluator.load_saved(save_dir=self.agent.logger.log_dir, model_name=item.name)
self._evaluator.evaluate(num_episodes=num_episodes, cost_criteria=cost_criteria)
scan_dir.close()

# pylint: disable-next=too-many-arguments
def render(
Expand All @@ -211,7 +200,8 @@ def render(
height (int): height of the rendered image.
"""
assert self._evaluator is not None, 'Please run learn() first!'
for item in os.scandir(os.path.join(self.agent.logger.log_dir, 'torch_save')):
scan_dir = os.scandir(os.path.join(self.agent.logger.log_dir, 'torch_save'))
for item in scan_dir:
if item.is_file() and item.name.split('.')[-1] == 'pt':
self._evaluator.load_saved(
save_dir=self.agent.logger.log_dir,
Expand All @@ -222,3 +212,4 @@ def render(
height=height,
)
self._evaluator.render(num_episodes=num_episodes)
scan_dir.close()
14 changes: 8 additions & 6 deletions omnisafe/algorithms/on_policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,24 @@

from omnisafe.algorithms.on_policy import (
base,
early_terminated,
first_order,
naive_lagrange,
penalty_function,
pid_lagrange,
saute,
second_order,
simmer,
)
from omnisafe.algorithms.on_policy.base import PPO, TRPO, NaturalPG, PolicyGradient

# from omnisafe.algorithms.on_policy.early_terminated import PPOEarlyTerminated, PPOLagEarlyTerminated
from omnisafe.algorithms.on_policy.early_terminated import PPOEarlyTerminated, TRPOEarlyTerminated
from omnisafe.algorithms.on_policy.first_order import CUP, FOCOPS
from omnisafe.algorithms.on_policy.naive_lagrange import PDO, RCPO, OnCRPO, PPOLag, TRPOLag
from omnisafe.algorithms.on_policy.penalty_function import IPO, P3O
from omnisafe.algorithms.on_policy.saute import TRPOSaute
from omnisafe.algorithms.on_policy.pid_lagrange import CPPOPID, TRPOPID
from omnisafe.algorithms.on_policy.saute import PPOSaute, TRPOSaute
from omnisafe.algorithms.on_policy.second_order import CPO, PCPO
from omnisafe.algorithms.on_policy.simmer import TRPOSimmerPID
from omnisafe.algorithms.on_policy.simmer import PPOSimmerPID, TRPOSimmerPID


# from omnisafe.algorithms.on_policy.pid_lagrange import CPPOPid, TRPOPid
Expand All @@ -47,11 +49,11 @@

__all__ = [
*base.__all__,
# *early_terminated.__all__,
*early_terminated.__all__,
*first_order.__all__,
*naive_lagrange.__all__,
*penalty_function.__all__,
# *pid_lagrange.__all__,
*pid_lagrange.__all__,
*saute.__all__,
*second_order.__all__,
*simmer.__all__,
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/algorithms/on_policy/base/policy_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _init_log(self) -> None:
self._logger.setup_torch_saver(what_to_save)
self._logger.torch_save()

self._logger.register_key('Metrics/EpRet', window_length=50)
self._logger.register_key('Metrics/EpRet', window_length=50, min_and_max=True)
self._logger.register_key('Metrics/EpCost', window_length=50)
self._logger.register_key('Metrics/EpLen', window_length=50)

Expand Down
4 changes: 1 addition & 3 deletions omnisafe/algorithms/on_policy/base/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def _search_step_size(
# average processes.... multi-processing style like: mpi_tools.mpi_avg(xxx)
loss_improve = distributed.dist_avg(loss_improve)
self._logger.log(f'Expected Improvement: {expected_improve} Actual: {loss_improve}')
if not torch.isfinite(loss):
self._logger.log('WARNING: loss_pi not finite')
elif loss_improve < 0:
if loss_improve < 0:
self._logger.log('INFO: did not improve improve <0')
elif kl > self._cfgs.algo_cfgs.target_kl:
self._logger.log('INFO: violated KL constraint.')
Expand Down
24 changes: 24 additions & 0 deletions omnisafe/algorithms/on_policy/early_terminated/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Early Terminated Reinforcement Learning algorithms."""
Gaiejj marked this conversation as resolved.
Show resolved Hide resolved

from omnisafe.algorithms.on_policy.early_terminated.ppo_early_terminated import PPOEarlyTerminated
from omnisafe.algorithms.on_policy.early_terminated.trpo_early_terminated import TRPOEarlyTerminated


__all__ = [
'TRPOEarlyTerminated',
'PPOEarlyTerminated',
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Early terminated version of the PPO algorithm."""


from omnisafe.adapter.early_terminated_adapter import EarlyTerminatedAdapter
from omnisafe.algorithms import registry
from omnisafe.algorithms.on_policy.base.ppo import PPO
from omnisafe.utils import distributed


@registry.register
class PPOEarlyTerminated(PPO):
"""The Early terminated version of the PPO algorithm.

A simple combination of the Early terminated RL and the Proximal Policy Optimization algorithm.
"""

def _init_env(self) -> None:
"""Initialize the environment.

Omnisafe use :class:`omnisafe.adapter.EarlyTerminatedAdapter` to adapt the environment to the algorithm.

User can customize the environment by inheriting this function.

Example:
>>> def _init_env(self) -> None:
>>> self._env = CustomAdapter()
"""
if self._cfgs.train_cfgs.vector_env_nums != 1:
Gaiejj marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError('Early terminated RL only supports single environment.')
self._env = EarlyTerminatedAdapter(
self._env_id,
self._cfgs.train_cfgs.vector_env_nums,
self._seed,
self._cfgs,
)
assert (self._cfgs.algo_cfgs.update_cycle) % (
distributed.world_size() * self._cfgs.train_cfgs.vector_env_nums
) == 0, 'The number of steps per epoch is not divisible by the number of environments.'
self._steps_per_epoch = (
self._cfgs.algo_cfgs.update_cycle
// distributed.world_size()
// self._cfgs.train_cfgs.vector_env_nums
)
Loading