Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PX: Avoid groupby when possible and access groups more efficiently #3765

Merged
merged 15 commits into from
Jun 23, 2022
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,7 +1904,7 @@ def infer_config(args, constructor, trace_patch, layout_patch):
return trace_specs, grouped_mappings, sizeref, show_colorbar


def get_orderings(args, grouper, grouped):
def get_orderings(args, grouper):
"""
`orders` is the user-supplied ordering with the remaining data-frame-supplied
ordering appended if the column is used for grouping. It includes anything the user
Expand All @@ -1915,29 +1915,43 @@ def get_orderings(args, grouper, grouped):
of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
of a single dimension-group
"""

orders = {} if "category_orders" not in args else args["category_orders"].copy()

# figure out orders and what the single group name would be if there were one
single_group_name = []
for col in grouper:
if col != one_group:
if col == one_group:
single_group_name.append("")
else:
uniques = list(args["data_frame"][col].unique())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the following snippet this line is executed twice for the exact same col ("variable");

import plotly.express as px
px.line([1, 2, 3, 4])

I guess that the second computation could be avoided? (this is why I used set(grouper) to construct the order dictionary)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could memoize the results of this loop as a function of col, yeah

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(done)

if len(uniques) == 1:
single_group_name.append(uniques[0])
if col not in orders:
orders[col] = uniques
else:
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))

sorted_group_names = []
for group_name in grouped.groups:
if len(grouper) == 1:
group_name = (group_name,)
sorted_group_names.append(group_name)

for i, col in reversed(list(enumerate(grouper))):
if col != one_group:
sorted_group_names = sorted(
sorted_group_names,
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
)
return orders, sorted_group_names
if len(single_group_name) == len(grouper):
# we have a single group, so we can skip all group-by operations!
grouped = None
sorted_group_names = [tuple(single_group_name)]
else:
grouped = args["data_frame"].groupby(grouper, sort=False)
sorted_group_names = []
for group_name in grouped.groups:
if len(grouper) == 1:
group_name = (group_name,)
sorted_group_names.append(group_name)

for i, col in reversed(list(enumerate(grouper))):
if col != one_group:
sorted_group_names = sorted(
sorted_group_names,
key=lambda g: orders[col].index(g[i])
if g[i] in orders[col]
else -1,
)
return grouped, orders, sorted_group_names


def make_figure(args, constructor, trace_patch=None, layout_patch=None):
Expand All @@ -1956,9 +1970,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
args, constructor, trace_patch, layout_patch
)
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
grouped = args["data_frame"].groupby(grouper, sort=False)

orders, sorted_group_names = get_orderings(args, grouper, grouped)
grouped, orders, sorted_group_names = get_orderings(args, grouper)

col_labels = []
row_labels = []
Expand Down Expand Up @@ -1988,7 +2000,12 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
trace_name_labels = None
facet_col_wrap = args.get("facet_col_wrap", 0)
for group_name in sorted_group_names:
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
if grouped is not None:
group = grouped.get_group(
group_name if len(group_name) > 1 else group_name[0]
)
else:
group = args["data_frame"]
mapping_labels = OrderedDict()
trace_name_labels = OrderedDict()
frame_name = ""
Expand Down