From 7e73ff560ace50ffe141bdec5c6d3273197ba91f Mon Sep 17 00:00:00 2001 From: Aaron Zuspan Date: Mon, 14 Nov 2022 14:37:15 -0800 Subject: [PATCH] Add label_type param (#33) --- sankee/datasets.py | 6 ++++++ sankee/plotting.py | 29 ++++++++++++++++++++++++++++- tests/test_plotting.py | 1 + 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/sankee/datasets.py b/sankee/datasets.py index cf98b15..746755d 100644 --- a/sankee/datasets.py +++ b/sankee/datasets.py @@ -110,6 +110,7 @@ def sankify( scale: Union[int, None] = None, seed: int = 0, exclude: None = None, + label_type: str = "class", ) -> go.Figure: """ Generate an interactive Sankey plot showing land cover change over time from a series of @@ -142,6 +143,10 @@ def sankify( The seed value used to generate repeatable results during random sampling. exclude : None Unused parameter that will be removed in a future release. + label_type : str, default "class" + The type of label to display for each link, one of "class", "percent", or "count". Selecting + "class" will use the class label, "percent" will use the proportion of sampled pixels in each + class, and "count" will use the number of sampled pixels in each class. Returns ------- @@ -174,6 +179,7 @@ def sankify( title=title, scale=scale, seed=seed, + label_type=label_type, ) diff --git a/sankee/plotting.py b/sankee/plotting.py index 55e6064..4c6858e 100644 --- a/sankee/plotting.py +++ b/sankee/plotting.py @@ -35,6 +35,7 @@ def sankify( title: Union[None, str] = None, scale: Union[None, int] = None, seed: int = 0, + label_type: str = "class", ) -> go.Figure: """ Generate an interactive Sankey plot showing land cover change over time from a series of images. @@ -73,6 +74,10 @@ def sankify( use the image's nominal scale, which may cause errors depending on the image projection. seed : int, default 0 The seed value used to generate repeatable results during random sampling. + label_type : str, default "class" + The type of label to display for each link, one of "class", "percent", or "count". Selecting + "class" will use the class label, "percent" will use the proportion of sampled pixels in each + class, and "count" will use the number of sampled pixels in each class. Returns ------- @@ -107,6 +112,7 @@ def sankify( palette=palette, title=title, samples=samples, + label_type=label_type, ) @@ -118,12 +124,14 @@ def __init__( palette: Dict[int, str], title: str, samples: ee.FeatureCollection, + label_type: str, ): self.data = data self.labels = labels self.palette = palette self.title = title self.samples = samples + self.label_type = label_type self.hide = [] self.plot = self.generate_plot() @@ -167,7 +175,6 @@ def generate_plot_parameters(self) -> SankeyParameters: all_classes = pd.concat([source_df, target_df]) all_classes = all_classes.drop_duplicates().reset_index(drop=True) - all_classes["label"] = all_classes["class"].apply(lambda k: self.labels[k]).tolist() all_classes["color"] = all_classes["class"].apply(lambda k: self.palette[k]).tolist() all_classes["id"] = all_classes.groupby(["year", "class"], sort=False).ngroup() @@ -187,6 +194,25 @@ def generate_plot_parameters(self) -> SankeyParameters: right_on=["year", "class"], )["id"] + # Calculate the proportion of each class in each year + melted = self.data.melt(var_name="year") + melted = melted.groupby(["year", "value"]).size().reset_index(name="count") + melted["proportion_of_total"] = (melted + .groupby("year")["count"] + .transform(lambda x: x / x.sum()) + .apply(lambda x: f"{x:.0%}") + ) + all_classes = all_classes.merge(melted, left_on=["year", "class"], right_on=["year", "value"]) + + if self.label_type == "class": + all_classes["label"] = all_classes["class"].apply(lambda k: self.labels[k]) + elif self.label_type == "percent": + all_classes["label"] = all_classes["proportion_of_total"] + elif self.label_type == "count": + all_classes["label"] = all_classes["count"] + else: + raise ValueError("Invalid label_type. Choose from 'class', 'percent', or 'count'.") + return SankeyParameters( node_labels=all_classes.year, link_labels=df.link_label, @@ -360,6 +386,7 @@ def generate_plot(self) -> go.Figure: fig = go.FigureWidget( data=[ go.Sankey( + arrangement="snap", node=dict( pad=30, thickness=10, diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 414166e..e3252a1 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -20,6 +20,7 @@ def sankey(): palette=TEST_DATASET.palette, title="", samples=None, + label_type="class", )