Skip to content

Commit

Permalink
Merge pull request #1875 from plotly/px_real_template
Browse files Browse the repository at this point in the history
PX shouldn't modify attrs controlled by template
  • Loading branch information
nicolaskruchten authored Nov 6, 2019
2 parents 06a2cb9 + 1e88daa commit 5830055
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 40 deletions.
80 changes: 40 additions & 40 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,16 +375,18 @@ def configure_cartesian_marginal_axes(args, fig, orders):

# Configure axis ticks on marginal subplots
if args["marginal_x"]:
fig.update_yaxes(
showticklabels=False, showgrid=args["marginal_x"] == "histogram", row=nrows
)
fig.update_xaxes(showgrid=True, row=nrows)
fig.update_yaxes(showticklabels=False, row=nrows)
if args["template"].layout.yaxis.showgrid is None:
fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows)
if args["template"].layout.xaxis.showgrid is None:
fig.update_xaxes(showgrid=True, row=nrows)

if args["marginal_y"]:
fig.update_xaxes(
showticklabels=False, showgrid=args["marginal_y"] == "histogram", col=ncols
)
fig.update_yaxes(showgrid=True, col=ncols)
fig.update_xaxes(showticklabels=False, col=ncols)
if args["template"].layout.xaxis.showgrid is None:
fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols)
if args["template"].layout.yaxis.showgrid is None:
fig.update_yaxes(showgrid=True, col=ncols)

# Add axis titles to non-marginal subplots
y_title = get_decorated_label(args, args["y"], "y")
Expand Down Expand Up @@ -687,55 +689,47 @@ def apply_default_cascade(args):
else:
args["template"] = "plotly"

# retrieve the actual template if we were given a name
try:
template = pio.templates[args["template"]]
# retrieve the actual template if we were given a name
args["template"] = pio.templates[args["template"]]
except Exception:
template = args["template"]
# otherwise try to build a real template
args["template"] = go.layout.Template(args["template"])

# if colors not set explicitly or in px.defaults, defer to a template
# if the template doesn't have one, we set some final fallback defaults
if "color_continuous_scale" in args:
if args["color_continuous_scale"] is None:
try:
args["color_continuous_scale"] = [
x[1] for x in template.layout.colorscale.sequential
]
except (AttributeError, TypeError):
pass
if (
args["color_continuous_scale"] is None
and args["template"].layout.colorscale.sequential
):
args["color_continuous_scale"] = [
x[1] for x in args["template"].layout.colorscale.sequential
]
if args["color_continuous_scale"] is None:
args["color_continuous_scale"] = sequential.Viridis

if "color_discrete_sequence" in args:
if args["color_discrete_sequence"] is None:
try:
args["color_discrete_sequence"] = template.layout.colorway
except (AttributeError, TypeError):
pass
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
args["color_discrete_sequence"] = args["template"].layout.colorway
if args["color_discrete_sequence"] is None:
args["color_discrete_sequence"] = qualitative.D3

# if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults,
# see if we can defer to template. If not, set reasonable defaults
if "symbol_sequence" in args:
if args["symbol_sequence"] is None:
try:
args["symbol_sequence"] = [
scatter.marker.symbol for scatter in template.data.scatter
]
except (AttributeError, TypeError):
pass
if args["symbol_sequence"] is None and args["template"].data.scatter:
args["symbol_sequence"] = [
scatter.marker.symbol for scatter in args["template"].data.scatter
]
if not args["symbol_sequence"] or not any(args["symbol_sequence"]):
args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"]

if "line_dash_sequence" in args:
if args["line_dash_sequence"] is None:
try:
args["line_dash_sequence"] = [
scatter.line.dash for scatter in template.data.scatter
]
except (AttributeError, TypeError):
pass
if args["line_dash_sequence"] is None and args["template"].data.scatter:
args["line_dash_sequence"] = [
scatter.line.dash for scatter in args["template"].data.scatter
]
if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]):
args["line_dash_sequence"] = [
"solid",
Expand Down Expand Up @@ -1264,13 +1258,17 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
cmax=range_color[1],
colorbar=dict(title=get_decorated_label(args, args[colorvar], colorvar)),
)
for v in ["title", "height", "width", "template"]:
for v in ["title", "height", "width"]:
if args[v]:
layout_patch[v] = args[v]
layout_patch["legend"] = {"tracegroupgap": 0}
if "title" not in layout_patch:
if "title" not in layout_patch and args["template"].layout.margin.t is None:
layout_patch["margin"] = {"t": 60}
if "size" in args and args["size"]:
if (
"size" in args
and args["size"]
and args["template"].layout.legend.itemsizing is None
):
layout_patch["legend"]["itemsizing"] = "constant"

