Skip to content

Commit

Permalink
Merge pull request #3765 from plotly/one_group_short_circuit
Browse files Browse the repository at this point in the history
PX: Avoid `groupby` when possible and access groups more efficiently
  • Loading branch information
nicolaskruchten authored Jun 23, 2022
2 parents a4b9887 + 4b22199 commit f83921f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 20 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- `pattern_shape` options now available in `px.timeline()` [#3774](https://github.com/plotly/plotly.py/pull/3774)
- `facet_*` and `category_orders` now available in `px.pie()` [#3775](https://github.com/plotly/plotly.py/pull/3775)

### Performance

- `px` methods no longer call `groupby` on the input dataframe when the result would be a single group, and no longer groups by a lambda, for significant speedups [#3765](https://github.com/plotly/plotly.py/pull/3765)

### Updated

- Allow non-string extras in `flaglist` attributes, to support upcoming changes to `ax.automargin` in plotly.js [plotly.js#6193](https://github.com/plotly/plotly.js/pull/6193), [#3749](https://github.com/plotly/plotly.py/pull/3749)
Expand Down
65 changes: 45 additions & 20 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,40 +1920,66 @@ def infer_config(args, constructor, trace_patch, layout_patch):
return trace_specs, grouped_mappings, sizeref, show_colorbar


def get_orderings(args, grouper, grouped):
def get_groups_and_orders(args, grouper):
"""
`orders` is the user-supplied ordering with the remaining data-frame-supplied
ordering appended if the column is used for grouping. It includes anything the user
gave, for any variable, including values not present in the dataset. It's a dict
where the keys are e.g. "x" or "color"
`sorted_group_names` is the set of groups, ordered by the order above. It's a list
of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
`groups` is the dicts of groups, ordered by the order above. Its keys are
tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
of a single dimension-group
"""

orders = {} if "category_orders" not in args else args["category_orders"].copy()

# figure out orders and what the single group name would be if there were one
single_group_name = []
unique_cache = dict()
for col in grouper:
if col != one_group:
uniques = list(args["data_frame"][col].unique())
if col == one_group:
single_group_name.append("")
else:
if col not in unique_cache:
unique_cache[col] = list(args["data_frame"][col].unique())
uniques = unique_cache[col]
if len(uniques) == 1:
single_group_name.append(uniques[0])
if col not in orders:
orders[col] = uniques
else:
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
df = args["data_frame"]
if len(single_group_name) == len(grouper):
# we have a single group, so we can skip all group-by operations!
groups = {tuple(single_group_name): df}
else:
required_grouper = [g for g in grouper if g != one_group]
grouped = df.groupby(required_grouper, sort=False) # skip one_group groupers
group_indices = grouped.indices
sorted_group_names = [
g if len(required_grouper) != 1 else (g,) for g in group_indices
]

sorted_group_names = []
for group_name in grouped.groups:
if len(grouper) == 1:
group_name = (group_name,)
sorted_group_names.append(group_name)

for i, col in reversed(list(enumerate(grouper))):
if col != one_group:
for i, col in reversed(list(enumerate(required_grouper))):
sorted_group_names = sorted(
sorted_group_names,
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
)
return orders, sorted_group_names

# calculate the full group_names by inserting "" in the tuple index for one_group groups
full_sorted_group_names = [list(t) for t in sorted_group_names]
for i, col in enumerate(grouper):
if col == one_group:
for g in full_sorted_group_names:
g.insert(i, "")
full_sorted_group_names = [tuple(g) for g in full_sorted_group_names]

groups = {
sf: grouped.get_group(s if len(s) > 1 else s[0])
for sf, s in zip(full_sorted_group_names, sorted_group_names)
}
return groups, orders


def make_figure(args, constructor, trace_patch=None, layout_patch=None):
Expand All @@ -1974,9 +2000,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
args, constructor, trace_patch, layout_patch
)
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
grouped = args["data_frame"].groupby(grouper, sort=False)

orders, sorted_group_names = get_orderings(args, grouper, grouped)
groups, orders = get_groups_and_orders(args, grouper)

col_labels = []
row_labels = []
Expand Down Expand Up @@ -2005,8 +2029,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
trendline_rows = []
trace_name_labels = None
facet_col_wrap = args.get("facet_col_wrap", 0)
for group_name in sorted_group_names:
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
for group_name, group in groups.items():
mapping_labels = OrderedDict()
trace_name_labels = OrderedDict()
frame_name = ""
Expand Down Expand Up @@ -2224,6 +2247,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
fig.update_layout(layout_patch)
if "template" in args and args["template"] is not None:
fig.update_layout(template=args["template"], overwrite=True)
for f in frame_list:
f["name"] = str(f["name"])
fig.frames = frame_list if len(frames) > 1 else []

if args.get("trendline") and args.get("trendline_scope", "trace") == "overall":
Expand Down

0 comments on commit f83921f

Please sign in to comment.