Skip to content

Commit

Permalink
Notebook for mcmc transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Sep 7, 2023
1 parent 40e8b1f commit 64a1915
Showing 1 changed file with 107 additions and 0 deletions.
107 changes: 107 additions & 0 deletions notebooks/user/rragonnet/mcmc_transform.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prior bounds"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a, b = 0, 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Default transformation (standard with pymc, stan...)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_default_transform():\n",
" default_transform = lambda x: np.log(x - a) - np.log(b - x)\n",
"\n",
" x = np.linspace(a, b, num=1000)[1:-1] \n",
" y = default_transform(x)\n",
" plt.plot(x, y, label=\"Default transform\")\n",
" plt.xlim((a - .5 * (b-a), b + .5 * (b-a)))\n",
" plt.xlabel(\"Original parameter\")\n",
" plt.ylabel(\"Transformed parameter\")\n",
" plt.vlines(x=[a, b], ymin=plt.gca().get_ylim()[0], ymax=plt.gca().get_ylim()[1], linestyles=[\"--\", \"--\"], color='black')\n",
"\n",
"\n",
"plot_default_transform()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tweaked transformation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_tweaked_transform(eps=.1):\n",
" a_prime = a - eps * (b - a)\n",
" b_prime = b + eps * (b - a)\n",
" tweaked_transform = lambda x: np.log(x - a_prime) - np.log(b_prime - x)\n",
"\n",
" x = np.linspace(a_prime, b_prime, num=1000)[1:-1] \n",
" y = tweaked_transform(x)\n",
" plot_default_transform()\n",
" plt.plot(x, y, label=\"Tweaked transform\")\n",
" plt.legend()\n",
"\n",
"plot_tweaked_transform(.1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "summer2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 64a1915

Please sign in to comment.