From f652bf14cde563661207de31bc7240b1924a14ec Mon Sep 17 00:00:00 2001 From: Aaron Zuspan Date: Mon, 14 Nov 2022 15:35:45 -0800 Subject: [PATCH] Store dataframe to avoid re-generating unnecessarily (#36) --- sankee/plotting.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sankee/plotting.py b/sankee/plotting.py index 4c6858e..382d2ff 100644 --- a/sankee/plotting.py +++ b/sankee/plotting.py @@ -134,20 +134,21 @@ def __init__( self.label_type = label_type self.hide = [] + # Initialized by `self.generate_plot` + self.df = None self.plot = self.generate_plot() self.gui = self.generate_gui() def get_sorted_classes(self) -> pd.Series: """Return all unique class values, sorted by the total number of observations.""" - df = self.generate_dataframe() start_count = ( - df.groupby("source") + self.df.groupby("source") .mean() .reset_index()[["source", "total"]] .rename(columns={"source": "class", "total": "count"}) ) end_count = ( - df.groupby("target") + self.df.groupby("target") .sum() .reset_index()[["target", "changed"]] .rename(columns={"target": "class", "changed": "count"}) @@ -158,13 +159,11 @@ def get_sorted_classes(self) -> pd.Series: def get_active_classes(self) -> pd.Series: """Return all unique active, visibile class values after filtering.""" - df = self.generate_dataframe() - - return df[["source", "target"]].melt().value.unique() + return self.df[["source", "target"]].melt().value.unique() def generate_plot_parameters(self) -> SankeyParameters: """Generate Sankey plot parameters from a formatted, cleaned dataframe""" - df = self.generate_dataframe() + df = self.df.copy() source_df = df[["source", "source_year"]].rename( columns={"source": "class", "source_year": "year"} @@ -312,10 +311,11 @@ def update_plot(): self.plot.data[0].node = new_plot.data[0].node buttons = [] + active_classes = self.get_active_classes() for i in unique_classes: label = self.labels[i] on_color = self.palette[i] - state = i in self.get_active_classes() + state = i in active_classes button = utils.ColorToggleButton(tooltip=label, on_color=on_color, state=state) button.layout.width = BUTTON_WIDTH @@ -324,7 +324,7 @@ def update_plot(): button.on_click(toggle_button) buttons.append(button) - def reset_plot(change): + def reset_plot(_): for button in buttons: if not button.state: button.click() @@ -357,6 +357,7 @@ def reset_plot(change): return gui def generate_plot(self) -> go.Figure: + self.df = self.generate_dataframe() params = self.generate_plot_parameters() shadow_color = "#76777a"