Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PX: Avoid groupby when possible and access groups more efficiently #3765

Merged
merged 15 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- `pattern_shape` options now available in `px.timeline()` [#3774](https://github.com/plotly/plotly.py/pull/3774)
- `facet_*` and `category_orders` now available in `px.pie()` [#3775](https://github.com/plotly/plotly.py/pull/3775)

### Performance

- `px` methods no longer call `groupby` on the input dataframe when the result would be a single group, and no longer groups by a lambda, for significant speedups [#3765](https://github.com/plotly/plotly.py/pull/3765)

### Updated

- Allow non-string extras in `flaglist` attributes, to support upcoming changes to `ax.automargin` in plotly.js [plotly.js#6193](https://github.com/plotly/plotly.js/pull/6193), [#3749](https://github.com/plotly/plotly.py/pull/3749)
Expand Down
65 changes: 45 additions & 20 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,40 +1920,66 @@ def infer_config(args, constructor, trace_patch, layout_patch):
return trace_specs, grouped_mappings, sizeref, show_colorbar


def get_orderings(args, grouper, grouped):
def get_groups_and_orders(args, grouper):
"""
`orders` is the user-supplied ordering with the remaining data-frame-supplied
ordering appended if the column is used for grouping. It includes anything the user
gave, for any variable, including values not present in the dataset. It's a dict
where the keys are e.g. "x" or "color"

`sorted_group_names` is the set of groups, ordered by the order above. It's a list
of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
`groups` is the dicts of groups, ordered by the order above. Its keys are
tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
of a single dimension-group
"""

orders = {} if "category_orders" not in args else args["category_orders"].copy()

# figure out orders and what the single group name would be if there were one
single_group_name = []
unique_cache = dict()
for col in grouper:
if col != one_group:
uniques = list(args["data_frame"][col].unique())
if col == one_group:
single_group_name.append("")
else:
if col not in unique_cache:
unique_cache[col] = list(args["data_frame"][col].unique())
uniques = unique_cache[col]
if len(uniques) == 1:
single_group_name.append(uniques[0])
if col not in orders:
orders[col] = uniques
else:
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
df = args["data_frame"]
if len(single_group_name) == len(grouper):
# we have a single group, so we can skip all group-by operations!
groups = {tuple(single_group_name): df}
else:
required_grouper = [g for g in grouper if g != one_group]
grouped = df.groupby(required_grouper, sort=False) # skip one_group groupers
group_indices = grouped.indices
sorted_group_names = [
g if len(required_grouper) != 1 else (g,) for g in group_indices
]

sorted_group_names = []
for group_name in grouped.groups:
if len(grouper) == 1:
group_name = (group_name,)
sorted_group_names.append(group_name)

for i, col in reversed(list(enumerate(grouper))):
if col != one_group:
for i, col in reversed(list(enumerate(required_grouper))):
sorted_group_names = sorted(
sorted_group_names,
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
)
return orders, sorted_group_names

# calculate the full group_names by inserting "" in the tuple index for one_group groups
full_sorted_group_names = [list(t) for t in sorted_group_names]
for i, col in enumerate(grouper):
if col == one_group:
for g in full_sorted_group_names:
g.insert(i, "")
full_sorted_group_names = [tuple(g) for g in full_sorted_group_names]

groups = {
sf: grouped.get_group(s if len(s) > 1 else s[0])
for sf, s in zip(full_sorted_group_names, sorted_group_names)
}
return groups, orders


def make_figure(args, constructor, trace_patch=None, layout_patch=None):
Expand All @@ -1974,9 +2000,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
args, constructor, trace_patch, layout_patch
)
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
grouped = args["data_frame"].groupby(grouper, sort=False)

orders, sorted_group_names = get_orderings(args, grouper, grouped)
groups, orders = get_groups_and_orders(args, grouper)

col_labels = []
row_labels = []
Expand Down Expand Up @@ -2005,8 +2029,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
trendline_rows = []
trace_name_labels = None
facet_col_wrap = args.get("facet_col_wrap", 0)
for group_name in sorted_group_names:
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
for group_name, group in groups.items():
mapping_labels = OrderedDict()
trace_name_labels = OrderedDict()
frame_name = ""
Expand Down Expand Up @@ -2224,6 +2247,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
fig.update_layout(layout_patch)
if "template" in args and args["template"] is not None:
fig.update_layout(template=args["template"], overwrite=True)
for f in frame_list:
f["name"] = str(f["name"])
fig.frames = frame_list if len(frames) > 1 else []

if args.get("trendline") and args.get("trendline_scope", "trace") == "overall":
Expand Down