Skip to content

Commit

Permalink
Store dataframe to avoid re-generating unnecessarily (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
aazuspan committed Nov 14, 2022
1 parent f957213 commit f652bf1
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions sankee/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,21 @@ def __init__(
self.label_type = label_type

self.hide = []
# Initialized by `self.generate_plot`
self.df = None
self.plot = self.generate_plot()
self.gui = self.generate_gui()

def get_sorted_classes(self) -> pd.Series:
"""Return all unique class values, sorted by the total number of observations."""
df = self.generate_dataframe()
start_count = (
df.groupby("source")
self.df.groupby("source")
.mean()
.reset_index()[["source", "total"]]
.rename(columns={"source": "class", "total": "count"})
)
end_count = (
df.groupby("target")
self.df.groupby("target")
.sum()
.reset_index()[["target", "changed"]]
.rename(columns={"target": "class", "changed": "count"})
Expand All @@ -158,13 +159,11 @@ def get_sorted_classes(self) -> pd.Series:

def get_active_classes(self) -> pd.Series:
"""Return all unique active, visibile class values after filtering."""
df = self.generate_dataframe()

return df[["source", "target"]].melt().value.unique()
return self.df[["source", "target"]].melt().value.unique()

def generate_plot_parameters(self) -> SankeyParameters:
"""Generate Sankey plot parameters from a formatted, cleaned dataframe"""
df = self.generate_dataframe()
df = self.df.copy()

source_df = df[["source", "source_year"]].rename(
columns={"source": "class", "source_year": "year"}
Expand Down Expand Up @@ -312,10 +311,11 @@ def update_plot():
self.plot.data[0].node = new_plot.data[0].node

buttons = []
active_classes = self.get_active_classes()
for i in unique_classes:
label = self.labels[i]
on_color = self.palette[i]
state = i in self.get_active_classes()
state = i in active_classes

button = utils.ColorToggleButton(tooltip=label, on_color=on_color, state=state)
button.layout.width = BUTTON_WIDTH
Expand All @@ -324,7 +324,7 @@ def update_plot():
button.on_click(toggle_button)
buttons.append(button)

def reset_plot(change):
def reset_plot(_):
for button in buttons:
if not button.state:
button.click()
Expand Down Expand Up @@ -357,6 +357,7 @@ def reset_plot(change):
return gui

def generate_plot(self) -> go.Figure:
self.df = self.generate_dataframe()
params = self.generate_plot_parameters()

shadow_color = "#76777a"
Expand Down

0 comments on commit f652bf1

Please sign in to comment.