Skip to content

Commit

Permalink
Visualize policies with Gantt chart. (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasraabe authored May 9, 2021
1 parent 363f231 commit 240db8c
Show file tree
Hide file tree
Showing 8 changed files with 387 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/source/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ all releases are available on `Anaconda.org
0.0.5 - 2021-05-09
------------------

- :gh:`113` implements a gantt chart to visualize policies.
- :gh:`115` allows to have heterogeneous effects of seasonality on contact models.
- :gh:`116` adds a plot to investigate which contact model caused how many infections
and :gh:`118` makes the data preparation more performant.
Expand Down
24 changes: 13 additions & 11 deletions docs/source/reference_guides/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
Policies
========

.. seealso::

For a hands-on example on how to specify contact policies, look at the `tutorial
about contact policies <../tutorials/how_to_specify_policies.ipynb>`_.

In sid we can implement nearly any type of policy as a modification of the
:ref:`contact_models`. However, to keep things separable and modular, policies can also
specified outside the contact models in a separate, specialized ``contact_policies``
Expand Down Expand Up @@ -33,14 +38,11 @@ To specify when the policy is active, you have three options:

.. code-block:: python
{
"1st_lockdown_school": {
"affected_contact_model": "school",
"policy": 0,
"start": "2020-03-22",
"end": "2020-04-20",
},
}
For an example on how to specify contact policies, look at the `contact policies
tutorial <../tutorials/how_to_specify_policies.ipynb>`_
{
"1st_lockdown_school": {
"affected_contact_model": "school",
"policy": 0,
"start": "2020-03-22",
"end": "2020-04-20",
},
}
107 changes: 107 additions & 0 deletions docs/source/tutorials/how_to_visualize.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5f1e1d38-7c8f-4ab2-bae3-cdb18bde476d",
"metadata": {},
"source": [
"# How to visualize"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49d4dac5-32b2-45da-bc62-f9e801c066af",
"metadata": {},
"outputs": [],
"source": [
"import sid"
]
},
{
"cell_type": "markdown",
"id": "c7a90f2a-3f30-41b4-897c-ec5854c1b132",
"metadata": {},
"source": [
"## Gantt chart of policies"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ecd98784-69b5-454b-96c2-cd03ca15d318",
"metadata": {},
"outputs": [],
"source": [
"policies = {\n",
" \"closed_schools\": {\n",
" \"affected_contact_model\": \"school\",\n",
" \"start\": \"2020-03-09\",\n",
" \"end\": \"2020-05-31\",\n",
" \"policy\": 0,\n",
" },\n",
" \"partially_closed_schools\": {\n",
" \"affected_contact_model\": \"school\",\n",
" \"start\": \"2020-06-01\",\n",
" \"end\": \"2020-09-30\",\n",
" \"policy\": 0.5,\n",
" },\n",
" \"partially_closed_kindergarden\": {\n",
" \"affected_contact_model\": \"school\",\n",
" \"start\": \"2020-05-20\",\n",
" \"end\": \"2020-06-30\",\n",
" \"policy\": 0.5,\n",
" },\n",
" \"work_closed\": {\n",
" \"affected_contact_model\": \"work\",\n",
" \"start\": \"2020-03-09\",\n",
" \"end\": \"2020-06-15\",\n",
" \"policy\": 0.4,\n",
" },\n",
" \"work_partially_opened\": {\n",
" \"affected_contact_model\": \"work\",\n",
" \"start\": \"2020-05-01\",\n",
" \"end\": \"2020-08-15\",\n",
" \"policy\": 0.7,\n",
" },\n",
" \"closed_leisure_activities\": {\n",
" \"affected_contact_model\": \"leisure\",\n",
" \"start\": \"2020-03-09\",\n",
" \"policy\": 0,\n",
" },\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7651e5fe-916c-4f16-9a18-1c01fea247e6",
"metadata": {},
"outputs": [],
"source": [
"sid.plotting.plot_policy_gantt_chart(policies, effects=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions docs/source/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ simple model and show you how to use sid to simulate your own model.
how_to_reduce_memory_consumption
how_to_resume_a_simulation
how_to_simulate_multiple_virus_strains
how_to_visualize
how_to_plot_infection_rates_by_contact_models
2 changes: 1 addition & 1 deletion src/sid/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def _mono_list_to_triangle(mono_list):
CAT_LIST = [
"#547482",
"#C87259",
"#C2D8C2",
"#F1B05D",
"#818662",
"#C2D8C2",
"#6C4A4D",
"#7A8C87",
"#EE8445",
Expand Down
152 changes: 152 additions & 0 deletions src/sid/plotting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,158 @@
import itertools

import dask.dataframe as dd
import holoviews as hv
import pandas as pd
from bokeh.models import HoverTool
from sid.colors import get_colors
from sid.policies import compute_pseudo_effect_sizes_of_policies


DEFAULT_FIGURE_KWARGS = {
"height": 400,
"width": 600,
"line_width": 12,
"title": "Gantt Chart of Policies",
}


def plot_policy_gantt_chart(
policies,
effects=False,
colors="categorical",
fig_kwargs=None,
):
"""Plot a Gantt chart of the policies."""
if fig_kwargs is None:
fig_kwargs = {}
fig_kwargs = {**DEFAULT_FIGURE_KWARGS, **fig_kwargs}

if isinstance(policies, dict):
df = (
pd.DataFrame(policies)
.T.reset_index()
.rename(columns={"index": "name"})
.astype({"start": "datetime64", "end": "datetime64"})
.drop(columns="policy")
)
elif isinstance(policies, pd.DataFrame):
df = policies
else:
raise ValueError("'policies' should be either a dict or pandas.DataFrame.")

if effects:
effect_kwargs = effects if isinstance(effects, dict) else {}
effects = compute_pseudo_effect_sizes_of_policies(
policies=policies, **effect_kwargs
)
effects_s = pd.DataFrame(
[{"policy": name, "effect": effects[name]["mean"]} for name in effects]
).set_index("policy")["effect"]
df = df.merge(effects_s, left_on="name", right_index=True)
df["alpha"] = (1 - df["effect"] + 0.1) / 1.1
else:
df["alpha"] = 1

df = df.reset_index()
df = _complete_dates(df)
df = _add_color_to_gantt_groups(df, colors)
df = _add_positions(df)

hv.extension("bokeh", logo=False)

segments = hv.Segments(
df,
[
hv.Dimension("start", label="Date"),
hv.Dimension("position", label="Affected contact model"),
"end",
"position",
],
)
y_ticks_and_labels = list(zip(*_create_y_ticks_and_labels(df)))

tooltips = [("Name", "@name")]
if effects:
tooltips.append(("Effect", "@effect"))
hover = HoverTool(tooltips=tooltips)

gantt = segments.opts(
color="color",
alpha="alpha",
tools=[hover],
yticks=y_ticks_and_labels,
**fig_kwargs,
)

return gantt


def _complete_dates(df):
"""Complete dates."""
for column in ("start", "end"):
df[column] = pd.to_datetime(df[column])
df["start"] = df["start"].fillna(df["start"].min())
df["end"] = df["end"].fillna(df["end"].max())
return df


def _add_color_to_gantt_groups(df, colors):
"""Add a color for each affected contact model."""
colors_ = itertools.cycle(get_colors(colors, 4))
acm_to_color = dict(zip(df["affected_contact_model"].unique(), colors_))
df["color"] = df["affected_contact_model"].replace(acm_to_color)

return df


def _add_positions(df):
"""Add positions.
This functions computes the positions of policies, displayed as segments on the time
line. For example, if two policies affecting the same contact model have an
overlapping time windows, the segments are stacked and drawn onto different
horizontal lines.
"""
min_position = 0

def _add_within_group_positions(df):
"""Add within group positions."""
nonlocal min_position
position = pd.Series(data=min_position, index=df.index)
for i in range(1, len(df)):
start = df.iloc[i]["start"]
end = df.iloc[i]["end"]
is_overlapping = (
(df.iloc[:i]["start"] <= start) & (start <= df.iloc[:i]["end"])
) | ((df.iloc[:i]["start"] <= end) & (end <= df.iloc[:i]["end"]))
if is_overlapping.any():
possible_positions = set(range(min_position, i + min_position + 1))
positions_of_overlapping = set(position.iloc[:i][is_overlapping])
position.iloc[i] = min(possible_positions - positions_of_overlapping)

min_position = max(position) + 1

return position

positions = df.groupby("affected_contact_model", group_keys=False).apply(
_add_within_group_positions
)
df["position_local"] = positions
df["position"] = df.groupby(
["affected_contact_model", "position_local"], sort=True
).ngroup()

return df


def _create_y_ticks_and_labels(df):
"""Create the positions and their related labels for the y axis."""
pos_per_group = df.groupby("position", as_index=False).first()
mean_pos_per_group = (
pos_per_group.groupby("affected_contact_model")["position"].mean().reset_index()
)
return mean_pos_per_group["position"], mean_pos_per_group["affected_contact_model"]


ERROR_MISSING_CHANNEL = (
Expand Down
Loading

0 comments on commit 240db8c

Please sign in to comment.