Skip to content

Commit

Permalink
Limit display widths for api components, allow row wrapping of cohort…
Browse files Browse the repository at this point in the history
… selections. (#81)
  • Loading branch information
gbowlin authored Aug 28, 2024
1 parent 14ba4ad commit 2974cb5
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 131 deletions.
1 change: 1 addition & 0 deletions changelog/81.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Limits control max-widths to 1200px in most cases, allowing row wrap when needed.
12 changes: 1 addition & 11 deletions example-notebooks/binary-classifier/classifier_bin.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -825,17 +825,7 @@
},
"outputs": [],
"source": [
"#sm.ExploreCohortEvaluation()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "516faf4a-7129-4022-841c-3bcb6986bfb8",
"metadata": {},
"outputs": [],
"source": [
"sm.plot_cohort_evaluation('Age', ('[0-10)', '[10-20)', '[20-50)', '[50-70)', '70+'), 'Readmitted within 30 Days', 'LGBM_score', (0.08, 0.15), per_context=False)"
"sm.ExploreCohortEvaluation()"
]
},
{
Expand Down
53 changes: 32 additions & 21 deletions src/seismometer/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,16 @@ def generate_fairness_audit(

data = data[[target, score_column] + cohort_columns]
data = FilterRule.isin(target, (0, 1)).filter(data)
if len(data.index) < sg.censor_threshold:
positive_samples = data[target].sum()
if min(positive_samples, len(data) - positive_samples) < sg.censor_threshold:
return template.render_censored_plot_message(sg.censor_threshold)

try:
altair_plot = fairness_audit_altair(
data, cohort_columns, score_column, target, score_threshold, metric_list, fairness_threshold
)
except CensoredResultException as error:
return template.render_censored_data_message(error.message)
return template.render_censored_data_message(str(error))

if NotebookHost.supports_iframe():
altair_plot.save(fairness_path, format="html")
Expand All @@ -294,26 +295,25 @@ def cohort_list():
from ipywidgets import Output, VBox

from .controls.selection import MultiSelectionListWidget
from .controls.styles import BOX_GRID_LAYOUT

options = sg.available_cohort_groups

comparison_selections = MultiSelectionListWidget(options, title="Cohort")
output = Output()

def on_widget_value_changed(*args):
output.clear_output(wait=True)
with output:
display("Recalculating...")
output.clear_output(wait=True)
display("Recalculating...", clear=True)
html = _cohort_list_details(comparison_selections.value)
display(html)
display(html, clear=True)

comparison_selections.observe(on_widget_value_changed, "value")

# get initial value
on_widget_value_changed()

return VBox(children=[comparison_selections, output])
return VBox(children=[comparison_selections, output], layout=BOX_GRID_LAYOUT)


@disk_cached_html_segment
Expand All @@ -335,26 +335,37 @@ def _cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML:
from .data.filter import filter_rule_from_cohort_dictionary

sg = Seismogram()
cfg = sg.config
target_cols = [pdh.event_value(x) for x in cfg.targets]
intervention_cols = [pdh.event_value(x) for x in cfg.interventions]
outcome_cols = [pdh.event_value(x) for x in cfg.outcomes]

rule = filter_rule_from_cohort_dictionary(cohort_dict)
data = rule.filter(sg.dataframe)
data = rule.filter(sg.dataframe)[
cfg.entity_keys + cfg.output_list + intervention_cols + outcome_cols + target_cols
]
cohort_count = data[sg.entity_keys[0]].nunique()
if cohort_count < sg.censor_threshold:
return template.render_censored_plot_message(sg.censor_threshold)

cfg = sg.config
target_cols = [pdh.event_value(x) for x in cfg.targets]
intervention_cols = [pdh.event_value(x) for x in cfg.interventions]
outcome_cols = [pdh.event_value(x) for x in cfg.outcomes]
groups = data[cfg.entity_keys + cfg.output_list + intervention_cols + outcome_cols + target_cols].groupby(
target_cols
)
aggregation = {cfg.entity_id: ["count", "nunique"]}
if len(cfg.context_id):
aggregation[cfg.context_id] = "nunique"
# add in other keys for aggregation
aggregation.update({k: "mean" for k in cfg.output_list + intervention_cols + outcome_cols})
groups = data.groupby(target_cols)
float_cols = list(data[intervention_cols + outcome_cols].select_dtypes(include=float))

stat_dict = {k: ["mean"] for k in float_cols}
stat_dict.update({cfg.entity_id: ["nunique", "count"], cfg.context_id: ["nunique"]})

groupstats = groups.agg(stat_dict)
groupstats.columns = [pdh.event_name(x) for x in float_cols] + [
f"Unique {cfg.entity_id}",
f"{cfg.entity_id} Count",
f"Unique {cfg.context_id}",
]
new_names = [pdh.event_name(x) for x in target_cols]
if len(new_names) == 1:
new_names = new_names[0] # because pandas Index only accepts a string for rename.
groupstats.index.rename(new_names, inplace=True)
html_table = groupstats.to_html()
title = "Summary"
html_table = groups.agg(aggregation).to_html()
return template.render_title_message(title, html_table)


Expand Down
6 changes: 3 additions & 3 deletions src/seismometer/controls/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import traitlets
from ipywidgets import HTML, Box, Dropdown, Label, Layout, Stack, ToggleButton, ValueWidget, VBox, jslink

from .styles import html_title
from .styles import DROPDOWN_LAYOUT, html_title


class SelectionListWidget(ValueWidget, VBox):
Expand Down Expand Up @@ -245,7 +245,7 @@ def __init__(
self.dropdown = Dropdown(
options=[key for key in values],
value=value[0],
layout=Layout(width="calc(max(max-content, var(--jp-widgets-inline-width-short)))"),
layout=DROPDOWN_LAYOUT,
)
self.dropdown.observe(self._on_selection_change, "value")
self.selection_widgets = {}
Expand All @@ -256,7 +256,7 @@ def __init__(
self.stack = Stack(children=[self.selection_widgets[key] for key in self.selection_widgets], selected_index=0)
self.children = [self.title_box, self.dropdown, self.stack]
jslink((self.dropdown, "index"), (self.stack, "selected_index"))
self.layout = Layout(width="calc(100% - var(--jp-widgets-border-width)* 2)")
self.layout = Layout(width="calc(100% - var(--jp-widgets-border-width)* 2)", max_width="min-content")
self._on_selection_change()
self.observe(self._on_value_change, "value")
self._disabled = False
Expand Down
3 changes: 2 additions & 1 deletion src/seismometer/controls/styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
WIDE_LABEL_STYLE = {"description_width": "120px"}

BOX_GRID_LAYOUT = Layout(
align_items="flex-start", grid_gap="20px", width="100%", min_width="300px", max_width="1400px"
align_items="flex-start", grid_gap="20px", width="100%", min_width="300px", max_width="1200px"
)
WIDE_BUTTON_LAYOUT = Layout(align_items="flex-start", width="max-content", min_width="200px")
DROPDOWN_LAYOUT = Layout(width="calc(max(max-content, var(--jp-widgets-inline-width-short)))")


def html_title(title: str) -> HTML:
Expand Down
16 changes: 9 additions & 7 deletions src/seismometer/html/resources/cohorts.html
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
}
</style>

<h3>Cohort Summaries</h3>
{% for cohort, df_list in dfs.items() %}
<div class="top">
{% for df in df_list %}
<div class="inner">
{{ df }}
<div style="{{ display_style }}">
<h3>Cohort Summaries</h3>
{% for cohort, df_list in dfs.items() %}
<div class="top">
{% for df in df_list %}
<div class="inner">
{{ df }}
</div>
{% endfor %}
</div>
{% endfor %}
</div>
{% endfor %}
55 changes: 28 additions & 27 deletions src/seismometer/html/resources/info.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,33 @@
text-align: left !important;
}
</style>

<h3>Summary</h3>
The preloaded data covers {{ num_predictions }} predictions over {{ num_entities }} entities from the dates {{ start_date }} to {{ end_date }}.
<table>
<thead>
<div style="{{ display_style }}">
<h3>Summary</h3>
The preloaded data covers {{ num_predictions }} predictions over {{ num_entities }} entities from the dates {{ start_date }} to {{ end_date }}.
<table>
<thead>
<tr>
<th>Dataframe Name</th>
<th>Rows</th>
<th>Columns</th>
<th>Content</th>
</tr>
</thead>
{% for table in tables %}
<tr>
<th>Dataframe Name</th>
<th>Rows</th>
<th>Columns</th>
<th>Content</th>
<td><code>{{ table["name"] }}</code></td>
<td>{{ table["num_rows"] }}</td>
<td>{{ table["num_cols"] }}</td>
<td>{{ table["description"] }}</td>
</tr>
</thead>
{% for table in tables %}
<tr>
<td><code>{{ table["name"] }}</code></td>
<td>{{ table["num_rows"] }}</td>
<td>{{ table["num_cols"] }}</td>
<td>{{ table["description"] }}</td>
</tr>
{% endfor %}
</table>
{% if plot_help %}
<h4>Plot Functions</h4>
<ul>
<li><code>sm.model_evaluation()</code> - Overall performance across thresholds</li>
<li><code>sm.cohort_evaluation(cohort_group)</code> - Performance split by specified cohort</li>
<li><code>sm.plot_outcome(outcome, intervention, cohort)</code> - Compare trends of interventions to outcomes</li>
</ul>
{% endif %}
{% endfor %}
</table>
{% if plot_help %}
<h4>Plot Functions</h4>
<ul>
<li><code>sm.ExploreModelEvaluation()</code> - Overall performance across thresholds</li>
<li><code>sm.ExploreCohortEvaluation()</code> - Performance split by specified cohort</li>
<li><code>sm.ExploreCohortOutcomeInterventionTimes()</code> - Compare trends of interventions to outcomes</li>
</ul>
{% endif %}
</div>
4 changes: 2 additions & 2 deletions src/seismometer/html/resources/title_image.html
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<div style="width: max-content;">
<div style="{{ display_style }}">
<h3 style="text-align: center;">{{ title }}</h3>
{{ image }}
</div>
</div>
4 changes: 2 additions & 2 deletions src/seismometer/html/resources/title_message.html
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<div style="width: 100%;">
<div style="{{ display_style }}">
<h3>{{ title }}</h3>
{{ message }}
</div>
</div>
8 changes: 6 additions & 2 deletions src/seismometer/html/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from IPython.display import HTML, SVG
from jinja2 import Environment, PackageLoader, TemplateNotFound

FULL_WIDTH_STYLE = "width: 100%; max-width: 1200px;"

logger = logging.getLogger("seismometer")

# Initializing Jinja
Expand Down Expand Up @@ -45,7 +47,7 @@ def render_cohort_summary_template(dfs: dict[str, list[str]]) -> HTML:
return render_into_template("cohorts", {"dfs": dfs})


def render_into_template(name: str, values: dict = None) -> HTML:
def render_into_template(name: str, values: dict = None, display_style=FULL_WIDTH_STYLE) -> HTML:
"""
Uses jinja to render a dictionary of values into a template.
Expand All @@ -55,6 +57,8 @@ def render_into_template(name: str, values: dict = None) -> HTML:
The template name.
values : Optional[dict], optional
A dictionary of values to be templated into the HTML, by default None.
display_style : str, optional
The display style for the template, by default FULL_WIDTH_STYLE.
Returns
-------
Expand All @@ -68,7 +72,7 @@ def render_into_template(name: str, values: dict = None) -> HTML:
logger.warning(f"HTML template {name} not found.")
return HTML()

return HTML(template.render(values))
return HTML(template.render(values, display_style=display_style))


def render_title_message(title: str, message: str) -> HTML:
Expand Down
32 changes: 17 additions & 15 deletions tests/resources/rendered_templates/cohort_summaries_template.html
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@
}
</style>

<h3>Cohort Summaries</h3>
<div class="top">
<div class="inner">
dataframe1!
<div style="width: 100%; max-width: 1200px;">
<h3>Cohort Summaries</h3>
<div class="top">
<div class="inner">
dataframe1!
</div>
<div class="inner">
dataframe2!
</div>
</div>
<div class="inner">
dataframe2!
<div class="top">
<div class="inner">
dataframe1!
</div>
<div class="inner">
dataframe2!
</div>
</div>
</div>
<div class="top">
<div class="inner">
dataframe1!
</div>
<div class="inner">
dataframe2!
</div>
</div>
</div>
47 changes: 24 additions & 23 deletions tests/resources/rendered_templates/info.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,29 @@
text-align: left !important;
}
</style>

<h3>Summary</h3>
The preloaded data covers num_predictions predictions over num_entities entities from the dates start_date to end_date.
<table>
<thead>
<div style="width: 100%; max-width: 1200px;">
<h3>Summary</h3>
The preloaded data covers num_predictions predictions over num_entities entities from the dates start_date to end_date.
<table>
<thead>
<tr>
<th>Dataframe Name</th>
<th>Rows</th>
<th>Columns</th>
<th>Content</th>
</tr>
</thead>
<tr>
<th>Dataframe Name</th>
<th>Rows</th>
<th>Columns</th>
<th>Content</th>
<td><code>sg.dataframe</code></td>
<td>1</td>
<td>1</td>
<td>Scores, features, configured demographics, and merged events for each prediction</td>
</tr>
</thead>
<tr>
<td><code>sg.dataframe</code></td>
<td>1</td>
<td>1</td>
<td>Scores, features, configured demographics, and merged events for each prediction</td>
</tr>
</table>
<h4>Plot Functions</h4>
<ul>
<li><code>sm.model_evaluation()</code> - Overall performance across thresholds</li>
<li><code>sm.cohort_evaluation(cohort_group)</code> - Performance split by specified cohort</li>
<li><code>sm.plot_outcome(outcome, intervention, cohort)</code> - Compare trends of interventions to outcomes</li>
</ul>
</table>
<h4>Plot Functions</h4>
<ul>
<li><code>sm.ExploreModelEvaluation()</code> - Overall performance across thresholds</li>
<li><code>sm.ExploreCohortEvaluation()</code> - Performance split by specified cohort</li>
<li><code>sm.ExploreCohortOutcomeInterventionTimes()</code> - Compare trends of interventions to outcomes</li>
</ul>
</div>
Loading

0 comments on commit 2974cb5

Please sign in to comment.