From ae1423168fd90207208a0c9b33d1d1c91a79c912 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 5 Nov 2019 12:50:34 -0500 Subject: [PATCH 1/4] PX shouldn't modify attrs controlled by template --- .../python/plotly/plotly/express/_core.py | 76 +++++++++---------- 1 file changed, 37 insertions(+), 39 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 3e01f445c2..b1c619c294 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -375,16 +375,18 @@ def configure_cartesian_marginal_axes(args, fig, orders): # Configure axis ticks on marginal subplots if args["marginal_x"]: - fig.update_yaxes( - showticklabels=False, showgrid=args["marginal_x"] == "histogram", row=nrows - ) - fig.update_xaxes(showgrid=True, row=nrows) + fig.update_yaxes(showticklabels=False, row=nrows) + if args["template"].layout.yaxis.showgrid is None: + fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows) + if args["template"].layout.xaxis.showgrid is None: + fig.update_xaxes(showgrid=True, row=nrows) if args["marginal_y"]: - fig.update_xaxes( - showticklabels=False, showgrid=args["marginal_y"] == "histogram", col=ncols - ) - fig.update_yaxes(showgrid=True, col=ncols) + fig.update_xaxes(showticklabels=False, col=ncols) + if args["template"].layout.xaxis.showgrid is None: + fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols) + if args["template"].layout.yaxis.showgrid is None: + fig.update_yaxes(showgrid=True, col=ncols) # Add axis titles to non-marginal subplots y_title = get_decorated_label(args, args["y"], "y") @@ -687,55 +689,47 @@ def apply_default_cascade(args): else: args["template"] = "plotly" - # retrieve the actual template if we were given a name try: - template = pio.templates[args["template"]] + # retrieve the actual template if we were given a name + args["template"] = pio.templates[args["template"]] except Exception: - template = args["template"] + # otherwise try to build a real template + args["template"] = go.layout.Template(args["template"]) # if colors not set explicitly or in px.defaults, defer to a template # if the template doesn't have one, we set some final fallback defaults if "color_continuous_scale" in args: - if args["color_continuous_scale"] is None: - try: - args["color_continuous_scale"] = [ - x[1] for x in template.layout.colorscale.sequential - ] - except (AttributeError, TypeError): - pass + if ( + args["color_continuous_scale"] is None + and args["template"].layout.colorscale.sequential + ): + args["color_continuous_scale"] = [ + x[1] for x in args["template"].layout.colorscale.sequential + ] if args["color_continuous_scale"] is None: args["color_continuous_scale"] = sequential.Viridis if "color_discrete_sequence" in args: - if args["color_discrete_sequence"] is None: - try: - args["color_discrete_sequence"] = template.layout.colorway - except (AttributeError, TypeError): - pass + if args["color_discrete_sequence"] is None and args["template"].layout.colorway: + args["color_discrete_sequence"] = args["template"].layout.colorway if args["color_discrete_sequence"] is None: args["color_discrete_sequence"] = qualitative.D3 # if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults, # see if we can defer to template. If not, set reasonable defaults if "symbol_sequence" in args: - if args["symbol_sequence"] is None: - try: - args["symbol_sequence"] = [ - scatter.marker.symbol for scatter in template.data.scatter - ] - except (AttributeError, TypeError): - pass + if args["symbol_sequence"] is None and args["template"].data.scatter: + args["symbol_sequence"] = [ + scatter.marker.symbol for scatter in args["template"].data.scatter + ] if not args["symbol_sequence"] or not any(args["symbol_sequence"]): args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"] if "line_dash_sequence" in args: - if args["line_dash_sequence"] is None: - try: - args["line_dash_sequence"] = [ - scatter.line.dash for scatter in template.data.scatter - ] - except (AttributeError, TypeError): - pass + if args["line_dash_sequence"] is None and args["template"].data.scatter: + args["line_dash_sequence"] = [ + scatter.line.dash for scatter in args["template"].data.scatter + ] if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]): args["line_dash_sequence"] = [ "solid", @@ -1268,9 +1262,13 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): if args[v]: layout_patch[v] = args[v] layout_patch["legend"] = {"tracegroupgap": 0} - if "title" not in layout_patch: + if "title" not in layout_patch and args["template"].layout.margin.t is None: layout_patch["margin"] = {"t": 60} - if "size" in args and args["size"]: + if ( + "size" in args + and args["size"] + and args["template"].layout.legend.itemsizing is None + ): layout_patch["legend"]["itemsizing"] = "constant" fig = init_figure( From c4445185acd129cb8ba98b1fc1d29277138f818f Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 5 Nov 2019 16:31:06 -0500 Subject: [PATCH 2/4] force the template via override --- packages/python/plotly/plotly/express/_core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index b1c619c294..1b5aee6e0f 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1258,7 +1258,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): cmax=range_color[1], colorbar=dict(title=get_decorated_label(args, args[colorvar], colorvar)), ) - for v in ["title", "height", "width", "template"]: + for v in ["title", "height", "width"]: if args[v]: layout_patch[v] = args[v] layout_patch["legend"] = {"tracegroupgap": 0} @@ -1293,6 +1293,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): # Add traces, layout and frames to figure fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else []) fig.layout.update(layout_patch) + if "template" in args and args["template"] is not None: + fig.update_layout(template=args["template"], overwrite=True) fig.frames = frame_list if len(frames) > 1 else [] fig._px_trendlines = pd.DataFrame(trendline_rows) From 778da393172775f097c314cd30701c126cfb7bef Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 5 Nov 2019 17:05:54 -0500 Subject: [PATCH 3/4] px template tests galore --- .../plotly/tests/test_core/test_px/test_px.py | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py index 588bfa3d18..e9bd73f6ec 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py @@ -51,3 +51,111 @@ def test_custom_data_scatter(): fig.data[0].hovertemplate == "sepal_width=%{x}
sepal_length=%{y}
petal_length=%{customdata[2]}
petal_width=%{customdata[3]}
species_id=%{customdata[0]}" ) + + +def test_px_templates(): + import plotly.io as pio + import plotly.graph_objects as go + + tips = px.data.tips() + + # use the normal defaults + fig = px.scatter() + assert fig.layout.template == pio.templates[pio.templates.default] + + # respect changes to defaults + pio.templates.default = "seaborn" + fig = px.scatter() + assert fig.layout.template == pio.templates["seaborn"] + + # special px-level defaults over pio defaults + pio.templates.default = "seaborn" + px.defaults.template = "ggplot2" + fig = px.scatter() + assert fig.layout.template == pio.templates["ggplot2"] + + # accept names in args over pio and px defaults + fig = px.scatter(template="seaborn") + assert fig.layout.template == pio.templates["seaborn"] + + # accept objects in args + fig = px.scatter(template={}) + assert fig.layout.template == go.layout.Template() + + # read colorway from the template + fig = px.scatter( + tips, + x="total_bill", + y="tip", + color="sex", + template=dict(layout_colorway=["red", "blue"]), + ) + assert fig.data[0].marker.color == "red" + assert fig.data[1].marker.color == "blue" + + # default colorway fallback + fig = px.scatter(tips, x="total_bill", y="tip", color="sex", template=dict()) + assert fig.data[0].marker.color == px.colors.qualitative.D3[0] + assert fig.data[1].marker.color == px.colors.qualitative.D3[1] + + # pio default template colorway fallback + pio.templates.default = "seaborn" + px.defaults.template = None + fig = px.scatter(tips, x="total_bill", y="tip", color="sex") + assert fig.data[0].marker.color == pio.templates["seaborn"].layout.colorway[0] + assert fig.data[1].marker.color == pio.templates["seaborn"].layout.colorway[1] + + # pio default template colorway fallback + pio.templates.default = "seaborn" + px.defaults.template = "ggplot2" + fig = px.scatter(tips, x="total_bill", y="tip", color="sex") + assert fig.data[0].marker.color == pio.templates["ggplot2"].layout.colorway[0] + assert fig.data[1].marker.color == pio.templates["ggplot2"].layout.colorway[1] + + # don't overwrite top margin when set in template + fig = px.scatter(title="yo") + assert fig.layout.margin.t is None + + fig = px.scatter() + assert fig.layout.margin.t == 60 + + fig = px.scatter(template=dict(layout_margin_t=2)) + assert fig.layout.margin.t is None + + # don't force histogram gridlines when set in template + pio.templates.default = "none" + px.defaults.template = None + fig = px.scatter( + tips, x="total_bill", y="tip", marginal_x="histogram", marginal_y="histogram" + ) + assert fig.layout.xaxis2.showgrid + assert fig.layout.xaxis3.showgrid + assert fig.layout.yaxis2.showgrid + assert fig.layout.yaxis3.showgrid + + fig = px.scatter( + tips, + x="total_bill", + y="tip", + marginal_x="histogram", + marginal_y="histogram", + template=dict(layout_yaxis_showgrid=False), + ) + assert fig.layout.xaxis2.showgrid + assert fig.layout.xaxis3.showgrid + assert fig.layout.yaxis2.showgrid is None + assert fig.layout.yaxis3.showgrid is None + + fig = px.scatter( + tips, + x="total_bill", + y="tip", + marginal_x="histogram", + marginal_y="histogram", + template=dict(layout_xaxis_showgrid=False), + ) + assert fig.layout.xaxis2.showgrid is None + assert fig.layout.xaxis3.showgrid is None + assert fig.layout.yaxis2.showgrid + assert fig.layout.yaxis3.showgrid + From 1e88daa9683886c7ae9d56ec70d149542ea75e2f Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Wed, 6 Nov 2019 09:17:48 -0500 Subject: [PATCH 4/4] blacken --- packages/python/plotly/plotly/tests/test_core/test_px/test_px.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py index e9bd73f6ec..08d4430c1f 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py @@ -158,4 +158,3 @@ def test_px_templates(): assert fig.layout.xaxis3.showgrid is None assert fig.layout.yaxis2.showgrid assert fig.layout.yaxis3.showgrid -