Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡ avoid expensive & unnecessary groupby in px #3761

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,10 +1955,37 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
args, constructor, trace_patch, layout_patch
)
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
grouped = args["data_frame"].groupby(grouper, sort=False)
grouper_ = [x.grouper or one_group for x in grouped_mappings] or [one_group]

all_same_group = True # variable indicating if grouping can be avoided
for g in grouper_:
if g is not one_group:
arr = args["data_frame"][g].values
all_same_group &= (arr[0] == arr).all(axis=0)
if not all_same_group:
break # early stopping if not all the same

if all_same_group:
# Do not perform an expensive groupby operation when there are either
# no groups to group by, or when the group has only one (i.e., the same) value
grouper = [g for g in grouper_ if g is not one_group]
assert len(grouper) <= 1
# -> create orders, sorted_group_names equivalent to those of get_ordings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love the optimization where we don't group if the groupers are all one_group but I'm worried about the duplication of logic here... If we just do the optimization when no grouping is requested at all (i.e. leave out the case when we do have e.g. symbol but there's only one value) then this PR can be less invasive/have less duplication. I've done an implementation of that over here #3765 if you want to take a look at it.

I think that most of the value of this PR comes from the "all one_group" case, and that leaving out the case where there happens to only be one group would be OK for now as this is likely not all that common, but I'm not dead set on it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In digging into this a bit more, I see/remember that we actually call .uniques() on all non-one_group groups in the awkward get_orderings() function anyway, so if we did want to do this other optimization, maybe (re-)inlining the code from get_orderings() into make_figure() and interleaving this check would make more sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean like your implementation in #3765?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love the optimization where we don't group if the groupers are all one_group but I'm worried about the duplication of logic here... If we just do the optimization when no grouping is requested at all (i.e. leave out the case when we do have e.g. symbol but there's only one value) then this PR can be less invasive/have less duplication. I've done an implementation of that over here #3765 if you want to take a look at it.

I think that most of the value of this PR comes from the "all one_group" case, and that leaving out the case where there happens to only be one group would be OK for now as this is likely not all that common, but I'm not dead set on it.

orders = {g: [args["data_frame"][g].iloc[0]] for g in grouper}
sorted_group_names = [tuple(args["data_frame"][g].iloc[0] for g in orders)]
if len(sorted_group_names): # check for length to support also empty plots
assert len(sorted_group_names) == 1 # should be only for 1 variable
sorted_group_names = list(sorted_group_names[0]) # convert [tuple] to list
for idx in range(len(grouper_)):
# insert "" in the list when no grouping was used
if grouper_[idx] is one_group:
sorted_group_names.insert(idx, "")
sorted_group_names = [tuple(sorted_group_names)] # convert list to [tuple]
else:
grouper = grouper_
grouped = args["data_frame"].groupby(grouper, sort=False)

orders, sorted_group_names = get_orderings(args, grouper, grouped)
orders, sorted_group_names = get_orderings(args, grouper, grouped)

col_labels = []
row_labels = []
Expand Down Expand Up @@ -1988,7 +2015,13 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
trace_name_labels = None
facet_col_wrap = args.get("facet_col_wrap", 0)
for group_name in sorted_group_names:
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
if all_same_group:
# No expensive get_group operation when all data from the same group
group = args["data_frame"]
else:
group = grouped.get_group(
group_name if len(group_name) > 1 else group_name[0]
)
mapping_labels = OrderedDict()
trace_name_labels = OrderedDict()
frame_name = ""
Expand Down