{ "cells": [ { "cell_type": "markdown", "id": "b8617e49-f31b-4404-bf15-4378de5c55eb", "metadata": {}, "source": [ "# OutOfSampleCausalTuning" ] }, { "cell_type": "markdown", "id": "27851131-170a-4feb-a8bd-8feeea3fcdd5", "metadata": {}, "source": [ "## Import and settings" ] }, { "cell_type": "code", "execution_count": 1, "id": "ef796beb-8a8a-4eb8-81e3-73c41e55a3ab", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from lingam import DirectLiNGAM\n", "from lingam.utils import make_dot\n", "\n", "from lingam.experimental import OutOfSampleCausalTuning\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "np.random.seed(0)" ] }, { "cell_type": "markdown", "id": "1ed7081b-b11f-4ae7-bbe0-28a23c87ff40", "metadata": {}, "source": [ "## Test data\n", "\n", "First we create simple structural data." ] }, { "cell_type": "code", "execution_count": 2, "id": "2a699e48-c86a-455e-8804-aa043005ca40", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", "<!-- Generated by graphviz version 2.43.0 (0)\n", " -->\n", "<!-- Title: %3 Pages: 1 -->\n", "<svg width=\"336pt\" height=\"392pt\"\n", " viewBox=\"0.00 0.00 336.00 392.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 388)\">\n", "<title>%3</title>\n", "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-388 332,-388 332,4 -4,4\"/>\n", "<!-- x0 -->\n", "<g id=\"node1\" class=\"node\">\n", "<title>x0</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-366\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"27\" y=\"-362.3\" font-family=\"Times,serif\" font-size=\"14.00\">x0</text>\n", "</g>\n", "<!-- x2 -->\n", "<g id=\"node3\" class=\"node\">\n", "<title>x2</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"160\" cy=\"-279\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"160\" y=\"-275.3\" font-family=\"Times,serif\" font-size=\"14.00\">x2</text>\n", "</g>\n", "<!-- x0->x2 -->\n", "<g id=\"edge1\" class=\"edge\">\n", "<title>x0->x2</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M45.77,-353C68.32,-338.59 106.34,-314.3 132.27,-297.72\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"134.35,-300.55 140.89,-292.21 130.58,-294.65 134.35,-300.55\"/>\n", "<text text-anchor=\"middle\" x=\"120.5\" y=\"-318.8\" font-family=\"Times,serif\" font-size=\"14.00\">-0.72</text>\n", "</g>\n", "<!-- x7 -->\n", "<g id=\"node8\" class=\"node\">\n", "<title>x7</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"83\" cy=\"-192\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"83\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">x7</text>\n", "</g>\n", "<!-- x0->x7 -->\n", "<g id=\"edge4\" class=\"edge\">\n", "<title>x0->x7</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M21.56,-348.12C15.7,-326.96 8.64,-289.98 20,-261 27.56,-241.71 43.45,-224.63 57.38,-212.41\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"59.85,-214.9 65.27,-205.79 55.36,-209.53 59.85,-214.9\"/>\n", "<text text-anchor=\"middle\" x=\"36\" y=\"-275.3\" font-family=\"Times,serif\" font-size=\"14.00\">0.21</text>\n", "</g>\n", "<!-- x1 -->\n", "<g id=\"node2\" class=\"node\">\n", "<title>x1</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"88\" cy=\"-279\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"88\" y=\"-275.3\" font-family=\"Times,serif\" font-size=\"14.00\">x1</text>\n", "</g>\n", "<!-- x1->x7 -->\n", "<g id=\"edge5\" class=\"edge\">\n", "<title>x1->x7</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M86.99,-260.8C86.3,-249.16 85.39,-233.55 84.6,-220.24\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"88.09,-219.95 84.01,-210.18 81.1,-220.36 88.09,-219.95\"/>\n", "<text text-anchor=\"middle\" x=\"102\" y=\"-231.8\" font-family=\"Times,serif\" font-size=\"14.00\">0.83</text>\n", "</g>\n", "<!-- x5 -->\n", "<g id=\"node6\" class=\"node\">\n", "<title>x5</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"227\" cy=\"-192\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"227\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">x5</text>\n", "</g>\n", "<!-- x2->x5 -->\n", "<g id=\"edge2\" class=\"edge\">\n", "<title>x2->x5</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M171.98,-262.8C182.11,-249.95 196.8,-231.32 208.45,-216.54\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"211.46,-218.37 214.9,-208.35 205.96,-214.04 211.46,-218.37\"/>\n", "<text text-anchor=\"middle\" x=\"215.5\" y=\"-231.8\" font-family=\"Times,serif\" font-size=\"14.00\">-0.88</text>\n", "</g>\n", "<!-- x2->x7 -->\n", "<g id=\"edge6\" class=\"edge\">\n", "<title>x2->x7</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M149.08,-262.08C141.85,-251.94 131.9,-238.74 122,-228 117.47,-223.08 112.32,-218.1 107.29,-213.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"109.38,-210.68 99.58,-206.65 104.73,-215.91 109.38,-210.68\"/>\n", "<text text-anchor=\"middle\" x=\"150\" y=\"-231.8\" font-family=\"Times,serif\" font-size=\"14.00\">0.38</text>\n", "</g>\n", "<!-- x3 -->\n", "<g id=\"node4\" class=\"node\">\n", "<title>x3</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"299\" cy=\"-192\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"299\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">x3</text>\n", "</g>\n", "<!-- x6 -->\n", "<g id=\"node7\" class=\"node\">\n", "<title>x6</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"294\" cy=\"-105\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"294\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">x6</text>\n", "</g>\n", "<!-- x3->x6 -->\n", "<g id=\"edge3\" class=\"edge\">\n", "<title>x3->x6</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M297.99,-173.8C297.3,-162.16 296.39,-146.55 295.6,-133.24\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"299.09,-132.95 295.01,-123.18 292.1,-133.36 299.09,-132.95\"/>\n", "<text text-anchor=\"middle\" x=\"312\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\">0.21</text>\n", "</g>\n", "<!-- x4 -->\n", "<g id=\"node5\" class=\"node\">\n", "<title>x4</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"155\" cy=\"-192\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"155\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">x4</text>\n", "</g>\n", "<!-- x8 -->\n", "<g id=\"node9\" class=\"node\">\n", "<title>x8</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"192\" cy=\"-105\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"192\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">x8</text>\n", "</g>\n", "<!-- x4->x8 -->\n", "<g id=\"edge7\" class=\"edge\">\n", "<title>x4->x8</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M159.39,-173.97C162.19,-164.13 166.19,-151.63 171,-141 172.58,-137.5 174.44,-133.92 176.38,-130.44\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"179.52,-132 181.58,-121.61 173.49,-128.45 179.52,-132\"/>\n", "<text text-anchor=\"middle\" x=\"187\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\">0.67</text>\n", "</g>\n", "<!-- x5->x8 -->\n", "<g id=\"edge8\" class=\"edge\">\n", "<title>x5->x8</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M220.09,-174.21C215.12,-162.14 208.32,-145.64 202.66,-131.89\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"205.83,-130.38 198.78,-122.47 199.35,-133.05 205.83,-130.38\"/>\n", "<text text-anchor=\"middle\" x=\"227\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\">0.78</text>\n", "</g>\n", "<!-- x9 -->\n", "<g id=\"node10\" class=\"node\">\n", "<title>x9</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"192\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"192\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">x9</text>\n", "</g>\n", "<!-- x6->x9 -->\n", "<g id=\"edge9\" class=\"edge\">\n", "<title>x6->x9</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M277.61,-90.34C261.03,-76.53 235.21,-55.01 216.26,-39.21\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"218.27,-36.34 208.35,-32.62 213.79,-41.71 218.27,-36.34\"/>\n", "<text text-anchor=\"middle\" x=\"267.5\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">-0.55</text>\n", "</g>\n", "<!-- x7->x9 -->\n", "<g id=\"edge10\" class=\"edge\">\n", "<title>x7->x9</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M86.21,-174.03C90.87,-152.46 100.97,-114.7 119,-87 131.44,-67.89 150.38,-50.46 165.86,-38.03\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"168.3,-40.56 174.03,-31.66 163.99,-35.04 168.3,-40.56\"/>\n", "<text text-anchor=\"middle\" x=\"137.5\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">-0.45</text>\n", "</g>\n", "<!-- x8->x9 -->\n", "<g id=\"edge11\" class=\"edge\">\n", "<title>x8->x9</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M192,-86.8C192,-75.16 192,-59.55 192,-46.24\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"195.5,-46.18 192,-36.18 188.5,-46.18 195.5,-46.18\"/>\n", "<text text-anchor=\"middle\" x=\"210.5\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">-0.76</text>\n", "</g>\n", "</g>\n", "</svg>\n" ], "text/plain": [ "<graphviz.graphs.Digraph at 0x7f4ffaf74520>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "(1000, 10)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "node_num = 10\n", "n_samples = 1000\n", "\n", "# make graph\n", "graph = np.random.choice([0, 1], p=[0.8, 0.2], size=(node_num, node_num))\n", "graph= np.tril(graph, k=-1)\n", "\n", "# force to have child nodes\n", "for i in range(node_num - 1):\n", " if np.sum(graph[:, i]) == 0:\n", " pos = np.random.choice(range(i + 1, node_num))\n", " graph[pos, i] = 1\n", "\n", "# coefficients\n", "graph = graph * np.random.uniform(0.2, 1.0, size=graph.shape) * np.random.choice([-1, 1], size=graph.shape)\n", "\n", "display(make_dot(graph))\n", "\n", "# generate data\n", "X = np.zeros((n_samples, graph.shape[0]))\n", "for i in range(graph.shape[0]):\n", " X[:, i] = graph[i] @ X.T + np.random.uniform(0, 1, size=n_samples)\n", "\n", "X = pd.DataFrame(X, columns=[f\"x{i}\" for i in range(node_num)])\n", "X.shape" ] }, { "cell_type": "markdown", "id": "d835862c-98d9-49ae-a87f-332e19a0ae18", "metadata": {}, "source": [ "## Preparation of configurations\n", "\n", "We prepare two configurations. Each configuration contains a specification of the algorithm for causal discovery and the second configuration contains a keyword argument of the constructor of the causal discovery. We set a prior knowledge to prevent a graph from having edges where it should have edges." ] }, { "cell_type": "code", "execution_count": 3, "id": "cf1b1ad6-76c0-42a7-b01a-5bdc5cc8cee9", "metadata": {}, "outputs": [], "source": [ "configs = [\n", " {\n", " \"model\": DirectLiNGAM\n", " },\n", " {\n", " \"model\": DirectLiNGAM,\n", " \"init_kwargs\": {\n", " \"prior_knowledge\": graph.astype(bool) - 1\n", " }\n", " },\n", "]" ] }, { "cell_type": "markdown", "id": "4471fa4e-06a6-418a-8c4f-f51aae1c4ab2", "metadata": {}, "source": [ "## Initialization\n", "\n", "Initialization can take two main arguments. `cv` specifies the number of folds and `thr` specifies a threshold of the permutation test in the sparsity penalty." ] }, { "cell_type": "code", "execution_count": 4, "id": "47e96818-fe4f-4974-b16f-91a462e57317", "metadata": {}, "outputs": [], "source": [ "model = OutOfSampleCausalTuning(cv=10, thr=0.05)" ] }, { "cell_type": "markdown", "id": "635a540d-2141-47c9-ad36-f24f55897e00", "metadata": {}, "source": [ "## Run" ] }, { "cell_type": "code", "execution_count": 5, "id": "148d5eec-225e-4829-8e2d-11eca20e605d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<lingam.experimental.oct.OutOfSampleCausalTuning at 0x7f5031230c40>" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(X, configs)" ] }, { "cell_type": "markdown", "id": "ae35ef75-da93-4c74-8881-9aa75a968ddf", "metadata": {}, "source": [ "The result is stored in `best_config_index_`." ] }, { "cell_type": "code", "execution_count": 6, "id": "b9e693d7-2987-496f-8f03-9949d9051f5a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n" ] } ], "source": [ "print(model.best_config_index_)" ] }, { "cell_type": "markdown", "id": "8155b91a-ed96-40d5-ac08-2ab4d8000569", "metadata": {}, "source": [ "The result is 0 as we expected." ] }, { "cell_type": "markdown", "id": "dc6c475a-3301-432a-9295-e28ce14aea4f", "metadata": {}, "source": [ "## More information about the result\n", "\n", "Information in fit() is stored in `scores_`. \n", "\n", "`performace` stores scores of given configurations. `mb_size` stores the average markov blanket size for each configuration. `sp_score` stores the sparsity penalty score in dictionary form and its key is the index of the configuration and its values is the score of the configuration. The configurations with the best scores are not in `sp_score`." ] }, { "cell_type": "code", "execution_count": 7, "id": "bc1f7cff-7c6c-47de-b12a-3bf9dc535cb5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'performance': [0.3024831563927093, 0.30094328627844685],\n", " 'mb_size': [3.4, 4.0],\n", " 'sp_score': {1: 0.0}}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(model.scores_)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 5 }