diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py
index 3e01f445c2..1b5aee6e0f 100644
--- a/packages/python/plotly/plotly/express/_core.py
+++ b/packages/python/plotly/plotly/express/_core.py
@@ -375,16 +375,18 @@ def configure_cartesian_marginal_axes(args, fig, orders):
# Configure axis ticks on marginal subplots
if args["marginal_x"]:
- fig.update_yaxes(
- showticklabels=False, showgrid=args["marginal_x"] == "histogram", row=nrows
- )
- fig.update_xaxes(showgrid=True, row=nrows)
+ fig.update_yaxes(showticklabels=False, row=nrows)
+ if args["template"].layout.yaxis.showgrid is None:
+ fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows)
+ if args["template"].layout.xaxis.showgrid is None:
+ fig.update_xaxes(showgrid=True, row=nrows)
if args["marginal_y"]:
- fig.update_xaxes(
- showticklabels=False, showgrid=args["marginal_y"] == "histogram", col=ncols
- )
- fig.update_yaxes(showgrid=True, col=ncols)
+ fig.update_xaxes(showticklabels=False, col=ncols)
+ if args["template"].layout.xaxis.showgrid is None:
+ fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols)
+ if args["template"].layout.yaxis.showgrid is None:
+ fig.update_yaxes(showgrid=True, col=ncols)
# Add axis titles to non-marginal subplots
y_title = get_decorated_label(args, args["y"], "y")
@@ -687,55 +689,47 @@ def apply_default_cascade(args):
else:
args["template"] = "plotly"
- # retrieve the actual template if we were given a name
try:
- template = pio.templates[args["template"]]
+ # retrieve the actual template if we were given a name
+ args["template"] = pio.templates[args["template"]]
except Exception:
- template = args["template"]
+ # otherwise try to build a real template
+ args["template"] = go.layout.Template(args["template"])
# if colors not set explicitly or in px.defaults, defer to a template
# if the template doesn't have one, we set some final fallback defaults
if "color_continuous_scale" in args:
- if args["color_continuous_scale"] is None:
- try:
- args["color_continuous_scale"] = [
- x[1] for x in template.layout.colorscale.sequential
- ]
- except (AttributeError, TypeError):
- pass
+ if (
+ args["color_continuous_scale"] is None
+ and args["template"].layout.colorscale.sequential
+ ):
+ args["color_continuous_scale"] = [
+ x[1] for x in args["template"].layout.colorscale.sequential
+ ]
if args["color_continuous_scale"] is None:
args["color_continuous_scale"] = sequential.Viridis
if "color_discrete_sequence" in args:
- if args["color_discrete_sequence"] is None:
- try:
- args["color_discrete_sequence"] = template.layout.colorway
- except (AttributeError, TypeError):
- pass
+ if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
+ args["color_discrete_sequence"] = args["template"].layout.colorway
if args["color_discrete_sequence"] is None:
args["color_discrete_sequence"] = qualitative.D3
# if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults,
# see if we can defer to template. If not, set reasonable defaults
if "symbol_sequence" in args:
- if args["symbol_sequence"] is None:
- try:
- args["symbol_sequence"] = [
- scatter.marker.symbol for scatter in template.data.scatter
- ]
- except (AttributeError, TypeError):
- pass
+ if args["symbol_sequence"] is None and args["template"].data.scatter:
+ args["symbol_sequence"] = [
+ scatter.marker.symbol for scatter in args["template"].data.scatter
+ ]
if not args["symbol_sequence"] or not any(args["symbol_sequence"]):
args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"]
if "line_dash_sequence" in args:
- if args["line_dash_sequence"] is None:
- try:
- args["line_dash_sequence"] = [
- scatter.line.dash for scatter in template.data.scatter
- ]
- except (AttributeError, TypeError):
- pass
+ if args["line_dash_sequence"] is None and args["template"].data.scatter:
+ args["line_dash_sequence"] = [
+ scatter.line.dash for scatter in args["template"].data.scatter
+ ]
if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]):
args["line_dash_sequence"] = [
"solid",
@@ -1264,13 +1258,17 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
cmax=range_color[1],
colorbar=dict(title=get_decorated_label(args, args[colorvar], colorvar)),
)
- for v in ["title", "height", "width", "template"]:
+ for v in ["title", "height", "width"]:
if args[v]:
layout_patch[v] = args[v]
layout_patch["legend"] = {"tracegroupgap": 0}
- if "title" not in layout_patch:
+ if "title" not in layout_patch and args["template"].layout.margin.t is None:
layout_patch["margin"] = {"t": 60}
- if "size" in args and args["size"]:
+ if (
+ "size" in args
+ and args["size"]
+ and args["template"].layout.legend.itemsizing is None
+ ):
layout_patch["legend"]["itemsizing"] = "constant"
fig = init_figure(
@@ -1295,6 +1293,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
# Add traces, layout and frames to figure
fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else [])
fig.layout.update(layout_patch)
+ if "template" in args and args["template"] is not None:
+ fig.update_layout(template=args["template"], overwrite=True)
fig.frames = frame_list if len(frames) > 1 else []
fig._px_trendlines = pd.DataFrame(trendline_rows)
diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py
index 588bfa3d18..08d4430c1f 100644
--- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py
+++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py
@@ -51,3 +51,110 @@ def test_custom_data_scatter():
fig.data[0].hovertemplate
== "sepal_width=%{x}
sepal_length=%{y}
petal_length=%{customdata[2]}
petal_width=%{customdata[3]}
species_id=%{customdata[0]}"
)
+
+
+def test_px_templates():
+ import plotly.io as pio
+ import plotly.graph_objects as go
+
+ tips = px.data.tips()
+
+ # use the normal defaults
+ fig = px.scatter()
+ assert fig.layout.template == pio.templates[pio.templates.default]
+
+ # respect changes to defaults
+ pio.templates.default = "seaborn"
+ fig = px.scatter()
+ assert fig.layout.template == pio.templates["seaborn"]
+
+ # special px-level defaults over pio defaults
+ pio.templates.default = "seaborn"
+ px.defaults.template = "ggplot2"
+ fig = px.scatter()
+ assert fig.layout.template == pio.templates["ggplot2"]
+
+ # accept names in args over pio and px defaults
+ fig = px.scatter(template="seaborn")
+ assert fig.layout.template == pio.templates["seaborn"]
+
+ # accept objects in args
+ fig = px.scatter(template={})
+ assert fig.layout.template == go.layout.Template()
+
+ # read colorway from the template
+ fig = px.scatter(
+ tips,
+ x="total_bill",
+ y="tip",
+ color="sex",
+ template=dict(layout_colorway=["red", "blue"]),
+ )
+ assert fig.data[0].marker.color == "red"
+ assert fig.data[1].marker.color == "blue"
+
+ # default colorway fallback
+ fig = px.scatter(tips, x="total_bill", y="tip", color="sex", template=dict())
+ assert fig.data[0].marker.color == px.colors.qualitative.D3[0]
+ assert fig.data[1].marker.color == px.colors.qualitative.D3[1]
+
+ # pio default template colorway fallback
+ pio.templates.default = "seaborn"
+ px.defaults.template = None
+ fig = px.scatter(tips, x="total_bill", y="tip", color="sex")
+ assert fig.data[0].marker.color == pio.templates["seaborn"].layout.colorway[0]
+ assert fig.data[1].marker.color == pio.templates["seaborn"].layout.colorway[1]
+
+ # pio default template colorway fallback
+ pio.templates.default = "seaborn"
+ px.defaults.template = "ggplot2"
+ fig = px.scatter(tips, x="total_bill", y="tip", color="sex")
+ assert fig.data[0].marker.color == pio.templates["ggplot2"].layout.colorway[0]
+ assert fig.data[1].marker.color == pio.templates["ggplot2"].layout.colorway[1]
+
+ # don't overwrite top margin when set in template
+ fig = px.scatter(title="yo")
+ assert fig.layout.margin.t is None
+
+ fig = px.scatter()
+ assert fig.layout.margin.t == 60
+
+ fig = px.scatter(template=dict(layout_margin_t=2))
+ assert fig.layout.margin.t is None
+
+ # don't force histogram gridlines when set in template
+ pio.templates.default = "none"
+ px.defaults.template = None
+ fig = px.scatter(
+ tips, x="total_bill", y="tip", marginal_x="histogram", marginal_y="histogram"
+ )
+ assert fig.layout.xaxis2.showgrid
+ assert fig.layout.xaxis3.showgrid
+ assert fig.layout.yaxis2.showgrid
+ assert fig.layout.yaxis3.showgrid
+
+ fig = px.scatter(
+ tips,
+ x="total_bill",
+ y="tip",
+ marginal_x="histogram",
+ marginal_y="histogram",
+ template=dict(layout_yaxis_showgrid=False),
+ )
+ assert fig.layout.xaxis2.showgrid
+ assert fig.layout.xaxis3.showgrid
+ assert fig.layout.yaxis2.showgrid is None
+ assert fig.layout.yaxis3.showgrid is None
+
+ fig = px.scatter(
+ tips,
+ x="total_bill",
+ y="tip",
+ marginal_x="histogram",
+ marginal_y="histogram",
+ template=dict(layout_xaxis_showgrid=False),
+ )
+ assert fig.layout.xaxis2.showgrid is None
+ assert fig.layout.xaxis3.showgrid is None
+ assert fig.layout.yaxis2.showgrid
+ assert fig.layout.yaxis3.showgrid