Skip to content

Commit

Permalink
reorder computation to avoid an extra scan through the data
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolaskruchten committed Jun 9, 2022
1 parent dd2137b commit 3ae0645
Showing 1 changed file with 29 additions and 39 deletions.
68 changes: 29 additions & 39 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,7 +1904,7 @@ def infer_config(args, constructor, trace_patch, layout_patch):
return trace_specs, grouped_mappings, sizeref, show_colorbar


def get_orderings(args, grouper, grouped, all_same_group):
def get_orderings(args, grouper):
"""
`orders` is the user-supplied ordering with the remaining data-frame-supplied
ordering appended if the column is used for grouping. It includes anything the user
Expand All @@ -1916,47 +1916,42 @@ def get_orderings(args, grouper, grouped, all_same_group):
of a single dimension-group
"""
orders = {} if "category_orders" not in args else args["category_orders"].copy()
sorted_group_names = []

if all_same_group:
for col in grouper:
if col != one_group:
single_val = args["data_frame"][col].iloc[0]
sorted_group_names.append(single_val)
orders[col] = [single_val]
else:
sorted_group_names.append("")
return orders, [tuple(sorted_group_names)]

# figure out orders and what the single group name would be if there were one
single_group_name = []
for col in grouper:
if col != one_group:
if col == one_group:
single_group_name.append("")
else:
uniques = list(args["data_frame"][col].unique())
if len(uniques) == 1:
single_group_name.append(uniques[0])
if col not in orders:
orders[col] = uniques
else:
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))

for group_name in grouped.groups:
if len(grouper) == 1:
group_name = (group_name,)
sorted_group_names.append(group_name)

for i, col in reversed(list(enumerate(grouper))):
if col != one_group:
sorted_group_names = sorted(
sorted_group_names,
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
)
return orders, sorted_group_names

if len(single_group_name) == len(grouper):
# we have a single group, so we can skip all group-by operations!
grouped = None
sorted_group_names = [tuple(single_group_name)]
else:
grouped = args["data_frame"].groupby(grouper, sort=False)
sorted_group_names = []
for group_name in grouped.groups:
if len(grouper) == 1:
group_name = (group_name,)
sorted_group_names.append(group_name)

def _all_same_group(args, grouper):
for g in set(grouper):
if g != one_group:
arr = args["data_frame"][g].values
if not (arr[0] == arr).all(axis=0):
return False
return True
for i, col in reversed(list(enumerate(grouper))):
if col != one_group:
sorted_group_names = sorted(
sorted_group_names,
key=lambda g: orders[col].index(g[i])
if g[i] in orders[col]
else -1,
)
return grouped, orders, sorted_group_names


def make_figure(args, constructor, trace_patch=None, layout_patch=None):
Expand All @@ -1975,12 +1970,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
args, constructor, trace_patch, layout_patch
)
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
grouped = None
all_same_group = _all_same_group(args, grouper)
if not all_same_group:
grouped = args["data_frame"].groupby(grouper, sort=False)

orders, sorted_group_names = get_orderings(args, grouper, grouped, all_same_group)
grouped, orders, sorted_group_names = get_orderings(args, grouper)

col_labels = []
row_labels = []
Expand Down

0 comments on commit 3ae0645

Please sign in to comment.