From 8aafcb3b42da5611b3fe16d657a669c1653bbe01 Mon Sep 17 00:00:00 2001 From: Felix Chalumeau Date: Fri, 8 Jul 2022 19:17:34 -0400 Subject: [PATCH] feat(notebooks): add example notebook for NSGA2 and SPEA2 (#58) Add example notebook for NSGA2 and SPEA2 + minor fixes in MOME notebook --- README.md | 5 + notebooks/mome_example.ipynb | 32 +-- notebooks/nsga2_spea2_example.ipynb | 369 ++++++++++++++++++++++++++++ 3 files changed, 391 insertions(+), 15 deletions(-) create mode 100644 notebooks/nsga2_spea2_example.ipynb diff --git a/README.md b/README.md index c2075bb8..26605f40 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ repertoire.genotypes, repertoire.fitnesses, repertoire.descriptors ## QDax core algorithms QDax currently supports the following algorithms: + | Algorithm | Example | | --- | --- | | [MAP-Elites](https://arxiv.org/abs/1504.04909) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/mapelites_example.ipynb) | @@ -68,13 +69,17 @@ QDax currently supports the following algorithms: | [CMA-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/cmamega_example.ipynb) | | [Multi-Objective Quality-Diversity (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/mome_example.ipynb) | + ## QDax baseline algorithms The QDax library also provides implementations for some useful baseline algorithms: + | Algorithm | Example | | --- | --- | | [DIAYN](https://arxiv.org/abs/1802.06070) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/diayn_example.ipynb) | | [DADS](https://arxiv.org/abs/1907.01657) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/dads_example.ipynb) | | [SMERL](https://arxiv.org/abs/2010.14484) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/smerl_example.ipynb) | +| [NSGA2](https://ieeexplore.ieee.org/document/996017) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/nsga2_spea2_example.ipynb) | +| [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/nsga2_spea2_example.ipynb) | ## QDax Overview diff --git a/notebooks/mome_example.ipynb b/notebooks/mome_example.ipynb index dc67d37d..b2e225fa 100644 --- a/notebooks/mome_example.ipynb +++ b/notebooks/mome_example.ipynb @@ -86,20 +86,22 @@ "metadata": {}, "outputs": [], "source": [ - "pareto_front_max_length = 50\n", - "num_variables = 100\n", - "num_iterations = 1000\n", + "#@markdown ---\n", + "pareto_front_max_length = 50 #@param {type:\"integer\"}\n", + "num_variables = 100 #@param {type:\"integer\"}\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", "\n", - "num_centroids = 64\n", - "minval = -2\n", - "maxval = 4\n", - "proportion_to_mutate = 0.6\n", - "eta = 1\n", - "proportion_var_to_change = 0.5\n", - "crossover_percentage = 1.\n", - "batch_size = 100\n", - "lag = 2.2\n", - "base_lag = 0" + "num_centroids = 64 #@param {type:\"integer\"}\n", + "minval = -2 #@param {type:\"number\"}\n", + "maxval = 4 #@param {type:\"number\"}\n", + "proportion_to_mutate = 0.6 #@param {type:\"number\"}\n", + "eta = 1 #@param {type:\"number\"}\n", + "proportion_var_to_change = 0.5 #@param {type:\"number\"}\n", + "crossover_percentage = 1. #@param {type:\"number\"}\n", + "batch_size = 100 #@param {type:\"integer\"}\n", + "lag = 2.2 #@param {type:\"number\"}\n", + "base_lag = 0 #@param {type:\"number\"}\n", + "#@markdown ---" ] }, { @@ -120,8 +122,8 @@ "outputs": [], "source": [ "def rastrigin_scorer(\n", - " genotypes: jnp.ndarray, base_lag: int, lag: int\n", - ") -> Tuple[jnp.ndarray, jnp.ndarray]:\n", + " genotypes: jnp.ndarray, base_lag: float, lag: float\n", + ") -> Tuple[Fitness, Descriptor]:\n", " \"\"\"\n", " Rastrigin Scorer with first two dimensions as descriptors\n", " \"\"\"\n", diff --git a/notebooks/nsga2_spea2_example.ipynb b/notebooks/nsga2_spea2_example.ipynb new file mode 100644 index 00000000..771dde31 --- /dev/null +++ b/notebooks/nsga2_spea2_example.ipynb @@ -0,0 +1,369 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Optimizing multiple objectives with NSGA2 & SPEA2 in Jax\n", + "\n", + "This notebook shows how to use QDax to find diverse and performing parameters on a multi-objectives Rastrigin problem, using [NSGA2](https://ieeexplore.ieee.org/document/996017) and [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) algorithms. It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "\n", + "- how to define the problem\n", + "- how to create an emitter instance\n", + "- how to create an NSGA2 instance\n", + "- how to create an SPEA2 instance\n", + "- which functions must be defined before training\n", + "- how to launch a certain number of training steps\n", + "- how to visualise the optimization process" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import jax\n", + "\n", + "from typing import Tuple\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from functools import partial\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", + "from qdax.core.nsga2 import (\n", + " NSGA2\n", + ")\n", + "from qdax.core.spea2 import (\n", + " SPEA2\n", + ")\n", + "\n", + "from qdax.core.emitters.mutation_operators import (\n", + " polynomial_crossover, \n", + " polynomial_mutation\n", + ")\n", + "from qdax.core.emitters.standard_emitters import MixingEmitter\n", + "from qdax.utils.pareto_front import compute_pareto_front\n", + "from qdax.utils.plotting import plot_global_pareto_front\n", + "\n", + "from qdax.utils.pareto_front import compute_pareto_front\n", + "from qdax.utils.plotting import plot_global_pareto_front\n", + "from qdax.utils.metrics import default_ga_metrics\n", + "\n", + "from qdax.types import Genotype, Fitness, Descriptor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set the hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@markdown ---\n", + "population_size = 1000 #@param {type:\"integer\"}\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", + "proportion_mutation = 0.80 #@param {type:\"number\"}\n", + "minval = -5.12 #@param {type:\"number\"}\n", + "maxval = 5.12 #@param {type:\"number\"}\n", + "batch_size = 100 #@param {type:\"integer\"}\n", + "genotype_dim = 6 #@param {type:\"integer\"}\n", + "lag = 2.2 #@param {type:\"number\"}\n", + "base_lag = 0 #@param {type:\"number\"}\n", + "# for spea2\n", + "num_neighbours=1 #@param {type:\"integer\"}\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the scoring function: rastrigin multi-objective\n", + "\n", + "We use two rastrigin functions with an offset to create a multi-objective problem." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def rastrigin_scorer(\n", + " genotypes: jnp.ndarray, base_lag: float, lag: float\n", + ") -> Tuple[Fitness, Descriptor]:\n", + " \"\"\"\n", + " Rastrigin Scorer with first two dimensions as descriptors\n", + " \"\"\"\n", + " descriptors = genotypes[:, :2]\n", + " f1 = -(\n", + " 10 * genotypes.shape[1]\n", + " + jnp.sum(\n", + " (genotypes - base_lag) ** 2\n", + " - 10 * jnp.cos(2 * jnp.pi * (genotypes - base_lag)),\n", + " axis=1,\n", + " )\n", + " )\n", + "\n", + " f2 = -(\n", + " 10 * genotypes.shape[1]\n", + " + jnp.sum(\n", + " (genotypes - lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - lag)),\n", + " axis=1,\n", + " )\n", + " )\n", + " scores = jnp.stack([f1, f2], axis=-1)\n", + "\n", + " return scores, descriptors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Scoring function\n", + "scoring_function = partial(\n", + " rastrigin_scorer,\n", + " lag=lag,\n", + " base_lag=base_lag\n", + ")\n", + "\n", + "def scoring_fn(x, random_key):\n", + " return scoring_function(x)[0], {}, random_key" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define initial population and emitter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initial population\n", + "random_key = jax.random.PRNGKey(0)\n", + "random_key, subkey = jax.random.split(random_key)\n", + "init_genotypes = jax.random.uniform(\n", + " subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n", + ")\n", + "\n", + "# Mutation & Crossover\n", + "crossover_function = partial(\n", + " polynomial_crossover, \n", + " proportion_var_to_change=0.5,\n", + ")\n", + "\n", + "mutation_function = partial(\n", + " polynomial_mutation, \n", + " proportion_to_mutate=0.5, \n", + " eta=0.05, \n", + " minval=minval, \n", + " maxval=maxval\n", + ")\n", + "\n", + "# Define the emitter\n", + "mixing_emitter = MixingEmitter(\n", + " mutation_fn=mutation_function, \n", + " variation_fn=crossover_function, \n", + " variation_percentage=1-proportion_mutation, \n", + " batch_size=batch_size\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiate and init NSGA2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# instantitiate nsga2\n", + "nsga2 = NSGA2(\n", + " scoring_function=scoring_fn,\n", + " emitter=mixing_emitter,\n", + " metrics_function=default_ga_metrics\n", + ")\n", + "\n", + "# init nsga2\n", + "repertoire, emitter_state, random_key = nsga2.init(\n", + " init_genotypes,\n", + " population_size,\n", + " random_key\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run and visualize result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "# run optimization loop\n", + "(repertoire, emitter_state, random_key), _ = jax.lax.scan(\n", + " nsga2.scan_update, (repertoire, emitter_state, random_key), (), length=num_iterations\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(9, 6))\n", + "pareto_bool = compute_pareto_front(repertoire.fitnesses)\n", + "plot_global_pareto_front(repertoire.fitnesses[pareto_bool], ax=ax)\n", + "ax.set_title(\"Pareto front obtained by NSGA2\", fontsize=16)\n", + "ax.set_xlabel(\"Fitness Dimension 1\", fontsize=14)\n", + "ax.set_ylabel(\"Fitness Dimension 2\", fontsize=14)\n", + "plt.grid()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiate and init SPEA2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# instantitiate spea2\n", + "spea2 = SPEA2(\n", + " scoring_function=scoring_fn,\n", + " emitter=mixing_emitter,\n", + " metrics_function=default_ga_metrics\n", + ")\n", + "\n", + "# init spea2\n", + "repertoire, emitter_state, random_key = spea2.init(\n", + " init_genotypes,\n", + " population_size,\n", + " num_neighbours,\n", + " random_key\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "# run optimization loop\n", + "(repertoire, emitter_state, random_key), _ = jax.lax.scan(\n", + " spea2.scan_update, (repertoire, emitter_state, random_key), (), length=num_iterations\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run and visualize result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(9, 6))\n", + "pareto_bool = compute_pareto_front(repertoire.fitnesses)\n", + "plot_global_pareto_front(repertoire.fitnesses[pareto_bool], ax=ax)\n", + "ax.set_title(\"Pareto front obtained by SPEA2\", fontsize=16)\n", + "ax.set_xlabel(\"Fitness Dimension 1\", fontsize=14)\n", + "ax.set_ylabel(\"Fitness Dimension 2\", fontsize=14)\n", + "plt.grid()\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.13 ('qdaxpy38')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "vscode": { + "interpreter": { + "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}