From a73428b0b4e5c7f09add5a237227143712a49eb5 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:00:40 +0100 Subject: [PATCH] [Environment] Petting zoo (#1471) Signed-off-by: Matteo Bettini Co-authored-by: vmoens --- .../scripts_pettingzoo/environment.yml | 23 + .../linux_libs/scripts_pettingzoo/install.sh | 46 + .../scripts_pettingzoo/post_process.sh | 6 + .../scripts_pettingzoo/run-clang-format.py | 356 +++++++ .../linux_libs/scripts_pettingzoo/run_test.sh | 30 + .../scripts_pettingzoo/setup_env.sh | 53 ++ .github/workflows/test-linux-pettingzoo.yml | 40 + docs/source/reference/envs.rst | 13 + test/test_env.py | 25 +- test/test_libs.py | 220 ++++- torchrl/envs/__init__.py | 2 + torchrl/envs/common.py | 19 +- torchrl/envs/libs/pettingzoo.py | 878 ++++++++++++++++++ torchrl/envs/libs/vmas.py | 7 +- torchrl/envs/utils.py | 121 ++- 15 files changed, 1828 insertions(+), 11 deletions(-) create mode 100644 .github/unittest/linux_libs/scripts_pettingzoo/environment.yml create mode 100755 .github/unittest/linux_libs/scripts_pettingzoo/install.sh create mode 100755 .github/unittest/linux_libs/scripts_pettingzoo/post_process.sh create mode 100755 .github/unittest/linux_libs/scripts_pettingzoo/run-clang-format.py create mode 100755 .github/unittest/linux_libs/scripts_pettingzoo/run_test.sh create mode 100755 .github/unittest/linux_libs/scripts_pettingzoo/setup_env.sh create mode 100644 .github/workflows/test-linux-pettingzoo.yml create mode 100644 torchrl/envs/libs/pettingzoo.py diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml b/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml new file mode 100644 index 00000000000..76f97355f7a --- /dev/null +++ b/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml @@ -0,0 +1,23 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - swig + - pip: + - cloudpickle + - gym + - gym-notices + - importlib-metadata + - six + - zipp + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - autorom[accept-rom-license] + - pettingzoo[all]==1.24.1 diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/install.sh b/.github/unittest/linux_libs/scripts_pettingzoo/install.sh new file mode 100755 index 00000000000..cb36c7cc48a --- /dev/null +++ b/.github/unittest/linux_libs/scripts_pettingzoo/install.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git + +# smoke test +python -c "import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/post_process.sh b/.github/unittest/linux_libs/scripts_pettingzoo/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_pettingzoo/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/run-clang-format.py b/.github/unittest/linux_libs/scripts_pettingzoo/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_pettingzoo/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh b/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh new file mode 100755 index 00000000000..d215b514081 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env +apt-get update && apt-get install -y git wget + + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir +export MKL_THREADING_LAYER=GNU +# more logging +export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON + +# this workflow only tests the libs +python -c "import pettingzoo" + +python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestPettingZoo --error-for-skips +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/setup_env.sh b/.github/unittest/linux_libs/scripts_pettingzoo/setup_env.sh new file mode 100755 index 00000000000..a3f833112a9 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_pettingzoo/setup_env.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 4. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune + +# 5. Download atari roms +autorom_dir="${env_dir}/lib/python${PYTHON_VERSION}/site-packages/AutoROM/roms" +multi_atari_rom_dir="${env_dir}/lib/python${PYTHON_VERSION}/site-packages/multi_agent_ale_py/roms" +ln -s "${autorom_dir}" "${multi_atari_rom_dir}" diff --git a/.github/workflows/test-linux-pettingzoo.yml b/.github/workflows/test-linux-pettingzoo.yml new file mode 100644 index 00000000000..bbf775f4c27 --- /dev/null +++ b/.github/workflows/test-linux-pettingzoo.yml @@ -0,0 +1,40 @@ +name: PettingZoo Tests on Linux + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +concurrency: + # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. + # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} + cancel-in-progress: true + +jobs: + unittests: + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "11.7" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="11.7" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + nvidia-smi + + bash .circleci/unittest/linux_libs/scripts_pettingzoo/setup_env.sh + bash .circleci/unittest/linux_libs/scripts_pettingzoo/install.sh + bash .circleci/unittest/linux_libs/scripts_pettingzoo/run_test.sh + bash .circleci/unittest/linux_libs/scripts_pettingzoo/post_process.sh diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index f6e6c536ce5..6527418ea75 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -219,6 +219,8 @@ etc.), but one can not use an arbitrary TorchRL environment, as it is possible w Multi-agent environments ------------------------ +.. currentmodule:: torchrl.envs + TorchRL supports multi-agent learning out-of-the-box. *The same classes used in a single-agent learning pipeline can be seamlessly used in multi-agent contexts, without any modification or dedicated multi-agent infrastructure.* @@ -344,6 +346,15 @@ single agent standards. Note that `env.reward_spec == env.output_spec["full_reward_spec"][env.reward_key]`. +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + MarlGroupMapType + check_marl_grouping + + + Transforms ---------- .. currentmodule:: torchrl.envs.transforms @@ -626,5 +637,7 @@ the following function will return ``1`` when queried: jumanji.JumanjiEnv jumanji.JumanjiWrapper openml.OpenMLEnv + pettingzoo.PettingZooEnv + pettingzoo.PettingZooWrapper vmas.VmasEnv vmas.VmasWrapper diff --git a/test/test_env.py b/test/test_env.py index f5d03ba366c..ffa336331f3 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -61,7 +61,13 @@ from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv -from torchrl.envs.utils import check_env_specs, make_composite_from_td, step_mdp +from torchrl.envs.utils import ( + check_env_specs, + check_marl_grouping, + make_composite_from_td, + MarlGroupMapType, + step_mdp, +) from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator from torchrl.modules.tensordict_module import WorldModelWrapper @@ -1647,6 +1653,23 @@ def test_make_spec_from_td(): assert val.dtype is spec[key].dtype +@pytest.mark.parametrize("group_type", list(MarlGroupMapType)) +def test_marl_group_type(group_type): + agent_names = ["agent"] + check_marl_grouping(group_type.get_group_map(agent_names), agent_names) + + agent_names = ["agent", "agent"] + with pytest.raises(ValueError): + check_marl_grouping(group_type.get_group_map(agent_names), agent_names) + + agent_names = ["agent_0", "agent_1"] + check_marl_grouping(group_type.get_group_map(agent_names), agent_names) + + agent_names = [] + with pytest.raises(ValueError): + check_marl_grouping(group_type.get_group_map(agent_names), agent_names) + + @pytest.mark.skipif(not torch.cuda.device_count(), reason="No cuda device") class TestConcurrentEnvs: """Concurrent parallel envs on multiple procs can interfere.""" diff --git a/test/test_libs.py b/test/test_libs.py index 2a44e5a70bc..c9e971e75c7 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -62,9 +62,10 @@ from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv from torchrl.envs.libs.openml import OpenMLEnv +from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv from torchrl.envs.libs.robohive import RoboHiveEnv from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper -from torchrl.envs.utils import check_env_specs, ExplorationType +from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType from torchrl.envs.vec_env import _has_envpool, MultiThreadedEnvWrapper, SerialEnv from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator @@ -101,6 +102,7 @@ if _has_vmas: import vmas + if _has_envpool: import envpool @@ -1635,6 +1637,222 @@ def test_env(self, task, num_envs, device): # break +@pytest.mark.skipif(not _has_pettingzoo, reason="PettingZoo not found") +class TestPettingZoo: + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("continuous_actions", [True, False]) + @pytest.mark.parametrize("use_mask", [True]) + @pytest.mark.parametrize("return_state", [True, False]) + @pytest.mark.parametrize( + "group_map", + [None, MarlGroupMapType.ALL_IN_ONE_GROUP, MarlGroupMapType.ONE_GROUP_PER_AGENT], + ) + def test_pistonball( + self, parallel, continuous_actions, use_mask, return_state, group_map + ): + + kwargs = {"n_pistons": 21, "continuous": continuous_actions} + + env = PettingZooEnv( + task="pistonball_v6", + parallel=parallel, + seed=0, + return_state=return_state, + use_mask=use_mask, + group_map=group_map, + **kwargs, + ) + + check_env_specs(env) + + @pytest.mark.parametrize( + "wins_player_0", + [True, False], + ) + def test_tic_tac_toe(self, wins_player_0): + env = PettingZooEnv( + task="tictactoe_v3", + parallel=False, + group_map={"player": ["player_1", "player_2"]}, + categorical_actions=False, + seed=0, + use_mask=True, + ) + + class Policy: + + action = 0 + t = 0 + + def __call__(self, td): + new_td = env.input_spec["full_action_spec"].zero() + + player_acting = 0 if self.t % 2 == 0 else 1 + other_player = 1 if self.t % 2 == 0 else 0 + # The acting player has "mask" True and "action_mask" set to the available actions + assert td["player", "mask"][player_acting].all() + assert td["player", "action_mask"][player_acting].any() + # The non-acting player has "mask" False and "action_mask" set to all Trues + assert not td["player", "mask"][other_player].any() + assert td["player", "action_mask"][other_player].all() + + if self.t % 2 == 0: + if not wins_player_0 and self.t == 4: + new_td["player", "action"][0][self.action + 1] = 1 + else: + new_td["player", "action"][0][self.action] = 1 + else: + new_td["player", "action"][1][self.action + 6] = 1 + if td["player", "mask"][1].all(): + self.action += 1 + self.t += 1 + return td.update(new_td) + + td = env.rollout(100, policy=Policy()) + + assert td.batch_size[0] == (5 if wins_player_0 else 6) + assert (td[:-1]["next", "player", "reward"] == 0).all() + if wins_player_0: + assert ( + td[-1]["next", "player", "reward"] == torch.tensor([[1], [-1]]) + ).all() + else: + assert ( + td[-1]["next", "player", "reward"] == torch.tensor([[-1], [1]]) + ).all() + + @pytest.mark.parametrize( + "task", + [ + "multiwalker_v9", + "waterworld_v4", + "pursuit_v4", + "simple_spread_v3", + "simple_v3", + "rps_v2", + "cooperative_pong_v5", + "pistonball_v6", + ], + ) + def test_envs_one_group_parallel(self, task): + env = PettingZooEnv( + task=task, + parallel=True, + seed=0, + use_mask=False, + ) + check_env_specs(env) + env.rollout(100, break_when_any_done=False) + + @pytest.mark.parametrize( + "task", + [ + "multiwalker_v9", + "waterworld_v4", + "pursuit_v4", + "simple_spread_v3", + "simple_v3", + "rps_v2", + "cooperative_pong_v5", + "pistonball_v6", + "connect_four_v3", + "tictactoe_v3", + "chess_v6", + "gin_rummy_v4", + "tictactoe_v3", + ], + ) + def test_envs_one_group_aec(self, task): + env = PettingZooEnv( + task=task, + parallel=False, + seed=0, + use_mask=True, + ) + check_env_specs(env) + env.rollout(100, break_when_any_done=False) + + @pytest.mark.parametrize( + "task", + [ + "simple_adversary_v3", + "simple_crypto_v3", + "simple_push_v3", + "simple_reference_v3", + "simple_speaker_listener_v4", + "simple_tag_v3", + "simple_world_comm_v3", + "knights_archers_zombies_v10", + "basketball_pong_v3", + "boxing_v2", + "foozpong_v3", + ], + ) + def test_envs_more_groups_parallel(self, task): + env = PettingZooEnv( + task=task, + parallel=True, + seed=0, + use_mask=False, + ) + check_env_specs(env) + env.rollout(100, break_when_any_done=False) + + @pytest.mark.parametrize( + "task", + [ + "simple_adversary_v3", + "simple_crypto_v3", + "simple_push_v3", + "simple_reference_v3", + "simple_speaker_listener_v4", + "simple_tag_v3", + "simple_world_comm_v3", + "knights_archers_zombies_v10", + "basketball_pong_v3", + "boxing_v2", + "foozpong_v3", + "go_v5", + ], + ) + def test_envs_more_groups_aec(self, task): + env = PettingZooEnv( + task=task, + parallel=False, + seed=0, + use_mask=True, + ) + check_env_specs(env) + env.rollout(100, break_when_any_done=False) + + @pytest.mark.parametrize("task", ["knights_archers_zombies_v10", "pistonball_v6"]) + @pytest.mark.parametrize("parallel", [True, False]) + def test_vec_env(self, task, parallel): + env_fun = lambda: PettingZooEnv( + task=task, + parallel=parallel, + seed=0, + use_mask=not parallel, + ) + vec_env = ParallelEnv(2, create_env_fn=env_fun) + vec_env.rollout(100, break_when_any_done=False) + + @pytest.mark.parametrize("task", ["knights_archers_zombies_v10", "pistonball_v6"]) + @pytest.mark.parametrize("parallel", [True, False]) + def test_collector(self, task, parallel): + env_fun = lambda: PettingZooEnv( + task=task, + parallel=parallel, + seed=0, + use_mask=not parallel, + ) + coll = SyncDataCollector( + create_env_fn=env_fun, frames_per_batch=30, total_frames=60, policy=None + ) + for _ in coll: + break + + class TestRoboHive: @pytest.mark.parametrize("envname", RoboHiveEnv.env_list) @pytest.mark.parametrize("from_pixels", [True, False]) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 1fb8f37a6ec..9a9944248b9 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -56,10 +56,12 @@ ) from .utils import ( check_env_specs, + check_marl_grouping, exploration_mode, exploration_type, ExplorationType, make_composite_from_td, + MarlGroupMapType, set_exploration_mode, set_exploration_type, step_mdp, diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 5ecdf148238..56b045cb823 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -225,6 +225,7 @@ def __init__( dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, run_type_checks: bool = False, + allow_done_after_reset: bool = False, ): if device is None: device = torch.device("cpu") @@ -250,6 +251,7 @@ def __init__( # it's already been set self.batch_size = torch.Size(batch_size) self._run_type_checks = run_type_checks + self._allow_done_after_reset = allow_done_after_reset @classmethod def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs): @@ -1300,13 +1302,14 @@ def reset( self.output_spec["full_done_spec"][done_key].zero(leading_dim), ) - for done_key in self.done_keys: - if done_key not in _reset_map: - if tensordict_reset.get(done_key).any(): - raise DONE_AFTER_RESET_ERROR - else: - if tensordict_reset.get(done_key)[_reset_map[done_key]].any(): - raise DONE_AFTER_RESET_ERROR + if not self._allow_done_after_reset: + for done_key in self.done_keys: + if done_key not in _reset_map: + if tensordict_reset.get(done_key).any(): + raise DONE_AFTER_RESET_ERROR + else: + if tensordict_reset.get(done_key)[_reset_map[done_key]].any(): + raise DONE_AFTER_RESET_ERROR if tensordict is not None: tensordict.update(tensordict_reset) @@ -1711,6 +1714,7 @@ def __init__( dtype: Optional[np.dtype] = None, device: DEVICE_TYPING = None, batch_size: Optional[torch.Size] = None, + allow_done_after_reset: bool = False, **kwargs, ): if device is None: @@ -1719,6 +1723,7 @@ def __init__( device=device, dtype=dtype, batch_size=batch_size, + allow_done_after_reset=allow_done_after_reset, ) if len(args): raise ValueError( diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py new file mode 100644 index 00000000000..4b260b79cf6 --- /dev/null +++ b/torchrl/envs/libs/pettingzoo.py @@ -0,0 +1,878 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import importlib +from typing import Dict, List, Optional, Tuple, Union + +import torch +from tensordict.tensordict import TensorDictBase + +from torchrl.data import ( + CompositeSpec, + DiscreteTensorSpec, + OneHotDiscreteTensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.envs.common import _EnvWrapper +from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend +from torchrl.envs.utils import _replace_last, check_marl_grouping, MarlGroupMapType + + +_has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None + + +def _get_envs() -> List[str]: + if not _has_pettingzoo: + return [] + from pettingzoo.utils.all_modules import all_environments + + return list(all_environments.keys()) + + +class PettingZooWrapper(_EnvWrapper): + """PettingZoo environment wrapper. + + To install petting zoo follow the guide `here __`. + + This class is a general torchrl wrapper for all PettingZoo environments. + It is able to wrap both ``pettingzoo.AECEnv`` and ``pettingzoo.ParallelEnv``. + + Let's see how more in details: + + In wrapped ``pettingzoo.ParallelEnv`` all agents will step at each environment step. + If the number of agents during the task varies, please set ``use_mask=True``. + ``"mask"`` will be provided + as an output in each group and should be used to mask out dead agents. + The environment will be reset as soon as one agent is done. + + In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act. + For this reason, it is compulsory to set ``use_mask=True`` for this type of environment. + ``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents. + The environment will be reset only when all agents are done. + + If there are any unavailable actions for an agent, + the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"`` + for each group to reflect the latest available actions. This should be passed to a masked distribution during + training. + + As a feature of torchrl multiagent, you are able to control the grouping of agents in your environment. + You can group agents together (stacking their tensors) to leverage vectorization when passing them through the same + neural network. You can split agents in different groups where they are heterogenous or should be processed by + different neural networks. To group, you just need to pass a ``group_map`` at env constructiuon time. + + By default, agents in pettingzoo will be grouped by name. + For example, with agents ``["agent_0","agent_1","agent_2","adversary_0"]``, the tensordicts will look like: + + >>> print(env.rand_action(env.reset())) + TensorDict( + fields={ + agent: TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([3]))}, + adversary: TensorDict( + fields={ + action: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([]))}, + batch_size=torch.Size([])) + >>> print(env.group_map) + {"agent": ["agent_0", "agent_1", "agent_2"], "adversary": ["adversary_0"]} + + Otherwise, a group map can be specified or selected from some premade options. + See :class:`torchrl.envs.utils.MarlGroupMapType` for more info. + For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should + have its own tensordict (similar to the pettingzoo parallel API). + + Grouping is useful for leveraging vectorisation among agents whose data goes through the same + neural network. + + Args: + env (``pettingzoo.utils.env.ParallelEnv`` or ``pettingzoo.utils.env.AECEnv``): the pettingzoo environment to wrap. + return_state (bool, optional): whether to return the global state from pettingzoo + (not available in all environments). Defaults to ``False``. + group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for + input/output. By default, agents will be grouped by their name. Otherwise, a group map can be specified + or selected from some premade options. See :class:`torchrl.envs.utils.MarlGroupMapType` for more info. + use_mask (bool, optional): whether the environment should output a ``"mask"``. This is compulsory in + wrapped ``pettingzoo.AECEnv`` to mask out non-acting agents and should be also used + for ``pettingzoo.ParallelEnv`` when the number of agents can vary. Defaults to ``False``. + categorical_actions (bool, optional): if the enviornments actions are discrete, whether to transform + them to categorical or one-hot. + seed (int, optional): the seed. Defaults to ``None``. + + Examples: + >>> # Parallel env + >>> from torchrl.envs.libs.pettingzoo import PettingZooWrapper + >>> from pettingzoo.butterfly import pistonball_v6 + >>> kwargs = {"n_pistons": 21, "continuous": True} + >>> env = PettingZooWrapper( + ... env=pistonball_v6.parallel_env(**kwargs), + ... return_state=True, + ... group_map=None, # Use default for parallel (all pistons grouped together) + ... ) + >>> print(env.group_map) + ... {'piston': ['piston_0', 'piston_1', ..., 'piston_20']} + >>> env.rollout(10) + >>> # AEC env + >>> from pettingzoo.classic import tictactoe_v3 + >>> from torchrl.envs.libs.pettingzoo import PettingZooWrapper + >>> from torchrl.envs.utils import MarlGroupMapType + >>> env = PettingZooWrapper( + ... env=tictactoe_v3.env(), + ... use_mask=True, # Must use it since one player plays at a time + ... group_map=None # # Use default for AEC (one group per player) + ... ) + >>> print(env.group_map) + ... {'player_1': ['player_1'], 'player_2': ['player_2']} + >>> env.rollout(10) + """ + + git_url = "https://github.com/Farama-Foundation/PettingZoo" + libname = "pettingzoo" + available_envs = _get_envs() + + def __init__( + self, + env: Union[ + "pettingzoo.utils.env.ParallelEnv", # noqa: F821 + "pettingzoo.utils.env.AECEnv", # noqa: F821 + ] = None, + return_state: Optional[bool] = False, + group_map: Optional[Union[MarlGroupMapType, Dict[str, List[str]]]] = None, + use_mask: bool = False, + categorical_actions: bool = True, + seed: Optional[int] = None, + **kwargs, + ): + if env is not None: + kwargs["env"] = env + + self.group_map = group_map + self.return_state = return_state + self.seed = seed + self.use_mask = use_mask + self.categorical_actions = categorical_actions + + super().__init__(**kwargs, allow_done_after_reset=True) + + def _get_default_group_map(self, agent_names: List[str]): + # This function performs the default grouping in pettingzoo + if not self.parallel: + # In AEC envs we will have one group per agent by default + group_map = MarlGroupMapType.ONE_GROUP_PER_AGENT.get_group_map(agent_names) + else: + # In parallel envs, by default + # Agents with names "str_int" will be grouped in group name "str" + group_map = {} + for agent_name in agent_names: + # See if the agent follows the convention "name_int" + follows_convention = True + agent_name_split = agent_name.split("_") + if len(agent_name_split) == 1: + follows_convention = False + try: + int(agent_name_split[-1]) + except ValueError: + follows_convention = False + + # If not, just put it in a single group + if not follows_convention: + group_map[agent_name] = [agent_name] + # Otherwise, group it with other agents that follow the same convention + else: + group_name = "_".join(agent_name_split[:-1]) + if group_name in group_map: + group_map[group_name].append(agent_name) + else: + group_map[group_name] = [agent_name] + + return group_map + + @property + def lib(self): + import pettingzoo + + return pettingzoo + + def _build_env( + self, + env: Union[ + "pettingzoo.utils.env.ParallelEnv", # noqa: F821 + "pettingzoo.utils.env.AECEnv", # noqa: F821 + ], + ): + import pettingzoo + + self.parallel = isinstance(env, pettingzoo.utils.env.ParallelEnv) + if not self.parallel and not self.use_mask: + raise ValueError("For AEC environments you need to set use_mask=True") + if len(self.batch_size): + raise RuntimeError( + f"PettingZoo does not support custom batch_size {self.batch_size}." + ) + + return env + + @set_gym_backend("gymnasium") + def _make_specs( + self, + env: Union[ + "pettingzoo.utils.env.ParallelEnv", # noqa: F821 + "pettingzoo.utils.env.AECEnv", # noqa: F821 + ], + ) -> None: + + # Create and check group map + if self.group_map is None: + self.group_map = self._get_default_group_map(self.possible_agents) + elif isinstance(self.group_map, MarlGroupMapType): + self.group_map = self.group_map.get_group_map(self.possible_agents) + check_marl_grouping(self.group_map, self.possible_agents) + self.has_action_mask = {group: False for group in self.group_map.keys()} + + action_spec = CompositeSpec() + observation_spec = CompositeSpec() + reward_spec = CompositeSpec() + done_spec = CompositeSpec() + for group, agents in self.group_map.items(): + ( + group_observation_spec, + group_action_spec, + group_reward_spec, + group_done_spec, + ) = self._make_group_specs(group_name=group, agent_names=agents) + action_spec[group] = group_action_spec + observation_spec[group] = group_observation_spec + reward_spec[group] = group_reward_spec + done_spec[group] = group_done_spec + + self.action_spec = action_spec + self.observation_spec = observation_spec + self.reward_spec = reward_spec + self.done_spec = done_spec + + def _make_group_specs(self, group_name: str, agent_names: List[str]): + n_agents = len(agent_names) + action_specs = [] + observation_specs = [] + for agent in agent_names: + action_specs.append( + CompositeSpec( + { + "action": _gym_to_torchrl_spec_transform( + self.action_space(agent), + remap_state_to_observation=False, + categorical_action_encoding=self.categorical_actions, + device=self.device, + ) + }, + ) + ) + observation_specs.append( + CompositeSpec( + { + "observation": _gym_to_torchrl_spec_transform( + self.observation_space(agent), + remap_state_to_observation=False, + device=self.device, + ) + } + ) + ) + group_action_spec = torch.stack(action_specs, dim=0) + group_observation_spec = torch.stack(observation_specs, dim=0) + + # Sometimes the observation spec contains an action mask. + # Or sometimes the info spec contains an action mask. + # We uniform this by removing it from both places and optionally set it in a standard location. + group_observation_inner_spec = group_observation_spec["observation"] + if ( + isinstance(group_observation_inner_spec, CompositeSpec) + and "action_mask" in group_observation_inner_spec.keys() + ): + self.has_action_mask[group_name] = True + del group_observation_inner_spec["action_mask"] + group_observation_spec["action_mask"] = DiscreteTensorSpec( + n=2, + shape=group_action_spec["action"].shape + if not self.categorical_actions + else ( + *group_action_spec["action"].shape, + group_action_spec["action"].space.n, + ), + dtype=torch.bool, + device=self.device, + ) + + if self.use_mask: + group_observation_spec["mask"] = DiscreteTensorSpec( + n=2, + shape=torch.Size((n_agents,)), + dtype=torch.bool, + device=self.device, + ) + + group_reward_spec = CompositeSpec( + { + "reward": UnboundedContinuousTensorSpec( + shape=torch.Size((n_agents, 1)), + device=self.device, + dtype=torch.float32, + ) + }, + shape=torch.Size((n_agents,)), + ) + group_done_spec = CompositeSpec( + { + "done": DiscreteTensorSpec( + n=2, + shape=torch.Size((n_agents, 1)), + dtype=torch.bool, + device=self.device, + ), + }, + shape=torch.Size((n_agents,)), + ) + if n_agents == 1: + # When there is only one agent in the group we remove the singleton corresponding to the group size + group_observation_spec = group_observation_spec.squeeze(0) + group_action_spec = group_action_spec.squeeze(0) + group_reward_spec = group_reward_spec.squeeze(0) + group_done_spec = group_done_spec.squeeze(0) + return ( + group_observation_spec, + group_action_spec, + group_reward_spec, + group_done_spec, + ) + + def _check_kwargs(self, kwargs: Dict): + import pettingzoo + + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + if not isinstance( + env, (pettingzoo.utils.env.ParallelEnv, pettingzoo.utils.env.AECEnv) + ): + raise TypeError("env is not of type expected.") + + def _init_env(self) -> Optional[int]: + # Add info + if self.parallel: + _, info_dict = self._reset_parallel() + else: + _, info_dict = self._reset_aec() + + for group, agents in self.group_map.items(): + info_specs = [] + for agent in agents: + info_specs.append( + CompositeSpec( + { + "info": CompositeSpec( + { + key: UnboundedContinuousTensorSpec( + shape=torch.tensor(value).shape, + device=self.device, + ) + for key, value in info_dict[agent].items() + } + ) + }, + device=self.device, + ) + ) + info_specs = torch.stack(info_specs, dim=0) + if ("info", "action_mask") in info_specs.keys(True, True): + if not self.has_action_mask[group]: + self.has_action_mask[group] = True + group_action_spec = self.input_spec[ + "full_action_spec", group, "action" + ] + self.observation_spec[group]["action_mask"] = DiscreteTensorSpec( + n=2, + shape=group_action_spec.shape + if not self.categorical_actions + else (*group_action_spec.shape, group_action_spec.space.n), + dtype=torch.bool, + device=self.device, + ) + group_inner_info_spec = info_specs["info"] + del group_inner_info_spec["action_mask"] + + if len(info_specs["info"].keys()): + self.observation_spec[group].update(info_specs) + + if self.return_state: + try: + state_spec = _gym_to_torchrl_spec_transform( + self.state_space, + remap_state_to_observation=False, + device=self.device, + ) + except AttributeError: + state_example = torch.tensor(self.state(), device=self.device) + state_spec = UnboundedContinuousTensorSpec( + shape=state_example.shape, + dtype=state_example.dtype, + device=self.device, + ) + self.observation_spec["state"] = state_spec + + # Caching + self.cached_reset_output_zero = self.observation_spec.zero() + self.cached_reset_output_zero.update(self.output_spec["full_done_spec"].zero()) + + self.cached_step_output_zero = self.observation_spec.zero() + self.cached_step_output_zero.update(self.output_spec["full_reward_spec"].zero()) + self.cached_step_output_zero.update(self.output_spec["full_done_spec"].zero()) + + def _set_seed(self, seed: Optional[int]): + self.seed = seed + + def _reset( + self, tensordict: Optional[TensorDictBase] = None, **kwargs + ) -> TensorDictBase: + + if self.parallel: + # This resets when any is done + observation_dict, info_dict = self._reset_parallel() + else: + # This resets when all are done + observation_dict, info_dict = self._reset_aec(tensordict) + + # We start with zeroed data and fill in the data for alive agents + tensordict_out = self.cached_reset_output_zero.clone() + # Update the "mask" for non-acting agents + self._update_agent_mask(tensordict_out) + # Update the "action_mask" for non-available actions + observation_dict, info_dict = self._update_action_mask( + tensordict_out, observation_dict, info_dict + ) + + # Now we get the data (obs and info) + for group, agent_names in self.group_map.items(): + group_observation = tensordict_out.get((group, "observation")) + group_info = tensordict_out.get((group, "info"), None) + + for i, agent in enumerate(agent_names): + index = ( + i if len(agent_names) > 1 else Ellipsis + ) # If group has one agent we index with '...' + group_observation[index] = self.observation_spec[group, "observation"][ + index + ].encode(observation_dict[agent]) + if group_info is not None: + agent_info_dict = info_dict[agent] + for agent_info, value in agent_info_dict.items(): + group_info.get(agent_info)[index] = torch.tensor( + value, device=self.device + ) + + return tensordict_out + + def _reset_aec(self, tensordict=None) -> Tuple[Dict, Dict]: + all_done = True + if tensordict is not None: + _resets = [] + for done_key in self.done_keys: + _reset_key = _replace_last(done_key, "_reset") + _reset = tensordict.get(_reset_key, default=None) + if _reset is None: + continue + _resets.append(_reset) + if len(_resets) < len(self.done_keys): + all_done = False + else: + for _reset in _resets: + if not _reset.all(): + all_done = False + break + + if all_done: + self._env.reset(seed=self.seed) + + observation_dict = { + agent: self._env.observe(agent) for agent in self.possible_agents + } + info_dict = self._env.infos + return observation_dict, info_dict + + def _reset_parallel( + self, + ) -> Tuple[Dict, Dict]: + return self._env.reset(seed=self.seed) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + + if self.parallel: + ( + observation_dict, + rewards_dict, + terminations_dict, + truncations_dict, + info_dict, + ) = self._step_parallel(tensordict) + else: + ( + observation_dict, + rewards_dict, + terminations_dict, + truncations_dict, + info_dict, + ) = self._step_aec(tensordict) + + # We start with zeroed data and fill in the data for alive agents + tensordict_out = self.cached_step_output_zero.clone() + # Update the "mask" for non-acting agents + self._update_agent_mask(tensordict_out) + # Update the "action_mask" for non-available actions + observation_dict, info_dict = self._update_action_mask( + tensordict_out, observation_dict, info_dict + ) + + # Now we get the data + for group, agent_names in self.group_map.items(): + group_observation = tensordict_out.get((group, "observation")) + group_reward = tensordict_out.get((group, "reward")) + group_done = tensordict_out.get((group, "done")) + group_info = tensordict_out.get((group, "info"), None) + + for i, agent in enumerate(agent_names): + if agent in observation_dict: # Live agents + index = ( + i if len(agent_names) > 1 else Ellipsis + ) # If group has one agent, we index with '...' + group_observation[index] = self.observation_spec[ + group, "observation" + ][index].encode(observation_dict[agent]) + group_reward[index] = torch.tensor( + rewards_dict[agent], + device=self.device, + dtype=torch.float32, + ) + group_done[index] = torch.tensor( + terminations_dict[agent] or truncations_dict[agent], + device=self.device, + dtype=torch.bool, + ) + + if group_info is not None: + agent_info_dict = info_dict[agent] + for agent_info, value in agent_info_dict.items(): + group_info.get(agent_info)[index] = torch.tensor( + value, device=self.device + ) + + elif not self.use_action_mask: + # Dead agent, if we are not masking it out, this is not allowed + raise ValueError( + "Dead agents found in the environment," + " you need to set use_action_mask=True to allow this." + ) + + return tensordict_out + + def _step_parallel( + self, + tensordict: TensorDictBase, + ) -> Tuple[Dict, Dict, Dict, Dict, Dict]: + action_dict = {} + for group, agents in self.group_map.items(): + group_action = tensordict.get((group, "action")) + group_action_np = self.input_spec[ + "full_action_spec", group, "action" + ].to_numpy(group_action) + for i, agent in enumerate(agents): + index = i if len(agents) > 1 else Ellipsis + action_dict[agent] = group_action_np[index] + + return self._env.step(action_dict) + + def _step_aec( + self, + tensordict: TensorDictBase, + ) -> Tuple[Dict, Dict, Dict, Dict, Dict]: + + for group, agents in self.group_map.items(): + if self.agent_selection in agents: + agent_index = ( + agents.index(self._env.agent_selection) + if len(agents) > 1 + else Ellipsis + ) + group_action = tensordict.get((group, "action")) + group_action_np = self.input_spec[ + "full_action_spec", group, "action" + ].to_numpy(group_action) + action = group_action_np[agent_index] + break + + self._env.step(action) + terminations_dict = self._env.terminations + truncations_dict = self._env.truncations + info_dict = self._env.infos + rewards_dict = self._env.rewards + observation_dict = { + agent: self._env.observe(agent) for agent in self.possible_agents + } + return ( + observation_dict, + rewards_dict, + terminations_dict, + truncations_dict, + info_dict, + ) + + def _update_action_mask(self, td, observation_dict, info_dict): + + # Since we remove the action_mask keys we need to copy the data + observation_dict = copy.deepcopy(observation_dict) + info_dict = copy.deepcopy(info_dict) + # In AEC only one agent acts, in parallel env self.agents contains the agents alive + agents_acting = self.agents if self.parallel else [self.agent_selection] + + for group, agents in self.group_map.items(): + if self.has_action_mask[group]: + group_mask = td.get((group, "action_mask")) + group_mask += True + for i, agent in enumerate(agents): + index = ( + i if len(agents) > 1 else Ellipsis + ) # If group has one agent we index with '...' + agent_obs = observation_dict[agent] + agent_info = info_dict[agent] + if isinstance(agent_obs, Dict) and "action_mask" in agent_obs: + if agent in agents_acting: + group_mask[index] = torch.tensor( + agent_obs["action_mask"], + device=self.device, + dtype=torch.bool, + ) + del agent_obs["action_mask"] + elif isinstance(agent_info, Dict) and "action_mask" in agent_info: + if agent in agents_acting: + group_mask[index] = torch.tensor( + agent_info["action_mask"], + device=self.device, + dtype=torch.bool, + ) + del agent_info["action_mask"] + + group_action_spec = self.input_spec["full_action_spec", group, "action"] + if isinstance( + group_action_spec, (DiscreteTensorSpec, OneHotDiscreteTensorSpec) + ): + # We update the mask for available actions + group_action_spec.update_mask(group_mask.clone()) + + return observation_dict, info_dict + + def _update_agent_mask(self, td): + if self.use_mask: + # In AEC only one agent acts, in parallel env self.agents contains the agents alive + agents_acting = self.agents if self.parallel else [self.agent_selection] + for group, agents in self.group_map.items(): + group_mask = td.get((group, "mask")) + group_mask += True + + # We now add dead agents to the mask + for i, agent in enumerate(agents): + index = ( + i if len(agents) > 1 else Ellipsis + ) # If group has one agent we index with '...' + if agent not in agents_acting: + group_mask[index] = False + + def close(self) -> None: + self._env.close() + + +class PettingZooEnv(PettingZooWrapper): + """PettingZoo Environment. + + To install petting zoo follow the guide `here __`. + + This class is a general torchrl wrapper for all PettingZoo environments. + It is able to wrap both ``pettingzoo.AECEnv`` and ``pettingzoo.ParallelEnv``. + + Let's see how more in details: + + For wrapping ``pettingzoo.ParallelEnv`` provide the name of your petting zoo task (in the ``task`` argument) + and specify ``parallel=True``. This will construct the ``pettingzoo.ParallelEnv`` version of that task + (if it is supported in pettingzoo) and wrap it for torchrl. + In wrapped ``pettingzoo.ParallelEnv`` all agents will step at each environment step. + If the number of agents during the task varies, please set ``use_mask=True``. + ``"mask"`` will be provided + as an output in each group and should be used to mask out dead agents. + The environment will be reset as soon as one agent is done. + + For wrapping ``pettingzoo.AECEnv`` provide the name of your petting zoo task (in the ``task`` argument) + and specify ``parallel=False``. This will construct the ``pettingzoo.AECEnv`` version of that task + and wrap it for torchrl. + In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act. + For this reason, it is compulsory to set ``use_mask=True`` for this type of environment. + ``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents. + The environment will be reset only when all agents are done. + + If there are any unavailable actions for an agent, + the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"`` + for each group to reflect the latest available actions. This should be passed to a masked distribution during + training. + + As a feature of torchrl multiagent, you are able to control the grouping of agents in your environment. + You can group agents together (stacking their tensors) to leverage vectorization when passing them through the same + neural network. You can split agents in different groups where they are heterogenous or should be processed by + different neural networks. To group, you just need to pass a ``group_map`` at env constructiuon time. + + By default, agents in pettingzoo will be grouped by name. + For example, with agents ``["agent_0","agent_1","agent_2","adversary_0"]``, the tensordicts will look like: + + >>> print(env.rand_action(env.reset())) + TensorDict( + fields={ + agent: TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([3]))}, + adversary: TensorDict( + fields={ + action: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([]))}, + batch_size=torch.Size([])) + >>> print(env.group_map) + {"agent": ["agent_0", "agent_1", "agent_2"], "adversary": ["adversary_0"]} + + Otherwise, a group map can be specified or selected from some premade options. + See :class:`torchrl.envs.utils.MarlGroupMapType` for more info. + For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should + have its own tensordict (similar to the pettingzoo parallel API). + + Grouping is useful for leveraging vectorisation among agents whose data goes through the same + neural network. + + Args: + task (str): the name of the pettingzoo task to create (for example, "multiwalker_v9"). + parallel (bool): if to construct the ``pettingzoo.ParallelEnv`` version of the task or the ``pettingzoo.AECEnv``. + return_state (bool, optional): whether to return the global state from pettingzoo + (not available in all environments). Defaults to ``False``. + group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for + input/output. By default, agents will be grouped by their name. Otherwise, a group map can be specified + or selected from some premade options. See :class:`torchrl.envs.utils.MarlGroupMapType` for more info. + use_mask (bool, optional): whether the environment should output an ``"mask"``. This is compulsory in + wrapped ``pettingzoo.AECEnv`` to mask out non-acting agents and should be also used + for ``pettingzoo.ParallelEnv`` when the number of agents can vary. Defaults to ``False``. + categorical_actions (bool, optional): if the enviornments actions are discrete, whether to transform + them to categorical or one-hot. + seed (int, optional): the seed. Defaults to ``None``. + + Examples: + >>> # Parallel env + >>> from torchrl.envs.libs.pettingzoo import PettingZooEnv + >>> kwargs = {"n_pistons": 21, "continuous": True} + >>> env = PettingZooEnv( + ... task="pistonball_v6", + ... parallel=True, + ... return_state=True, + ... group_map=None, # Use default (all pistons grouped together) + ... **kwargs, + ... ) + >>> print(env.group_map) + ... {'piston': ['piston_0', 'piston_1', ..., 'piston_20']} + >>> env.rollout(10) + >>> # AEC env + >>> from torchrl.envs.libs.pettingzoo import PettingZooEnv + >>> from torchrl.envs.utils import MarlGroupMapType + >>> env = PettingZooEnv( + ... task="tictactoe_v3", + ... parallel=False, + ... use_mask=True, # Must use it since one player plays at a time + ... group_map=None # # Use default for AEC (one group per player) + ... ) + >>> print(env.group_map) + ... {'player_1': ['player_1'], 'player_2': ['player_2']} + >>> env.rollout(10) + """ + + def __init__( + self, + task: str, + parallel: bool, + return_state: Optional[bool] = False, + group_map: Optional[Union[MarlGroupMapType, Dict[str, List[str]]]] = None, + use_mask: bool = False, + categorical_actions: bool = True, + seed: Optional[int] = None, + **kwargs, + ): + if not _has_pettingzoo: + raise ImportError( + f"pettingzoo python package was not found. Please install this dependency. " + f"More info: {self.git_url}." + ) + kwargs["task"] = task + kwargs["parallel"] = parallel + kwargs["return_state"] = return_state + kwargs["group_map"] = group_map + kwargs["use_mask"] = use_mask + kwargs["categorical_actions"] = categorical_actions + kwargs["seed"] = seed + + super().__init__(**kwargs) + + def _check_kwargs(self, kwargs: Dict): + if "task" not in kwargs: + raise TypeError("Could not find environment key 'task' in kwargs.") + if "parallel" not in kwargs: + raise TypeError("Could not find environment key 'parallel' in kwargs.") + + def _build_env( + self, + task: str, + parallel: bool, + **kwargs, + ) -> Union[ + "pettingzoo.utils.env.ParallelEnv", # noqa: F821 + "pettingzoo.utils.env.AECEnv", # noqa: F821 + ]: + self.task_name = task + + from pettingzoo.utils.all_modules import all_environments + + if task not in all_environments: + # Try looking at the literal translation of values + task_module = None + for value in all_environments.values(): + if value.__name__.split(".")[-1] == task: + task_module = value + break + if task_module is None: + raise RuntimeError(f"Specified task not in {_get_envs()}") + else: + task_module = all_environments[task] + + if parallel: + petting_zoo_env = task_module.parallel_env(**kwargs) + else: + petting_zoo_env = task_module.env(**kwargs) + + return super()._build_env(env=petting_zoo_env) diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 64bc71d0a5a..7ce2bc91885 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + from typing import Dict, List, Optional, Union import torch @@ -116,7 +121,7 @@ def __init__( raise TypeError("Env device is different from vmas device") kwargs["device"] = str(env.device) self.categorical_actions = categorical_actions - super().__init__(**kwargs) + super().__init__(**kwargs, allow_done_after_reset=True) @property def lib(self): diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ca80c0c2189..e4d41d14bfa 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -9,7 +9,8 @@ import importlib.util import os import re -from typing import List, Union +from enum import Enum +from typing import Dict, List, Union import torch @@ -35,6 +36,8 @@ "check_env_specs", "step_mdp", "make_composite_from_td", + "MarlGroupMapType", + "check_marl_grouping", ] @@ -612,3 +615,119 @@ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: return new_ending else: return key[:-1] + (new_ending,) + + +class MarlGroupMapType(Enum): + """Marl Group Map Type. + + As a feature of torchrl multiagent, you are able to control the grouping of agents in your environment. + You can group agents together (stacking their tensors) to leverage vectorization when passing them through the same + neural network. You can split agents in different groups where they are heterogenous or should be processed by + different neural networks. To group, you just need to pass a ``group_map`` at env constructiuon time. + + Otherwise, you can choose one of the premade grouping strategies from this class. + + - With ``group_map=MarlGroupMapType.ALL_IN_ONE_GROUP`` and + agents ``["agent_0", "agent_1", "agent_2", "agent_3"]``, + the tensordicts coming and going from your environment will look + something like: + + >>> print(env.rand_action(env.reset())) + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([4, 9]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([4, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([4]))}, + batch_size=torch.Size([])) + >>> print(env.group_map) + {"agents": ["agent_0", "agent_1", "agent_2", "agent_3]} + + - With ``group_map=MarlGroupMapType.ONE_GROUP_PER_AGENT`` and + agents ``["agent_0", "agent_1", "agent_2", "agent_3"]``, + the tensordicts coming and going from your environment will look + something like: + + >>> print(env.rand_action(env.reset())) + TensorDict( + fields={ + agent_0: TensorDict( + fields={ + action: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([]))}, + agent_1: TensorDict( + fields={ + action: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([]))}, + agent_2: TensorDict( + fields={ + action: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([]))}, + agent_3: TensorDict( + fields={ + action: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False)}, + batch_size=torch.Size([]))}, + batch_size=torch.Size([])) + >>> print(env.group_map) + {"agent_0": ["agent_0"], "agent_1": ["agent_1"], "agent_2": ["agent_2"], "agent_3": ["agent_3"]} + """ + + ALL_IN_ONE_GROUP = 1 + ONE_GROUP_PER_AGENT = 2 + + def get_group_map(self, agent_names: List[str]): + if self == MarlGroupMapType.ALL_IN_ONE_GROUP: + return {"agents": agent_names} + elif self == MarlGroupMapType.ONE_GROUP_PER_AGENT: + return {agent_name: [agent_name] for agent_name in agent_names} + + +def check_marl_grouping(group_map: Dict[str, List[str]], agent_names: List[str]): + """Check MARL group map. + + Performs checks on the group map of a marl environment to assess its validity. + Raises an error in cas of an invalid group_map. + + Args: + group_map (Dict[str, List[str]]): the group map mapping group names to list of agent names in the group + agent_names (List[str]): a list of all the agent names in the environment4 + + Examples: + >>> from torchrl.envs.utils import MarlGroupMapType, check_marl_grouping + >>> agent_names = ["agent_0", "agent_1", "agent_2"] + >>> check_marl_grouping(MarlGroupMapType.ALL_IN_ONE_GROUP.get_group_map(agent_names), agent_names) + + """ + n_agents = len(agent_names) + if n_agents == 0: + raise ValueError("No agents passed") + if len(set(agent_names)) != n_agents: + raise ValueError("There are agents with the same name") + if len(group_map.keys()) > n_agents: + raise ValueError( + f"Number of groups {len(group_map.keys())} greater than number of agents {n_agents}" + ) + found_agents = {agent_name: False for agent_name in agent_names} + for group_name, group in group_map.items(): + if not len(group): + raise ValueError(f"Group {group_name} is empty") + for agent_name in group: + if agent_name not in found_agents: + raise ValueError(f"Agent {agent_name} not present in environment") + if not found_agents[agent_name]: + found_agents[agent_name] = True + else: + raise ValueError(f"Agent {agent_name} present more than once") + for agent_name, found in found_agents.items(): + if not found: + raise ValueError(f"Agent {agent_name} not found in any group")