From b21cd203eab2bcc931629c8a5ae5b35d7b68550a Mon Sep 17 00:00:00 2001
From: Aaron Zuspan <50475791+aazuspan@users.noreply.github.com>
Date: Sun, 12 Mar 2023 16:49:10 -0700
Subject: [PATCH] Add theme parameter to sankify #39
---
sankee/datasets.py | 10 +++--
sankee/plotting.py | 93 +++++++++++++++++++-----------------------
sankee/themes.py | 82 +++++++++++++++++++++++++++++++++++++
tests/test_plotting.py | 1 +
4 files changed, 133 insertions(+), 53 deletions(-)
create mode 100644 sankee/themes.py
diff --git a/sankee/datasets.py b/sankee/datasets.py
index 746755d..548537c 100644
--- a/sankee/datasets.py
+++ b/sankee/datasets.py
@@ -4,6 +4,7 @@
import pandas as pd
import plotly.graph_objects as go
+from sankee import themes
from sankee.plotting import sankify
@@ -111,6 +112,7 @@ def sankify(
seed: int = 0,
exclude: None = None,
label_type: str = "class",
+ theme: Union[str, themes.Theme] = themes.DEFAULT,
) -> go.Figure:
"""
Generate an interactive Sankey plot showing land cover change over time from a series of
@@ -144,9 +146,10 @@ def sankify(
exclude : None
Unused parameter that will be removed in a future release.
label_type : str, default "class"
- The type of label to display for each link, one of "class", "percent", or "count". Selecting
- "class" will use the class label, "percent" will use the proportion of sampled pixels in each
- class, and "count" will use the number of sampled pixels in each class.
+ The type of label to display for each link, one of "class", "percent", or "count".
+ Selecting "class" will use the class label, "percent" will use the proportion of
+ sampled pixels in each class, and "count" will use the number of sampled pixels in each
+ class.
Returns
-------
@@ -180,6 +183,7 @@ def sankify(
scale=scale,
seed=seed,
label_type=label_type,
+ theme=theme,
)
diff --git a/sankee/plotting.py b/sankee/plotting.py
index 8080529..51a578b 100644
--- a/sankee/plotting.py
+++ b/sankee/plotting.py
@@ -6,7 +6,7 @@
import pandas as pd
import plotly.graph_objects as go
-from sankee import sampling, utils
+from sankee import sampling, themes, utils
SankeyParameters = namedtuple(
"SankeyParameters",
@@ -36,6 +36,7 @@ def sankify(
scale: Union[None, int] = None,
seed: int = 0,
label_type: Union[None, str] = "class",
+ theme: Union[str, themes.Theme] = themes.DEFAULT,
) -> go.Figure:
"""
Generate an interactive Sankey plot showing land cover change over time from a series of images.
@@ -75,9 +76,13 @@ def sankify(
seed : int, default 0
The seed value used to generate repeatable results during random sampling.
label_type : str, default "class"
- The type of label to display for each link, one of "class", "percent", "count", or False. Selecting
- "class" will use the class label, "percent" will use the proportion of sampled pixels in each
- class, and "count" will use the number of sampled pixels in each class. False will disable link labels.
+ The type of label to display for each link, one of "class", "percent", "count", or False.
+ Selecting "class" will use the class label, "percent" will use the proportion of sampled
+ pixels in each class, and "count" will use the number of sampled pixels in each class.
+ False will disable link labels.
+ theme : str or themes.Theme
+ The theme to apply to the Sankey diagram. Can be the name of a built-in theme (e.g. "d3") or
+ a custom `sankee.themes.Theme` object.
Returns
-------
@@ -113,6 +118,7 @@ def sankify(
title=title,
samples=samples,
label_type=label_type,
+ theme=theme,
)
@@ -125,6 +131,7 @@ def __init__(
title: str,
samples: ee.FeatureCollection,
label_type: str,
+ theme: Union[str, themes.Theme],
):
self.data = data
self.labels = labels
@@ -132,6 +139,7 @@ def __init__(
self.title = title
self.samples = samples
self.label_type = label_type
+ self.theme = theme
self.hide = []
# Initialized by `self.generate_plot`
@@ -196,12 +204,14 @@ def generate_plot_parameters(self) -> SankeyParameters:
# Calculate the proportion of each class in each year
melted = self.data.melt(var_name="year")
melted = melted.groupby(["year", "value"]).size().reset_index(name="count")
- melted["proportion_of_total"] = (melted
- .groupby("year")["count"]
+ melted["proportion_of_total"] = (
+ melted.groupby("year")["count"]
.transform(lambda x: x / x.sum())
.apply(lambda x: f"{x:.0%}")
)
- all_classes = all_classes.merge(melted, left_on=["year", "class"], right_on=["year", "value"])
+ all_classes = all_classes.merge(
+ melted, left_on=["year", "class"], right_on=["year", "value"]
+ )
if self.label_type == "class":
all_classes["label"] = all_classes["class"].apply(lambda k: self.labels[k])
@@ -212,7 +222,9 @@ def generate_plot_parameters(self) -> SankeyParameters:
elif not self.label_type:
all_classes["label"] = ""
else:
- raise ValueError("Invalid label_type. Choose from 'class', 'percent', 'count', or False.")
+ raise ValueError(
+ "Invalid label_type. Choose from 'class', 'percent', 'count', or False."
+ )
return SankeyParameters(
node_labels=all_classes.year,
@@ -362,58 +374,39 @@ def generate_plot(self) -> go.Figure:
self.df = self.generate_dataframe()
params = self.generate_plot_parameters()
- shadow_color = "#76777a"
- label_style = f"""
- color: #fff;
- font-weight: 600;
- letter-spacing: -1px;
- text-shadow:
- 0 0 4px black,
- -1px 1px 0 {shadow_color},
- 1px 1px 0 {shadow_color},
- 1px -1px 0 {shadow_color},
- -1px -1px 0 {shadow_color};
- """
-
- title_style = """
- color: #fff;
- font-weight: 900;
- word-spacing: 10px;
- letter-spacing: 3px;
- text-shadow:
- 0 0 1px black,
- 0 0 2px black,
- 0 0 4px black;
- """
+ theme = (
+ self.theme if isinstance(self.theme, themes.Theme) else themes.load_theme(self.theme)
+ )
+
+ node_kwargs = dict(
+ customdata=params.node_labels,
+ hovertemplate="%{customdata}",
+ label=[f"{s}" for s in params.label],
+ color=params.node_palette,
+ )
+ link_kwargs = dict(
+ source=params.source,
+ target=params.target,
+ value=params.value,
+ color=params.link_palette,
+ customdata=params.link_labels,
+ hovertemplate="%{customdata} ",
+ )
fig = go.FigureWidget(
data=[
go.Sankey(
arrangement="snap",
- node=dict(
- pad=30,
- thickness=10,
- line=dict(color="#505050", width=1.5),
- customdata=params.node_labels,
- hovertemplate="%{customdata}",
- label=[f"{s}" for s in params.label],
- color=params.node_palette,
- ),
- link=dict(
- source=params.source,
- target=params.target,
- line=dict(color="#909090", width=1),
- value=params.value,
- color=params.link_palette,
- customdata=params.link_labels,
- hovertemplate="%{customdata} ",
- ),
+ node={**node_kwargs, **theme.node_kwargs},
+ link={**link_kwargs, **theme.link_kwargs},
)
]
)
fig.update_layout(
- title_text=f"{self.title}" if self.title else None,
+ title_text=f"{self.title}"
+ if self.title
+ else None,
font_size=16,
title_x=0.5,
paper_bgcolor="rgba(0, 0, 0, 0)",
diff --git a/sankee/themes.py b/sankee/themes.py
new file mode 100644
index 0000000..5087045
--- /dev/null
+++ b/sankee/themes.py
@@ -0,0 +1,82 @@
+from typing import Dict, Union
+
+
+class Theme:
+ def __init__(
+ self,
+ label_style: Union[None, str] = None,
+ title_style: Union[None, str] = None,
+ node_kwargs: Union[None, Dict] = None,
+ link_kwargs: Union[None, Dict] = None,
+ ):
+ self.label_style = label_style
+ self.title_style = title_style
+ self.node_kwargs = node_kwargs if node_kwargs is not None else {}
+ self.link_kwargs = link_kwargs if link_kwargs is not None else {}
+
+
+DEFAULT = Theme(
+ node_kwargs=dict(
+ pad=30,
+ thickness=10,
+ line=dict(color="#505050", width=1.5),
+ ),
+ link_kwargs=dict(
+ line=dict(color="#909090", width=1),
+ ),
+ label_style="""
+ color: #fff;
+ font-weight: 600;
+ letter-spacing: -1px;
+ text-shadow:
+ 0 0 4px black,
+ -1px 1px 0 #76777a,
+ 1px 1px 0 #76777a,
+ 1px -1px 0 #76777a,
+ -1px -1px 0 #76777a;
+ """,
+ title_style="""
+ color: #fff;
+ font-weight: 900;
+ word-spacing: 10px;
+ letter-spacing: 3px;
+ text-shadow:
+ 0 0 1px black,
+ 0 0 2px black,
+ 0 0 4px black;
+ """,
+)
+
+
+D3 = Theme(
+ node_kwargs=dict(line=dict(width=1), pad=20, thickness=15),
+ link_kwargs=dict(color="rgba(120, 120, 120, 0.25)"),
+)
+
+SIMPLE = Theme(
+ node_kwargs=dict(line=dict(width=0), pad=60, thickness=30),
+ link_kwargs=dict(color="rgba(120, 120, 120, 0.25)"),
+ label_style="""
+ color: #666666;
+ font-size: 18px;
+ color: #666666;
+ """,
+ title_style="""
+ color: #666666;
+ font-size: 24px;
+ font-weight: 900;
+ """,
+)
+
+
+THEMES = {
+ "default": DEFAULT,
+ "d3": D3,
+ "simple": SIMPLE,
+}
+
+
+def load_theme(theme_name):
+ if theme_name not in THEMES:
+ raise ValueError(f"Theme `{theme_name}` not found. Choose from {list(THEMES.keys())}.")
+ return THEMES[theme_name]
diff --git a/tests/test_plotting.py b/tests/test_plotting.py
index e3252a1..d1aa9bf 100644
--- a/tests/test_plotting.py
+++ b/tests/test_plotting.py
@@ -21,6 +21,7 @@ def sankey():
title="",
samples=None,
label_type="class",
+ theme="default",
)