From b21cd203eab2bcc931629c8a5ae5b35d7b68550a Mon Sep 17 00:00:00 2001 From: Aaron Zuspan <50475791+aazuspan@users.noreply.github.com> Date: Sun, 12 Mar 2023 16:49:10 -0700 Subject: [PATCH] Add theme parameter to sankify #39 --- sankee/datasets.py | 10 +++-- sankee/plotting.py | 93 +++++++++++++++++++----------------------- sankee/themes.py | 82 +++++++++++++++++++++++++++++++++++++ tests/test_plotting.py | 1 + 4 files changed, 133 insertions(+), 53 deletions(-) create mode 100644 sankee/themes.py diff --git a/sankee/datasets.py b/sankee/datasets.py index 746755d..548537c 100644 --- a/sankee/datasets.py +++ b/sankee/datasets.py @@ -4,6 +4,7 @@ import pandas as pd import plotly.graph_objects as go +from sankee import themes from sankee.plotting import sankify @@ -111,6 +112,7 @@ def sankify( seed: int = 0, exclude: None = None, label_type: str = "class", + theme: Union[str, themes.Theme] = themes.DEFAULT, ) -> go.Figure: """ Generate an interactive Sankey plot showing land cover change over time from a series of @@ -144,9 +146,10 @@ def sankify( exclude : None Unused parameter that will be removed in a future release. label_type : str, default "class" - The type of label to display for each link, one of "class", "percent", or "count". Selecting - "class" will use the class label, "percent" will use the proportion of sampled pixels in each - class, and "count" will use the number of sampled pixels in each class. + The type of label to display for each link, one of "class", "percent", or "count". + Selecting "class" will use the class label, "percent" will use the proportion of + sampled pixels in each class, and "count" will use the number of sampled pixels in each + class. Returns ------- @@ -180,6 +183,7 @@ def sankify( scale=scale, seed=seed, label_type=label_type, + theme=theme, ) diff --git a/sankee/plotting.py b/sankee/plotting.py index 8080529..51a578b 100644 --- a/sankee/plotting.py +++ b/sankee/plotting.py @@ -6,7 +6,7 @@ import pandas as pd import plotly.graph_objects as go -from sankee import sampling, utils +from sankee import sampling, themes, utils SankeyParameters = namedtuple( "SankeyParameters", @@ -36,6 +36,7 @@ def sankify( scale: Union[None, int] = None, seed: int = 0, label_type: Union[None, str] = "class", + theme: Union[str, themes.Theme] = themes.DEFAULT, ) -> go.Figure: """ Generate an interactive Sankey plot showing land cover change over time from a series of images. @@ -75,9 +76,13 @@ def sankify( seed : int, default 0 The seed value used to generate repeatable results during random sampling. label_type : str, default "class" - The type of label to display for each link, one of "class", "percent", "count", or False. Selecting - "class" will use the class label, "percent" will use the proportion of sampled pixels in each - class, and "count" will use the number of sampled pixels in each class. False will disable link labels. + The type of label to display for each link, one of "class", "percent", "count", or False. + Selecting "class" will use the class label, "percent" will use the proportion of sampled + pixels in each class, and "count" will use the number of sampled pixels in each class. + False will disable link labels. + theme : str or themes.Theme + The theme to apply to the Sankey diagram. Can be the name of a built-in theme (e.g. "d3") or + a custom `sankee.themes.Theme` object. Returns ------- @@ -113,6 +118,7 @@ def sankify( title=title, samples=samples, label_type=label_type, + theme=theme, ) @@ -125,6 +131,7 @@ def __init__( title: str, samples: ee.FeatureCollection, label_type: str, + theme: Union[str, themes.Theme], ): self.data = data self.labels = labels @@ -132,6 +139,7 @@ def __init__( self.title = title self.samples = samples self.label_type = label_type + self.theme = theme self.hide = [] # Initialized by `self.generate_plot` @@ -196,12 +204,14 @@ def generate_plot_parameters(self) -> SankeyParameters: # Calculate the proportion of each class in each year melted = self.data.melt(var_name="year") melted = melted.groupby(["year", "value"]).size().reset_index(name="count") - melted["proportion_of_total"] = (melted - .groupby("year")["count"] + melted["proportion_of_total"] = ( + melted.groupby("year")["count"] .transform(lambda x: x / x.sum()) .apply(lambda x: f"{x:.0%}") ) - all_classes = all_classes.merge(melted, left_on=["year", "class"], right_on=["year", "value"]) + all_classes = all_classes.merge( + melted, left_on=["year", "class"], right_on=["year", "value"] + ) if self.label_type == "class": all_classes["label"] = all_classes["class"].apply(lambda k: self.labels[k]) @@ -212,7 +222,9 @@ def generate_plot_parameters(self) -> SankeyParameters: elif not self.label_type: all_classes["label"] = "" else: - raise ValueError("Invalid label_type. Choose from 'class', 'percent', 'count', or False.") + raise ValueError( + "Invalid label_type. Choose from 'class', 'percent', 'count', or False." + ) return SankeyParameters( node_labels=all_classes.year, @@ -362,58 +374,39 @@ def generate_plot(self) -> go.Figure: self.df = self.generate_dataframe() params = self.generate_plot_parameters() - shadow_color = "#76777a" - label_style = f""" - color: #fff; - font-weight: 600; - letter-spacing: -1px; - text-shadow: - 0 0 4px black, - -1px 1px 0 {shadow_color}, - 1px 1px 0 {shadow_color}, - 1px -1px 0 {shadow_color}, - -1px -1px 0 {shadow_color}; - """ - - title_style = """ - color: #fff; - font-weight: 900; - word-spacing: 10px; - letter-spacing: 3px; - text-shadow: - 0 0 1px black, - 0 0 2px black, - 0 0 4px black; - """ + theme = ( + self.theme if isinstance(self.theme, themes.Theme) else themes.load_theme(self.theme) + ) + + node_kwargs = dict( + customdata=params.node_labels, + hovertemplate="%{customdata}", + label=[f"{s}" for s in params.label], + color=params.node_palette, + ) + link_kwargs = dict( + source=params.source, + target=params.target, + value=params.value, + color=params.link_palette, + customdata=params.link_labels, + hovertemplate="%{customdata} ", + ) fig = go.FigureWidget( data=[ go.Sankey( arrangement="snap", - node=dict( - pad=30, - thickness=10, - line=dict(color="#505050", width=1.5), - customdata=params.node_labels, - hovertemplate="%{customdata}", - label=[f"{s}" for s in params.label], - color=params.node_palette, - ), - link=dict( - source=params.source, - target=params.target, - line=dict(color="#909090", width=1), - value=params.value, - color=params.link_palette, - customdata=params.link_labels, - hovertemplate="%{customdata} ", - ), + node={**node_kwargs, **theme.node_kwargs}, + link={**link_kwargs, **theme.link_kwargs}, ) ] ) fig.update_layout( - title_text=f"{self.title}" if self.title else None, + title_text=f"{self.title}" + if self.title + else None, font_size=16, title_x=0.5, paper_bgcolor="rgba(0, 0, 0, 0)", diff --git a/sankee/themes.py b/sankee/themes.py new file mode 100644 index 0000000..5087045 --- /dev/null +++ b/sankee/themes.py @@ -0,0 +1,82 @@ +from typing import Dict, Union + + +class Theme: + def __init__( + self, + label_style: Union[None, str] = None, + title_style: Union[None, str] = None, + node_kwargs: Union[None, Dict] = None, + link_kwargs: Union[None, Dict] = None, + ): + self.label_style = label_style + self.title_style = title_style + self.node_kwargs = node_kwargs if node_kwargs is not None else {} + self.link_kwargs = link_kwargs if link_kwargs is not None else {} + + +DEFAULT = Theme( + node_kwargs=dict( + pad=30, + thickness=10, + line=dict(color="#505050", width=1.5), + ), + link_kwargs=dict( + line=dict(color="#909090", width=1), + ), + label_style=""" + color: #fff; + font-weight: 600; + letter-spacing: -1px; + text-shadow: + 0 0 4px black, + -1px 1px 0 #76777a, + 1px 1px 0 #76777a, + 1px -1px 0 #76777a, + -1px -1px 0 #76777a; + """, + title_style=""" + color: #fff; + font-weight: 900; + word-spacing: 10px; + letter-spacing: 3px; + text-shadow: + 0 0 1px black, + 0 0 2px black, + 0 0 4px black; + """, +) + + +D3 = Theme( + node_kwargs=dict(line=dict(width=1), pad=20, thickness=15), + link_kwargs=dict(color="rgba(120, 120, 120, 0.25)"), +) + +SIMPLE = Theme( + node_kwargs=dict(line=dict(width=0), pad=60, thickness=30), + link_kwargs=dict(color="rgba(120, 120, 120, 0.25)"), + label_style=""" + color: #666666; + font-size: 18px; + color: #666666; + """, + title_style=""" + color: #666666; + font-size: 24px; + font-weight: 900; + """, +) + + +THEMES = { + "default": DEFAULT, + "d3": D3, + "simple": SIMPLE, +} + + +def load_theme(theme_name): + if theme_name not in THEMES: + raise ValueError(f"Theme `{theme_name}` not found. Choose from {list(THEMES.keys())}.") + return THEMES[theme_name] diff --git a/tests/test_plotting.py b/tests/test_plotting.py index e3252a1..d1aa9bf 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -21,6 +21,7 @@ def sankey(): title="", samples=None, label_type="class", + theme="default", )