fig = init_figure(
Expand All @@ -1295,6 +1293,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
# Add traces, layout and frames to figure
fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else [])
fig.layout.update(layout_patch)
if "template" in args and args["template"] is not None:
fig.update_layout(template=args["template"], overwrite=True)
fig.frames = frame_list if len(frames) > 1 else []

fig._px_trendlines = pd.DataFrame(trendline_rows)
Expand Down
107 changes: 107 additions & 0 deletions packages/python/plotly/plotly/tests/test_core/test_px/test_px.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,110 @@ def test_custom_data_scatter():
fig.data[0].hovertemplate
== "sepal_width=%{x}<br>sepal_length=%{y}<br>petal_length=%{customdata[2]}<br>petal_width=%{customdata[3]}<br>species_id=%{customdata[0]}"
)


def test_px_templates():
import plotly.io as pio
import plotly.graph_objects as go

tips = px.data.tips()

# use the normal defaults
fig = px.scatter()
assert fig.layout.template == pio.templates[pio.templates.default]

# respect changes to defaults
pio.templates.default = "seaborn"
fig = px.scatter()
assert fig.layout.template == pio.templates["seaborn"]

# special px-level defaults over pio defaults
pio.templates.default = "seaborn"
px.defaults.template = "ggplot2"
fig = px.scatter()
assert fig.layout.template == pio.templates["ggplot2"]

# accept names in args over pio and px defaults
fig = px.scatter(template="seaborn")
assert fig.layout.template == pio.templates["seaborn"]

# accept objects in args
fig = px.scatter(template={})
assert fig.layout.template == go.layout.Template()

# read colorway from the template
fig = px.scatter(
tips,
x="total_bill",
y="tip",
color="sex",
template=dict(layout_colorway=["red", "blue"]),
)
assert fig.data[0].marker.color == "red"
assert fig.data[1].marker.color == "blue"

# default colorway fallback
fig = px.scatter(tips, x="total_bill", y="tip", color="sex", template=dict())
assert fig.data[0].marker.color == px.colors.qualitative.D3[0]
assert fig.data[1].marker.color == px.colors.qualitative.D3[1]

# pio default template colorway fallback
pio.templates.default = "seaborn"
px.defaults.template = None
fig = px.scatter(tips, x="total_bill", y="tip", color="sex")
assert fig.data[0].marker.color == pio.templates["seaborn"].layout.colorway[0]
assert fig.data[1].marker.color == pio.templates["seaborn"].layout.colorway[1]

# pio default template colorway fallback
pio.templates.default = "seaborn"
px.defaults.template = "ggplot2"
fig = px.scatter(tips, x="total_bill", y="tip", color="sex")
assert fig.data[0].marker.color == pio.templates["ggplot2"].layout.colorway[0]
assert fig.data[1].marker.color == pio.templates["ggplot2"].layout.colorway[1]

# don't overwrite top margin when set in template
fig = px.scatter(title="yo")
assert fig.layout.margin.t is None

fig = px.scatter()
assert fig.layout.margin.t == 60

fig = px.scatter(template=dict(layout_margin_t=2))
assert fig.layout.margin.t is None

# don't force histogram gridlines when set in template
pio.templates.default = "none"
px.defaults.template = None
fig = px.scatter(
tips, x="total_bill", y="tip", marginal_x="histogram", marginal_y="histogram"
)
assert fig.layout.xaxis2.showgrid
assert fig.layout.xaxis3.showgrid
assert fig.layout.yaxis2.showgrid
assert fig.layout.yaxis3.showgrid

fig = px.scatter(
tips,
x="total_bill",
y="tip",
marginal_x="histogram",
marginal_y="histogram",
template=dict(layout_yaxis_showgrid=False),
)
assert fig.layout.xaxis2.showgrid
assert fig.layout.xaxis3.showgrid
assert fig.layout.yaxis2.showgrid is None
assert fig.layout.yaxis3.showgrid is None

fig = px.scatter(
tips,
x="total_bill",
y="tip",
marginal_x="histogram",
marginal_y="histogram",
template=dict(layout_xaxis_showgrid=False),
)
assert fig.layout.xaxis2.showgrid is None
assert fig.layout.xaxis3.showgrid is None
assert fig.layout.yaxis2.showgrid
assert fig.layout.yaxis3.showgrid

0 comments on commit 5830055

Please sign in to comment.