diff --git a/renew/06-calibration.ipynb b/renew/06-calibration.ipynb index 6464567..4eadff1 100644 --- a/renew/06-calibration.ipynb +++ b/renew/06-calibration.ipynb @@ -86,7 +86,7 @@ "priors.append(esp.UniformPrior(proc_req['name'], (proc_req['lower'], proc_req['upper']), size=n_process_periods))\n", "renewal_model = RenewalModel(calib_kwargs['pop'], n_times, run_in, n_process_periods)\n", "obj_func = get_obj_func(renewal_model)\n", - "n_draws = 100\n", + "n_draws = 100 # This is obviously far too short - demonstration code only\n", "\n", "with pm.Model() as pmm:\n", " variables = use_model(priors, obj_func)\n", @@ -103,16 +103,6 @@ "az.summary(idata)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "8e294388", - "metadata": {}, - "outputs": [], - "source": [ - "mean_posterior_params = mpp = np.array(az.summary(idata)[\"mean\"])" - ] - }, { "cell_type": "code", "execution_count": null, @@ -120,16 +110,9 @@ "metadata": {}, "outputs": [], "source": [ - "incidence = renewal_model.func(mpp[0], mpp[1], mpp[4:], np.log(mpp[3])).incidence * mpp[2]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5f77ea75-5a92-45ec-bafc-2e417d7457c2", - "metadata": {}, - "outputs": [], - "source": [ + "# Pull out a parameter set to look at a single run\n", + "mpp = np.array(az.summary(idata)[\"mean\"])\n", + "incidence = renewal_model.func(mpp[0], mpp[1], mpp[4:], np.log(mpp[3])).incidence * mpp[2]\n", "inc = pd.DataFrame(incidence)\n", "inc[\"targets\"] = mys_data\n", "inc.plot()" @@ -159,9 +142,10 @@ "spaghetti = pd.DataFrame()\n", "for i, p in enumerate(sample_params):\n", " incidence = renewal_model.func(p['gen_mean'], p['gen_sd'], p['random_process'], np.log(p['seed'])).incidence\n", - " cdr = p['cdr']\n", - " spaghetti[i] = incidence * cdr\n", - "spaghetti.columns = sample_params.index.to_flat_index().map(str)" + " spaghetti[i] = incidence * p['cdr']\n", + "spaghetti.columns = sample_params.index.to_flat_index().map(str)\n", + "spaghetti['targets'] = calib_kwargs['targets']\n", + "spaghetti.plot()" ] }, { @@ -177,17 +161,6 @@ "params_table = params_table.rename_axis(None)\n", "params_table = params_table.rename(columns={'init': 'Starting value', 'lower': 'Lower limit', 'upper': 'Upper limit'})" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "70ab40ed-8da4-4464-bb7b-64058bf063f1", - "metadata": {}, - "outputs": [], - "source": [ - "spaghetti['targets'] = calib_kwargs['targets']\n", - "spaghetti.plot()" - ] } ], "metadata": {