Skip to content

Commit

Permalink
Merge branch 'auto_orient' into wide_form2
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolaskruchten committed Apr 3, 2020
2 parents 33d03d5 + bbc22bc commit 236cd2c
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 47 deletions.
50 changes: 16 additions & 34 deletions packages/python/plotly/plotly/express/_chart_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def area(
labels={},
color_discrete_sequence=None,
color_discrete_map={},
orientation="v",
orientation=None,
groupnorm=None,
log_x=False,
log_y=False,
Expand All @@ -256,9 +256,7 @@ def area(
return make_figure(
args=locals(),
constructor=go.Scatter,
trace_patch=dict(
stackgroup=1, mode="lines", orientation=orientation, groupnorm=groupnorm
),
trace_patch=dict(stackgroup=1, mode="lines", groupnorm=groupnorm),
)


Expand Down Expand Up @@ -291,7 +289,7 @@ def bar(
range_color=None,
color_continuous_midpoint=None,
opacity=None,
orientation="v",
orientation=None,
barmode="relative",
log_x=False,
log_y=False,
Expand All @@ -309,7 +307,7 @@ def bar(
return make_figure(
args=locals(),
constructor=go.Bar,
trace_patch=dict(orientation=orientation, textposition="auto"),
trace_patch=dict(textposition="auto"),
layout_patch=dict(barmode=barmode),
)

Expand All @@ -335,7 +333,7 @@ def histogram(
color_discrete_map={},
marginal=None,
opacity=None,
orientation="v",
orientation=None,
barmode="relative",
barnorm=None,
histnorm=None,
Expand All @@ -361,13 +359,7 @@ def histogram(
args=locals(),
constructor=go.Histogram,
trace_patch=dict(
orientation=orientation,
histnorm=histnorm,
histfunc=histfunc,
nbinsx=nbins if orientation == "v" else None,
nbinsy=None if orientation == "v" else nbins,
cumulative=dict(enabled=cumulative),
bingroup="x" if orientation == "v" else "y",
histnorm=histnorm, histfunc=histfunc, cumulative=dict(enabled=cumulative),
),
layout_patch=dict(barmode=barmode, barnorm=barnorm),
)
Expand All @@ -393,8 +385,8 @@ def violin(
labels={},
color_discrete_sequence=None,
color_discrete_map={},
orientation="v",
violinmode="group",
orientation=None,
violinmode=None,
log_x=False,
log_y=False,
range_x=None,
Expand All @@ -414,12 +406,7 @@ def violin(
args=locals(),
constructor=go.Violin,
trace_patch=dict(
orientation=orientation,
points=points,
box=dict(visible=box),
scalegroup=True,
x0=" ",
y0=" ",
points=points, box=dict(visible=box), scalegroup=True, x0=" ", y0=" ",
),
layout_patch=dict(violinmode=violinmode),
)
Expand All @@ -445,8 +432,8 @@ def box(
labels={},
color_discrete_sequence=None,
color_discrete_map={},
orientation="v",
boxmode="group",
orientation=None,
boxmode=None,
log_x=False,
log_y=False,
range_x=None,
Expand All @@ -470,9 +457,7 @@ def box(
return make_figure(
args=locals(),
constructor=go.Box,
trace_patch=dict(
orientation=orientation, boxpoints=points, notched=notched, x0=" ", y0=" "
),
trace_patch=dict(boxpoints=points, notched=notched, x0=" ", y0=" "),
layout_patch=dict(boxmode=boxmode),
)

Expand All @@ -497,8 +482,8 @@ def strip(
labels={},
color_discrete_sequence=None,
color_discrete_map={},
orientation="v",
stripmode="group",
orientation=None,
stripmode=None,
log_x=False,
log_y=False,
range_x=None,
Expand All @@ -516,7 +501,6 @@ def strip(
args=locals(),
constructor=go.Box,
trace_patch=dict(
orientation=orientation,
boxpoints="all",
pointpos=0,
hoveron="points",
Expand Down Expand Up @@ -1384,7 +1368,7 @@ def funnel(
color_discrete_sequence=None,
color_discrete_map={},
opacity=None,
orientation="h",
orientation=None,
log_x=False,
log_y=False,
range_x=None,
Expand All @@ -1398,9 +1382,7 @@ def funnel(
In a funnel plot, each row of `data_frame` is represented as a
rectangular sector of a funnel.
"""
return make_figure(
args=locals(), constructor=go.Funnel, trace_patch=dict(orientation=orientation),
)
return make_figure(args=locals(), constructor=go.Funnel)


funnel.__doc__ = make_docstring(funnel)
Expand Down
76 changes: 63 additions & 13 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def get_label(args, column):
return column


def _is_continuous(df, col_name):
return df[col_name].dtype.kind in "ifc"


def get_decorated_label(args, column, role):
label = get_label(args, column)
if "histfunc" in args and (
Expand Down Expand Up @@ -188,7 +192,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
if ((not attr_value) or (name in attr_value))
and (
trace_spec.constructor != go.Parcoords
or args["data_frame"][name].dtype.kind in "ifc"
or _is_continuous(args["data_frame"], name)
)
and (
trace_spec.constructor != go.Parcats
Expand Down Expand Up @@ -1161,7 +1165,7 @@ def aggfunc_discrete(x):
agg_f[count_colname] = "sum"

if args["color"]:
if df[args["color"]].dtype.kind not in "ifc":
if not _is_continuous(df, args["color"]):
aggfunc_color = aggfunc_discrete
discrete_color = True
elif not aggfunc_color:
Expand Down Expand Up @@ -1227,7 +1231,7 @@ def aggfunc_continuous(x):
return args


def infer_config(args, constructor, trace_patch):
def infer_config(args, constructor, trace_patch, layout_patch):
# Declare all supported attributes, across all plot types
attrables = (
["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"]
Expand Down Expand Up @@ -1263,10 +1267,7 @@ def infer_config(args, constructor, trace_patch):
if "color_discrete_sequence" not in args:
attrs.append("color")
else:
if (
args["color"]
and args["data_frame"][args["color"]].dtype.kind in "ifc"
):
if args["color"] and _is_continuous(args["data_frame"], args["color"]):
attrs.append("color")
args["color_is_continuous"] = True
elif constructor in [go.Sunburst, go.Treemap]:
Expand Down Expand Up @@ -1305,8 +1306,55 @@ def infer_config(args, constructor, trace_patch):
if "symbol" in args:
grouped_attrs.append("marker.symbol")

# Compute final trace patch
trace_patch = trace_patch.copy()
if "orientation" in args:
has_x = args["x"] is not None
has_y = args["y"] is not None
if args["orientation"] is None:
if constructor in [go.Histogram, go.Scatter]:
if has_y and not has_x:
args["orientation"] = "h"
elif constructor in [go.Violin, go.Box, go.Bar, go.Funnel]:
if has_x and not has_y:
args["orientation"] = "h"

if args["orientation"] is None and has_x and has_y:
x_is_continuous = _is_continuous(args["data_frame"], args["x"])
y_is_continuous = _is_continuous(args["data_frame"], args["y"])
if x_is_continuous and not y_is_continuous:
args["orientation"] = "h"
if y_is_continuous and not x_is_continuous:
args["orientation"] = "v"

if args["orientation"] is None:
args["orientation"] = "v"

if constructor == go.Histogram:
if has_x and has_y and args["histfunc"] is None:
args["histfunc"] = trace_patch["histfunc"] = "sum"

orientation = args["orientation"]
nbins = args["nbins"]
trace_patch["nbinsx"] = nbins if orientation == "v" else None
trace_patch["nbinsy"] = None if orientation == "v" else nbins
trace_patch["bingroup"] = "x" if orientation == "v" else "y"
trace_patch["orientation"] = args["orientation"]

if constructor in [go.Violin, go.Box]:
mode = "boxmode" if constructor == go.Box else "violinmode"
if layout_patch[mode] is None and args["color"] is not None:
if args["y"] == args["color"] and args["orientation"] == "h":
layout_patch[mode] = "overlay"
elif args["x"] == args["color"] and args["orientation"] == "v":
layout_patch[mode] = "overlay"
if layout_patch[mode] is None:
layout_patch[mode] = "group"

if (
constructor == go.Histogram2d
and args["z"] is not None
and args["histfunc"] is None
):
args["histfunc"] = trace_patch["histfunc"] = "sum"

if constructor in [go.Histogram2d, go.Densitymapbox]:
show_colorbar = True
Expand Down Expand Up @@ -1354,7 +1402,7 @@ def infer_config(args, constructor, trace_patch):

# Create trace specs
trace_specs = make_trace_spec(args, constructor, attrs, trace_patch)
return args, trace_specs, grouped_mappings, sizeref, show_colorbar
return trace_specs, grouped_mappings, sizeref, show_colorbar


def get_orderings(args, grouper, grouped):
Expand Down Expand Up @@ -1398,11 +1446,13 @@ def get_orderings(args, grouper, grouped):
return orders, group_names, group_values


def make_figure(args, constructor, trace_patch={}, layout_patch={}):
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
trace_patch = trace_patch or {}
layout_patch = layout_patch or {}
apply_default_cascade(args)

args, trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
args, constructor, trace_patch
trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
args, constructor, trace_patch, layout_patch
)
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
grouped = args["data_frame"].groupby(grouper, sort=False)
Expand Down

0 comments on commit 236cd2c

Please sign in to comment.