diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index f7e7da8cbf..bff581e6ff 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -236,7 +236,7 @@ def area( labels={}, color_discrete_sequence=None, color_discrete_map={}, - orientation="v", + orientation=None, groupnorm=None, log_x=False, log_y=False, @@ -256,9 +256,7 @@ def area( return make_figure( args=locals(), constructor=go.Scatter, - trace_patch=dict( - stackgroup=1, mode="lines", orientation=orientation, groupnorm=groupnorm - ), + trace_patch=dict(stackgroup=1, mode="lines", groupnorm=groupnorm), ) @@ -291,7 +289,7 @@ def bar( range_color=None, color_continuous_midpoint=None, opacity=None, - orientation="v", + orientation=None, barmode="relative", log_x=False, log_y=False, @@ -309,7 +307,7 @@ def bar( return make_figure( args=locals(), constructor=go.Bar, - trace_patch=dict(orientation=orientation, textposition="auto"), + trace_patch=dict(textposition="auto"), layout_patch=dict(barmode=barmode), ) @@ -335,7 +333,7 @@ def histogram( color_discrete_map={}, marginal=None, opacity=None, - orientation="v", + orientation=None, barmode="relative", barnorm=None, histnorm=None, @@ -361,13 +359,7 @@ def histogram( args=locals(), constructor=go.Histogram, trace_patch=dict( - orientation=orientation, - histnorm=histnorm, - histfunc=histfunc, - nbinsx=nbins if orientation == "v" else None, - nbinsy=None if orientation == "v" else nbins, - cumulative=dict(enabled=cumulative), - bingroup="x" if orientation == "v" else "y", + histnorm=histnorm, histfunc=histfunc, cumulative=dict(enabled=cumulative), ), layout_patch=dict(barmode=barmode, barnorm=barnorm), ) @@ -393,8 +385,8 @@ def violin( labels={}, color_discrete_sequence=None, color_discrete_map={}, - orientation="v", - violinmode="group", + orientation=None, + violinmode=None, log_x=False, log_y=False, range_x=None, @@ -414,12 +406,7 @@ def violin( args=locals(), constructor=go.Violin, trace_patch=dict( - orientation=orientation, - points=points, - box=dict(visible=box), - scalegroup=True, - x0=" ", - y0=" ", + points=points, box=dict(visible=box), scalegroup=True, x0=" ", y0=" ", ), layout_patch=dict(violinmode=violinmode), ) @@ -445,8 +432,8 @@ def box( labels={}, color_discrete_sequence=None, color_discrete_map={}, - orientation="v", - boxmode="group", + orientation=None, + boxmode=None, log_x=False, log_y=False, range_x=None, @@ -470,9 +457,7 @@ def box( return make_figure( args=locals(), constructor=go.Box, - trace_patch=dict( - orientation=orientation, boxpoints=points, notched=notched, x0=" ", y0=" " - ), + trace_patch=dict(boxpoints=points, notched=notched, x0=" ", y0=" "), layout_patch=dict(boxmode=boxmode), ) @@ -497,8 +482,8 @@ def strip( labels={}, color_discrete_sequence=None, color_discrete_map={}, - orientation="v", - stripmode="group", + orientation=None, + stripmode=None, log_x=False, log_y=False, range_x=None, @@ -516,7 +501,6 @@ def strip( args=locals(), constructor=go.Box, trace_patch=dict( - orientation=orientation, boxpoints="all", pointpos=0, hoveron="points", @@ -1384,7 +1368,7 @@ def funnel( color_discrete_sequence=None, color_discrete_map={}, opacity=None, - orientation="h", + orientation=None, log_x=False, log_y=False, range_x=None, @@ -1398,9 +1382,7 @@ def funnel( In a funnel plot, each row of `data_frame` is represented as a rectangular sector of a funnel. """ - return make_figure( - args=locals(), constructor=go.Funnel, trace_patch=dict(orientation=orientation), - ) + return make_figure(args=locals(), constructor=go.Funnel) funnel.__doc__ = make_docstring(funnel) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 8aae2f14f2..622fe7a51f 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -92,6 +92,10 @@ def get_label(args, column): return column +def _is_continuous(df, col_name): + return df[col_name].dtype.kind in "ifc" + + def get_decorated_label(args, column, role): label = get_label(args, column) if "histfunc" in args and ( @@ -188,7 +192,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): if ((not attr_value) or (name in attr_value)) and ( trace_spec.constructor != go.Parcoords - or args["data_frame"][name].dtype.kind in "ifc" + or _is_continuous(args["data_frame"], name) ) and ( trace_spec.constructor != go.Parcats @@ -1161,7 +1165,7 @@ def aggfunc_discrete(x): agg_f[count_colname] = "sum" if args["color"]: - if df[args["color"]].dtype.kind not in "ifc": + if not _is_continuous(df, args["color"]): aggfunc_color = aggfunc_discrete discrete_color = True elif not aggfunc_color: @@ -1227,7 +1231,7 @@ def aggfunc_continuous(x): return args -def infer_config(args, constructor, trace_patch): +def infer_config(args, constructor, trace_patch, layout_patch): # Declare all supported attributes, across all plot types attrables = ( ["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"] @@ -1263,10 +1267,7 @@ def infer_config(args, constructor, trace_patch): if "color_discrete_sequence" not in args: attrs.append("color") else: - if ( - args["color"] - and args["data_frame"][args["color"]].dtype.kind in "ifc" - ): + if args["color"] and _is_continuous(args["data_frame"], args["color"]): attrs.append("color") args["color_is_continuous"] = True elif constructor in [go.Sunburst, go.Treemap]: @@ -1305,8 +1306,55 @@ def infer_config(args, constructor, trace_patch): if "symbol" in args: grouped_attrs.append("marker.symbol") - # Compute final trace patch - trace_patch = trace_patch.copy() + if "orientation" in args: + has_x = args["x"] is not None + has_y = args["y"] is not None + if args["orientation"] is None: + if constructor in [go.Histogram, go.Scatter]: + if has_y and not has_x: + args["orientation"] = "h" + elif constructor in [go.Violin, go.Box, go.Bar, go.Funnel]: + if has_x and not has_y: + args["orientation"] = "h" + + if args["orientation"] is None and has_x and has_y: + x_is_continuous = _is_continuous(args["data_frame"], args["x"]) + y_is_continuous = _is_continuous(args["data_frame"], args["y"]) + if x_is_continuous and not y_is_continuous: + args["orientation"] = "h" + if y_is_continuous and not x_is_continuous: + args["orientation"] = "v" + + if args["orientation"] is None: + args["orientation"] = "v" + + if constructor == go.Histogram: + if has_x and has_y and args["histfunc"] is None: + args["histfunc"] = trace_patch["histfunc"] = "sum" + + orientation = args["orientation"] + nbins = args["nbins"] + trace_patch["nbinsx"] = nbins if orientation == "v" else None + trace_patch["nbinsy"] = None if orientation == "v" else nbins + trace_patch["bingroup"] = "x" if orientation == "v" else "y" + trace_patch["orientation"] = args["orientation"] + + if constructor in [go.Violin, go.Box]: + mode = "boxmode" if constructor == go.Box else "violinmode" + if layout_patch[mode] is None and args["color"] is not None: + if args["y"] == args["color"] and args["orientation"] == "h": + layout_patch[mode] = "overlay" + elif args["x"] == args["color"] and args["orientation"] == "v": + layout_patch[mode] = "overlay" + if layout_patch[mode] is None: + layout_patch[mode] = "group" + + if ( + constructor == go.Histogram2d + and args["z"] is not None + and args["histfunc"] is None + ): + args["histfunc"] = trace_patch["histfunc"] = "sum" if constructor in [go.Histogram2d, go.Densitymapbox]: show_colorbar = True @@ -1354,7 +1402,7 @@ def infer_config(args, constructor, trace_patch): # Create trace specs trace_specs = make_trace_spec(args, constructor, attrs, trace_patch) - return args, trace_specs, grouped_mappings, sizeref, show_colorbar + return trace_specs, grouped_mappings, sizeref, show_colorbar def get_orderings(args, grouper, grouped): @@ -1398,11 +1446,13 @@ def get_orderings(args, grouper, grouped): return orders, group_names, group_values -def make_figure(args, constructor, trace_patch={}, layout_patch={}): +def make_figure(args, constructor, trace_patch=None, layout_patch=None): + trace_patch = trace_patch or {} + layout_patch = layout_patch or {} apply_default_cascade(args) - args, trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( - args, constructor, trace_patch + trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( + args, constructor, trace_patch, layout_patch ) grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] grouped = args["data_frame"].groupby(grouper, sort=False)