From 3f5e9049666be93d9e04943d61050e66a84aaf4f Mon Sep 17 00:00:00 2001 From: David Shipman Date: Fri, 24 Nov 2023 09:19:24 +1100 Subject: [PATCH] LDS notebook --- notebooks/workshops/LDS.ipynb | 519 ++++++++++++++++++++++++++++++++++ 1 file changed, 519 insertions(+) create mode 100644 notebooks/workshops/LDS.ipynb diff --git a/notebooks/workshops/LDS.ipynb b/notebooks/workshops/LDS.ipynb new file mode 100644 index 000000000..e19b99880 --- /dev/null +++ b/notebooks/workshops/LDS.ipynb @@ -0,0 +1,519 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "89adb60c", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from scipy.stats import qmc\n", + "from matplotlib import pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6dd2fbbd", + "metadata": {}, + "outputs": [], + "source": [ + "ndim = 1\n", + "\n", + "sobol = qmc.Sobol(ndim, scramble=False)\n", + "M = 1\n", + "N = 2**M\n", + "samp = sobol.random(N)\n", + "samp.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31ac2cce", + "metadata": {}, + "outputs": [], + "source": [ + "pd.Series(np.ones_like(samp[:,0]),index=samp[:,0]).plot(style='.',ylim=(0.0,2.0),xlim=(0.0,1.0))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25bd4ecb", + "metadata": {}, + "outputs": [], + "source": [ + "ndim = 4\n", + "\n", + "sobol = qmc.Sobol(ndim, scramble=True, seed=0)\n", + "M = 3\n", + "N = 2**M\n", + "samp = sobol.random(N)\n", + "samp.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83f4b631", + "metadata": {}, + "outputs": [], + "source": [ + "xdim = 0\n", + "ydim = 1\n", + "\n", + "fig = plt.figure()\n", + "ax = fig.gca()\n", + "spacing = 1/8\n", + "minorLocator = plt.MultipleLocator(spacing)\n", + "ax.yaxis.set_minor_locator(minorLocator)\n", + "ax.xaxis.set_minor_locator(minorLocator)\n", + "ax.grid(which = 'minor')\n", + "ax.scatter(samp[:,xdim],samp[:,ydim])\n", + "ax.set_xlim(0.0,1.0)\n", + "ax.set_ylim(0.0,1.0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d78f33ad", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(32):\n", + " lhs = qmc.LatinHypercube(ndim)\n", + " print(qmc.discrepancy(lhs.random(N)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57b51120", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(32):\n", + " sobol = qmc.Sobol(ndim, scramble=True)\n", + " print(qmc.discrepancy(sobol.random(N)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1f82ab8", + "metadata": {}, + "outputs": [], + "source": [ + "from summer2 import CompartmentalModel, Stratification\n", + "from summer2.parameters import Parameter as param" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e04afed", + "metadata": {}, + "outputs": [], + "source": [ + "def sirs_parametric_age(times=[0, 100], agegroups=list(range(0, 80, 5))):\n", + " m = CompartmentalModel(times, [\"S\", \"I\", \"R\"], \"I\")\n", + " m.set_initial_population({\"S\": 99990.0, \"I\": 10.0})\n", + " m.add_infection_frequency_flow(\"infection\", param(\"contact_rate\"), \"S\", \"I\")\n", + " m.add_transition_flow(\"recovery\", 1.0 / param(\"recovery_duration\"), \"I\", \"R\")\n", + " m.add_transition_flow(\"waning\", 1.0 / param(\"waning_duration\"), \"R\", \"S\")\n", + "\n", + " max_strl = len(str(agegroups[-1]))\n", + " agegroup_keys = [str(k).zfill(max_strl) for k in agegroups]\n", + " num_age = len(agegroup_keys)\n", + " age_strat = Stratification(\"age\", agegroup_keys, [\"S\", \"I\", \"R\"])\n", + "\n", + " # +++ SUMMER3\n", + " # Not currently parameterizable\n", + " suscept_adj = {}\n", + " for k in agegroup_keys:\n", + " suscept_adj[k] = np.exp(param(f\"suscept_{k}\"))\n", + " age_strat.set_flow_adjustments(\"infection\", suscept_adj)\n", + " rec_adj = {k: adj for k, adj in zip(agegroup_keys, np.linspace(0.5, 2.5, num_age))}\n", + " age_strat.set_flow_adjustments(\"recovery\", rec_adj)\n", + " wane_adj = {k: adj for k, adj in zip(agegroup_keys, np.linspace(2.5, 0.1, num_age))}\n", + " age_strat.set_flow_adjustments(\"waning\", wane_adj)\n", + "\n", + " #mm_base = np.linspace(5.5, 0.1, num_age).reshape((1, num_age))\n", + " #mm = (mm_base * mm_base.T) * 0.1\n", + " mm = np.random.uniform(size=(num_age,num_age))\n", + "\n", + " age_strat.set_mixing_matrix(mm)\n", + "\n", + " pop_spread = np.linspace(2.0, 1.0, num_age)\n", + " pop_split = pop_spread / pop_spread.sum()\n", + "\n", + " age_strat.set_population_split({k: pop_prop for k, pop_prop in zip(agegroup_keys, pop_split)})\n", + "\n", + " m.stratify_with(age_strat)\n", + "\n", + " # +++ SUMMER3\n", + " # Arbitrary dims for output requests\n", + " # Export to xarray instead of dataframe\n", + " incidence = m.request_output_for_flow(\"incidence\", \"infection\")\n", + " m.request_function_output(\"notifications\", incidence * param(\"cdr\"))\n", + "\n", + " for k in agegroup_keys:\n", + " m.request_output_for_flow(f\"incidenceXage_{k}\", \"infection\", source_strata={\"age\": k})\n", + "\n", + " m.set_default_parameters(\n", + " {\"contact_rate\": 0.04, \"recovery_duration\": 10.0, \"waning_duration\": 30.0, \"cdr\": 0.2} |\n", + " {f\"suscept_{k}\":0.0 for k in agegroup_keys}\n", + " )\n", + " return m\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ce4b795", + "metadata": {}, + "outputs": [], + "source": [ + "m = sirs_parametric_age()\n", + "agegroups = m.stratifications[\"age\"].strata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "972e0083", + "metadata": {}, + "outputs": [], + "source": [ + "targetp = {f\"suscept_{k}\":np.random.normal(0.0,0.2) for k in agegroups}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e38756fc", + "metadata": {}, + "outputs": [], + "source": [ + "m.run(targetp)\n", + "inoisy = (m.get_derived_outputs_df()[\"incidence\"] * np.exp(np.random.normal(0.0,0.1,101)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a559869", + "metadata": {}, + "outputs": [], + "source": [ + "def pdopt(backend=None):\n", + " if backend is None:\n", + " if pd.options.plotting.backend == \"plotly\":\n", + " backend = \"matplotlib\"\n", + " else:\n", + " backend = \"plotly\"\n", + " pd.options.plotting.backend = backend" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3234ff4", + "metadata": {}, + "outputs": [], + "source": [ + "pdopt(\"plotly\")\n", + "m.get_derived_outputs_df().plot()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ea9661b", + "metadata": {}, + "outputs": [], + "source": [ + "pdopt()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4dfd245", + "metadata": {}, + "outputs": [], + "source": [ + "inoisy.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9fe7823f", + "metadata": {}, + "outputs": [], + "source": [ + "from estival import priors as esp\n", + "from estival import targets as est\n", + "from estival.model import BayesianCompartmentalModel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bde32a83", + "metadata": {}, + "outputs": [], + "source": [ + "defp = m.get_default_parameters()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09f06eca", + "metadata": {}, + "outputs": [], + "source": [ + "priors = [\n", + " #esp.UniformPrior(\"contact_rate\", (0.04,0.06)),\n", + " #esp.UniformPrior(\"waning_duration\", (15.0,25.0))\n", + "] + [\n", + " esp.UniformPrior(f\"suscept_{k}\", (-1.0,1.0)) for k in agegroups\n", + "]\n", + "targets = [est.TruncatedNormalTarget(\"incidence\",inoisy, (0.0,np.inf),inoisy*0.1)]\n", + "bcm = BayesianCompartmentalModel(m, defp, priors, targets)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61117fc1", + "metadata": {}, + "outputs": [], + "source": [ + "res = bcm.run(defp)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c4d4490", + "metadata": {}, + "outputs": [], + "source": [ + "def bcm_to_salib(bcm):\n", + " problem = {\n", + " \"num_vars\": len(bcm.priors),\n", + " \"names\": [p for p in bcm.priors],\n", + " \"bounds\": [p.bounds() for p in bcm.priors.values()]\n", + " }\n", + " return problem" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40ee6584", + "metadata": {}, + "outputs": [], + "source": [ + "problem = bcm_to_salib(bcm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c7e083f", + "metadata": {}, + "outputs": [], + "source": [ + "from SALib.analyze.sobol import analyze\n", + "from SALib.sample.sobol import sample" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "db82b01d", + "metadata": {}, + "outputs": [], + "source": [ + "from estival.sampling.tools import model_results_for_samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2e1c376", + "metadata": {}, + "outputs": [], + "source": [ + "samp_df = pd.DataFrame(sample(problem, 32,scramble=True,seed=0,skip_values=0), columns=bcm.priors)\n", + "samp_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b342c3b7", + "metadata": {}, + "outputs": [], + "source": [ + "res = model_results_for_samples(samp_df, bcm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdee8b60", + "metadata": {}, + "outputs": [], + "source": [ + "np.unique([k for k,_ in res.results.columns])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1ab25ae", + "metadata": {}, + "outputs": [], + "source": [ + "y = ()\n", + "\n", + "y = np.concatenate([y, res.results[\"incidence\"].sum().to_numpy()])\n", + "ares = analyze(problem, y)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43651ea9", + "metadata": {}, + "outputs": [], + "source": [ + "df = ares.to_df()[0].sort_values(\"ST\",ascending=False)\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9b1a332", + "metadata": {}, + "outputs": [], + "source": [ + "sobol = qmc.Sobol(len(bcm.priors),scramble=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24bb71f8", + "metadata": {}, + "outputs": [], + "source": [ + "samps = sobol.random(2048)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4bd8749", + "metadata": {}, + "outputs": [], + "source": [ + "sdf = bcm.sample.ppf(samps, \"pandas\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e43dbf5", + "metadata": {}, + "outputs": [], + "source": [ + "ssres = model_results_for_samples(sdf, bcm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19cfd68f", + "metadata": {}, + "outputs": [], + "source": [ + "sorted_res = sdf.loc[ssres.extras.sort_values(\"loglikelihood\",ascending=False).index]\n", + "best_params = sdf.loc[sorted_res.index[0]]\n", + "sorted_res" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21085255", + "metadata": {}, + "outputs": [], + "source": [ + "bcm.run(targetp).extras" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "577a3776", + "metadata": {}, + "outputs": [], + "source": [ + "bcm.run(best_params).extras" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f37a7a5b", + "metadata": {}, + "outputs": [], + "source": [ + "pd.DataFrame({\n", + " \"best\": bcm.run(best_params).derived_outputs[\"incidence\"],\n", + " \"target\": bcm.run(targetp).derived_outputs[\"incidence\"],\n", + " \"data\": inoisy}).plot()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35af7ed3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}