Skip to content

Commit

Permalink
intro tutorial: Switch to using Seaborn
Browse files Browse the repository at this point in the history
Co-authored-by: rht <rhtbot@protonmail.com>
  • Loading branch information
EwoutH and rht committed Jun 17, 2023
1 parent 000408d commit f7db1f5
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 22 deletions.
68 changes: 47 additions & 21 deletions docs/tutorials/intro_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"Install Mesa:\n",
"\n",
"```bash\n",
"pip install mesa\n",
"pip install --upgrade mesa\n",
"```\n",
"\n",
"Install Jupyter Notebook (optional):\n",
Expand Down Expand Up @@ -130,8 +130,8 @@
"source": [
"import mesa\n",
"\n",
"# Data visualization tool.\n",
"import matplotlib.pyplot as plt\n",
"# Data visualization tools.\n",
"import seaborn as sns\n",
"\n",
"# Has multi-dimensional arrays and matrices. Has a large collection of\n",
"# mathematical functions to operate on these arrays.\n",
Expand Down Expand Up @@ -509,6 +509,7 @@
"If you are running from a text editor or IDE, you'll also need to add this line, to make the graph appear.\n",
"\n",
"```python\n",
"import matplotlib.pyplot as plt\n",
"plt.show()\n",
"```"
]
Expand All @@ -531,7 +532,11 @@
"import matplotlib.pyplot as plt\n",
"\n",
"agent_wealth = [a.wealth for a in model.schedule.agents]\n",
"plt.hist(agent_wealth)"
"# Create a histogram with seaborn\n",
"g = sns.histplot(agent_wealth, discrete=True)\n",
"g.set(\n",
" title=\"Wealth distribution\", xlabel=\"Wealth\", ylabel=\"Number of agents\"\n",
"); # The semicolon is just to avoid printing the object representation"
]
},
{
Expand Down Expand Up @@ -571,7 +576,9 @@
" for agent in model.schedule.agents:\n",
" all_wealth.append(agent.wealth)\n",
"\n",
"plt.hist(all_wealth, bins=range(max(all_wealth) + 1))"
"# Use seaborn\n",
"g = sns.histplot(all_wealth, discrete=True)\n",
"g.set(title=\"Wealth distribution\", xlabel=\"Wealth\", ylabel=\"Number of agents\");"
]
},
{
Expand Down Expand Up @@ -758,7 +765,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create a model with 50 agents on a 10x10 grid, and run it for 20 steps."
"Let's create a model with 100 agents on a 10x10 grid, and run it for 20 steps."
]
},
{
Expand All @@ -772,7 +779,7 @@
},
"outputs": [],
"source": [
"model = MoneyModel(50, 10, 10)\n",
"model = MoneyModel(100, 10, 10)\n",
"for i in range(20):\n",
" model.step()"
]
Expand Down Expand Up @@ -802,11 +809,10 @@
" cell_content, x, y = cell\n",
" agent_count = len(cell_content)\n",
" agent_counts[x][y] = agent_count\n",
"plt.imshow(agent_counts, interpolation=\"nearest\")\n",
"plt.colorbar()\n",
"\n",
"# If running from a text editor or IDE, remember you'll need the following:\n",
"# plt.show()"
"# Plot using seaborn, with a size of 5x5\n",
"g = sns.heatmap(agent_counts, cmap=\"viridis\", annot=True, cbar=False, square=True)\n",
"g.figure.set_size_inches(4, 4)\n",
"g.set(title=\"Number of agents on each cell of the grid\");"
]
},
{
Expand Down Expand Up @@ -923,7 +929,7 @@
},
"outputs": [],
"source": [
"model = MoneyModel(50, 10, 10)\n",
"model = MoneyModel(100, 10, 10)\n",
"for i in range(100):\n",
" model.step()"
]
Expand All @@ -947,7 +953,9 @@
"outputs": [],
"source": [
"gini = model.datacollector.get_model_vars_dataframe()\n",
"gini.plot()"
"# Plot the Gini coefficient over time\n",
"g = sns.lineplot(data=gini)\n",
"g.set(title=\"Gini Coefficient over Time\", ylabel=\"Gini Coefficient\");"
]
},
{
Expand Down Expand Up @@ -976,7 +984,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"You'll see that the DataFrame's index is pairings of model step and agent ID. You can analyze it the way you would any other DataFrame. For example, to get a histogram of agent wealth at the model's end:"
"You'll see that the DataFrame's index is pairings of model step and agent ID. This is because the data collector stores the data in a dictionary, with the step number as the key, and a dictionary of agent ID and variable value pairs as the value. The data collector then converts this dictionary into a DataFrame, which is why the index is a pair of (model step, agent ID). You can analyze it the way you would any other DataFrame. For example, to get a histogram of agent wealth at the model's end:"
]
},
{
Expand All @@ -990,8 +998,15 @@
},
"outputs": [],
"source": [
"end_wealth = agent_wealth.xs(99, level=\"Step\")[\"Wealth\"]\n",
"end_wealth.hist(bins=range(agent_wealth.Wealth.max() + 1))"
"last_step = agent_wealth.index.get_level_values(\"Step\").max()\n",
"end_wealth = agent_wealth.xs(last_step, level=\"Step\")[\"Wealth\"]\n",
"# Create a histogram of wealth at the last step\n",
"g = sns.histplot(end_wealth, discrete=True)\n",
"g.set(\n",
" title=\"Distribution of wealth at the end of simulation\",\n",
" xlabel=\"Wealth\",\n",
" ylabel=\"Number of agents\",\n",
");"
]
},
{
Expand All @@ -1012,8 +1027,12 @@
},
"outputs": [],
"source": [
"# Get the wealth of agent 14 over time\n",
"one_agent_wealth = agent_wealth.xs(14, level=\"AgentID\")\n",
"one_agent_wealth.Wealth.plot()"
"\n",
"# Plot the wealth of agent 14 over time\n",
"g = sns.lineplot(data=one_agent_wealth, x=\"Step\", y=\"Wealth\")\n",
"g.set(title=\"Wealth of agent 14 over time\");"
]
},
{
Expand Down Expand Up @@ -1235,10 +1254,17 @@
},
"outputs": [],
"source": [
"# Filter the results to only contain the data of one agent (the Gini coefficient will be the same for the entire population at any time) at the 100th step of each episode\n",
"results_filtered = results_df[(results_df.AgentID == 0) & (results_df.Step == 100)]\n",
"N_values = results_filtered.N.values\n",
"gini_values = results_filtered.Gini.values\n",
"plt.scatter(N_values, gini_values)"
"results_filtered[[\"iteration\", \"N\", \"Gini\"]].reset_index(\n",
" drop=True\n",
").head() # Create a scatter plot\n",
"g = sns.scatterplot(data=results_filtered, x=\"N\", y=\"Gini\")\n",
"g.set(\n",
" xlabel=\"Number of agents\",\n",
" ylabel=\"Gini coefficient\",\n",
" title=\"Gini coefficient vs. number of agents\",\n",
");"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# Explicitly install ipykernel for Python 3.8.
# See https://stackoverflow.com/questions/28831854/how-do-i-add-python3-kernel-to-jupyter-ipython
# Could be removed in the future
"docs": ["sphinx<7", "ipython", "nbsphinx", "ipykernel", "pydata_sphinx_theme"],
"docs": ["sphinx<7", "ipython", "nbsphinx", "ipykernel", "pydata_sphinx_theme", "seaborn"],
}

version = ""
Expand Down

0 comments on commit f7db1f5

Please sign in to comment.