Skip to content

Commit

Permalink
optimize group access
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolaskruchten committed Jun 9, 2022
1 parent 3ae0645 commit 067d4b0
Showing 1 changed file with 16 additions and 22 deletions.
38 changes: 16 additions & 22 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,44 +1904,42 @@ def infer_config(args, constructor, trace_patch, layout_patch):
return trace_specs, grouped_mappings, sizeref, show_colorbar


def get_orderings(args, grouper):
def get_groups_and_orders(args, grouper):
"""
`orders` is the user-supplied ordering with the remaining data-frame-supplied
ordering appended if the column is used for grouping. It includes anything the user
gave, for any variable, including values not present in the dataset. It's a dict
where the keys are e.g. "x" or "color"
`sorted_group_names` is the set of groups, ordered by the order above. It's a list
of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
`groups` is the dicts of groups, ordered by the order above. Its keys are
tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
of a single dimension-group
"""
orders = {} if "category_orders" not in args else args["category_orders"].copy()

# figure out orders and what the single group name would be if there were one
single_group_name = []
unique_cache = dict()
for col in grouper:
if col == one_group:
single_group_name.append("")
else:
uniques = list(args["data_frame"][col].unique())
if col not in unique_cache:
unique_cache[col] = list(args["data_frame"][col].unique())
uniques = unique_cache[col]
if len(uniques) == 1:
single_group_name.append(uniques[0])
if col not in orders:
orders[col] = uniques
else:
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))

df = args["data_frame"]
if len(single_group_name) == len(grouper):
# we have a single group, so we can skip all group-by operations!
grouped = None
sorted_group_names = [tuple(single_group_name)]
groups = {tuple(single_group_name): df}
else:
grouped = args["data_frame"].groupby(grouper, sort=False)
sorted_group_names = []
for group_name in grouped.groups:
if len(grouper) == 1:
group_name = (group_name,)
sorted_group_names.append(group_name)
group_indices = df.groupby(grouper, sort=False).indices
sorted_group_names = [g if len(grouper) != 1 else (g,) for g in group_indices]

for i, col in reversed(list(enumerate(grouper))):
if col != one_group:
Expand All @@ -1951,7 +1949,9 @@ def get_orderings(args, grouper):
if g[i] in orders[col]
else -1,
)
return grouped, orders, sorted_group_names

groups = {s: df.iloc[group_indices[s]] for s in sorted_group_names}
return groups, orders


def make_figure(args, constructor, trace_patch=None, layout_patch=None):
Expand All @@ -1970,7 +1970,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
args, constructor, trace_patch, layout_patch
)
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
grouped, orders, sorted_group_names = get_orderings(args, grouper)
groups, orders = get_groups_and_orders(args, grouper)

col_labels = []
row_labels = []
Expand Down Expand Up @@ -1999,13 +1999,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
trendline_rows = []
trace_name_labels = None
facet_col_wrap = args.get("facet_col_wrap", 0)
for group_name in sorted_group_names:
if grouped is not None:
group = grouped.get_group(
group_name if len(group_name) > 1 else group_name[0]
)
else:
group = args["data_frame"]
for group_name, group in groups.items():
mapping_labels = OrderedDict()
trace_name_labels = OrderedDict()
frame_name = ""
Expand Down

0 comments on commit 067d4b0

Please sign in to comment.