diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index f4aa02cdc2..f5c3acb9c9 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1955,10 +1955,37 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( args, constructor, trace_patch, layout_patch ) - grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] - grouped = args["data_frame"].groupby(grouper, sort=False) + grouper_ = [x.grouper or one_group for x in grouped_mappings] or [one_group] + + all_same_group = True # variable indicating if grouping can be avoided + for g in grouper_: + if g is not one_group: + arr = args["data_frame"][g].values + all_same_group &= (arr[0] == arr).all(axis=0) + if not all_same_group: + break # early stopping if not all the same + + if all_same_group: + # Do not perform an expensive groupby operation when there are either + # no groups to group by, or when the group has only one (i.e., the same) value + grouper = [g for g in grouper_ if g is not one_group] + assert len(grouper) <= 1 + # -> create orders, sorted_group_names equivalent to those of get_ordings + orders = {g: [args["data_frame"][g].iloc[0]] for g in grouper} + sorted_group_names = [tuple(args["data_frame"][g].iloc[0] for g in orders)] + if len(sorted_group_names): # check for length to support also empty plots + assert len(sorted_group_names) == 1 # should be only for 1 variable + sorted_group_names = list(sorted_group_names[0]) # convert [tuple] to list + for idx in range(len(grouper_)): + # insert "" in the list when no grouping was used + if grouper_[idx] is one_group: + sorted_group_names.insert(idx, "") + sorted_group_names = [tuple(sorted_group_names)] # convert list to [tuple] + else: + grouper = grouper_ + grouped = args["data_frame"].groupby(grouper, sort=False) - orders, sorted_group_names = get_orderings(args, grouper, grouped) + orders, sorted_group_names = get_orderings(args, grouper, grouped) col_labels = [] row_labels = [] @@ -1988,7 +2015,13 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): trace_name_labels = None facet_col_wrap = args.get("facet_col_wrap", 0) for group_name in sorted_group_names: - group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0]) + if all_same_group: + # No expensive get_group operation when all data from the same group + group = args["data_frame"] + else: + group = grouped.get_group( + group_name if len(group_name) > 1 else group_name[0] + ) mapping_labels = OrderedDict() trace_name_labels = OrderedDict() frame_name = ""