Skip to content

Commit

Permalink
🧹 avoid unnecessary one_group groupby operations
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd committed Jun 9, 2022
1 parent cdae77c commit b3a4583
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,21 +1938,30 @@ def get_groups_and_orders(args, grouper):
# we have a single group, so we can skip all group-by operations!
groups = {tuple(single_group_name): df}
else:
grouped = df.groupby(grouper, sort=False)
required_grouper = [g for g in grouper if g != one_group]
grouped = df.groupby(required_grouper, sort=False) # skip one_group groupers
group_indices = grouped.indices
sorted_group_names = [g if len(grouper) != 1 else (g,) for g in group_indices]
sorted_group_names = [
g if len(required_grouper) != 1 else (g,) for g in group_indices
]

for i, col in reversed(list(enumerate(grouper))):
if col != one_group:
sorted_group_names = sorted(
sorted_group_names,
key=lambda g: orders[col].index(g[i])
if g[i] in orders[col]
else -1,
)
for i, col in reversed(list(enumerate(required_grouper))):
sorted_group_names = sorted(
sorted_group_names,
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
)

# calculate the full group_names by inserting "" in the tuple index for one_group groups
full_sorted_group_names = [list(t) for t in sorted_group_names]
for i, col in enumerate(grouper):
if col == one_group:
for g in full_sorted_group_names:
g.insert(i, "")
full_sorted_group_names = [tuple(g) for g in full_sorted_group_names]

groups = {
s: grouped.get_group(s if len(s) > 1 else s[0]) for s in sorted_group_names
sf: grouped.get_group(s if len(s) > 1 else s[0])
for sf, s in zip(full_sorted_group_names, sorted_group_names)
}
return groups, orders

Expand Down

0 comments on commit b3a4583

Please sign in to comment.