From 65abd71b194bb13d5c1d6b346b9ac1b5020597bd Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Mon, 8 Jun 2020 15:39:32 -0400 Subject: [PATCH 1/4] make trendlines more robust --- .../python/plotly/plotly/express/_core.py | 36 +++++--- .../tests/test_core/test_px/test_trendline.py | 90 +++++++++++++++++-- 2 files changed, 108 insertions(+), 18 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 5d6388ae59..635bd870b9 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -277,17 +277,24 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): attr_value in ["ols", "lowess"] and args["x"] and args["y"] - and len(trace_data) > 1 + and len(trace_data[[args["x"], args["y"]]].dropna()) > 1 ): import statsmodels.api as sm # sorting is bad but trace_specs with "trendline" have no other attrs sorted_trace_data = trace_data.sort_values(by=args["x"]) - y = sorted_trace_data[args["y"]] - x = sorted_trace_data[args["x"]] + y = sorted_trace_data[args["y"]].values + x = sorted_trace_data[args["x"]].values + x_is_date = False if x.dtype.type == np.datetime64: x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds + x_is_date = True + elif x.dtype.type == np.object_: + x = x.astype(np.float64) + + if y.dtype.type == np.object_: + y = y.astype(np.float64) if attr_value == "lowess": # missing ='drop' is the default value for lowess but not for OLS (None) @@ -298,25 +305,32 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): hover_header = "LOWESS trendline

" elif attr_value == "ols": fit_results = sm.OLS( - y.values, sm.add_constant(x.values), missing="drop" + y, sm.add_constant(x), missing="drop" ).fit() trace_patch["y"] = fit_results.predict() trace_patch["x"] = x[ np.logical_not(np.logical_or(np.isnan(y), np.isnan(x))) ] hover_header = "OLS trendline
" - hover_header += "%s = %g * %s + %g
" % ( - args["y"], - fit_results.params[1], - args["x"], - fit_results.params[0], - ) + if len(fit_results.params) == 2: + hover_header += "%s = %g * %s + %g
" % ( + args["y"], + fit_results.params[1], + args["x"], + fit_results.params[0], + ) + else: + hover_header += "%s = %g
" % ( + args["y"], + fit_results.params[0], + ) hover_header += ( "R2=%f

" % fit_results.rsquared ) + if x_is_date: + trace_patch["x"] = pd.to_datetime(trace_patch["x"] * 10 ** 9) mapping_labels[get_label(args, args["x"])] = "%{x}" mapping_labels[get_label(args, args["y"])] = "%{y} (trend)" - elif attr_name.startswith("error"): error_xy = attr_name[:7] arr = "arrayminus" if attr_name.endswith("minus") else "array" diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py index 4c151148c1..125319be97 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py @@ -1,14 +1,90 @@ import plotly.express as px import numpy as np +import pandas as pd +import pytest +from datetime import datetime -def test_trendline_nan_values(): +@pytest.mark.parametrize("mode", ["ols", "lowess"]) +def test_trendline_results_passthrough(mode): + df = px.data.gapminder().query("continent == 'Oceania'") + fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode) + assert len(fig.data) == 4 + for trace in fig["data"][0::2]: + assert "trendline" not in trace.hovertemplate + for trendline in fig["data"][1::2]: + assert "trendline" in trendline.hovertemplate + if mode == "ols": + assert "R2" in trendline.hovertemplate + results = px.get_trendline_results(fig) + if mode == "ols": + assert len(results) == 2 + assert results["country"].values[0] == "Australia" + assert results["country"].values[0] == "Australia" + au_result = results["px_fit_results"].values[0] + assert len(au_result.params) == 2 + else: + assert len(results) == 0 + + +@pytest.mark.parametrize("mode", ["ols", "lowess"]) +def test_trendline_enough_values(mode): + fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode) + assert len(fig.data) == 2 + assert len(fig.data[1].x) == 2 + fig = px.scatter(x=[0], y=[0], trendline=mode) + assert len(fig.data) == 2 + assert fig.data[1].x is None + fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode) + assert len(fig.data) == 2 + assert fig.data[1].x is None + fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode) + assert len(fig.data) == 2 + assert fig.data[1].x is None + fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode) + assert len(fig.data) == 2 + assert len(fig.data[1].x) == 2 + + +@pytest.mark.parametrize("mode", ["ols", "lowess"]) +def test_trendline_nan_values(mode): df = px.data.gapminder().query("continent == 'Oceania'") start_date = 1970 df["pop"][df["year"] < start_date] = np.nan - modes = ["ols", "lowess"] - for mode in modes: - fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode) - for trendline in fig["data"][1::2]: - assert trendline.x[0] >= start_date - assert len(trendline.x) == len(trendline.y) + fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode) + for trendline in fig["data"][1::2]: + assert trendline.x[0] >= start_date + assert len(trendline.x) == len(trendline.y) + + +def test_no_slope_ols_trendline(): + fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols") + assert "y = 1" in fig.data[1].hovertemplate # then + x*(some small number) + results = px.get_trendline_results(fig) + params = results["px_fit_results"].iloc[0].params + assert np.all(np.isclose(params, [0, 1])) + + fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols") + assert "y = 0" in fig.data[1].hovertemplate + results = px.get_trendline_results(fig) + params = results["px_fit_results"].iloc[0].params + assert np.all(np.isclose(params, [0])) + + fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols") + assert "y = 0" in fig.data[1].hovertemplate + fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols") + assert "y = 0 * x + 1" in fig.data[1].hovertemplate + fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols") + assert "y = 0 * x + 1.5" in fig.data[1].hovertemplate + + +@pytest.mark.parametrize("mode", ["ols", "lowess"]) +def test_trendline_on_timeseries(mode): + df = px.data.stocks() + df["date"] = pd.to_datetime(df["date"]) + fig = px.scatter(df, x="date", y="GOOG", trendline=mode) + assert len(fig.data) == 2 + assert len(fig.data[0].x) == len(fig.data[1].x) + assert type(fig.data[0].x[0]) == datetime + assert type(fig.data[1].x[0]) == datetime + assert np.all(fig.data[0].x == fig.data[1].x) From c19c49fe82bdcd87e0eacbd7ae6fac6b8387addc Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Mon, 8 Jun 2020 15:50:42 -0400 Subject: [PATCH 2/4] more tests --- .../tests/test_core/test_px/test_trendline.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py index 125319be97..6c767f1068 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py @@ -38,12 +38,25 @@ def test_trendline_enough_values(mode): fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode) assert len(fig.data) == 2 assert fig.data[1].x is None + fig = px.scatter(x=[0, 1], y=np.array([0, np.nan]), trendline=mode) + assert len(fig.data) == 2 + assert fig.data[1].x is None fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode) assert len(fig.data) == 2 assert fig.data[1].x is None + fig = px.scatter( + x=np.array([0, 1, np.nan]), y=np.array([0, np.nan, 1]), trendline=mode + ) + assert len(fig.data) == 2 + assert fig.data[1].x is None fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode) assert len(fig.data) == 2 assert len(fig.data[1].x) == 2 + fig = px.scatter( + x=np.array([0, 1, np.nan, 2]), y=np.array([1, np.nan, 1, 2]), trendline=mode + ) + assert len(fig.data) == 2 + assert len(fig.data[1].x) == 2 @pytest.mark.parametrize("mode", ["ols", "lowess"]) From 216331e267c674e3c9515f8ea07c3519f7156329 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Mon, 8 Jun 2020 15:53:51 -0400 Subject: [PATCH 3/4] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5aca477f45..f6e9fdc370 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). - Fixed special cases with `px.sunburst` and `px.treemap` with `path` input ([#2524](https://github.com/plotly/plotly.py/pull/2524)) - Fixed bug in `hover_data` argument of `px` functions, when the column name is changed with labels and `hover_data` is a dictionary setting up a specific format for the hover data ([#2544](https://github.com/plotly/plotly.py/pull/2544)). +- Made the Plotly Express `trendline` argument more robust and made it work with datetime `x` values ([#2554](https://github.com/plotly/plotly.py/pull/2554)) ## [4.8.1] - 2020-05-28 From e8234377dc3b5e043dfac3b50fb5ba52ae7f222f Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Mon, 22 Jun 2020 09:26:05 -0400 Subject: [PATCH 4/4] nicer error message --- packages/python/plotly/plotly/express/_core.py | 17 ++++++++++++++--- .../tests/test_core/test_px/test_trendline.py | 7 +++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 635bd870b9..d89794a5a4 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -291,10 +291,21 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds x_is_date = True elif x.dtype.type == np.object_: - x = x.astype(np.float64) - + try: + x = x.astype(np.float64) + except ValueError: + raise ValueError( + "Could not convert value of 'x' ('%s') into a numeric type. " + "If 'x' contains stringified dates, please convert to a datetime column." + % args["x"] + ) if y.dtype.type == np.object_: - y = y.astype(np.float64) + try: + y = y.astype(np.float64) + except ValueError: + raise ValueError( + "Could not convert value of 'y' into a numeric type." + ) if attr_value == "lowess": # missing ='drop' is the default value for lowess but not for OLS (None) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py index 6c767f1068..e908d7dee1 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py @@ -94,6 +94,13 @@ def test_no_slope_ols_trendline(): @pytest.mark.parametrize("mode", ["ols", "lowess"]) def test_trendline_on_timeseries(mode): df = px.data.stocks() + + with pytest.raises(ValueError) as err_msg: + px.scatter(df, x="date", y="GOOG", trendline=mode) + assert "Could not convert value of 'x' ('date') into a numeric type." in str( + err_msg.value + ) + df["date"] = pd.to_datetime(df["date"]) fig = px.scatter(df, x="date", y="GOOG", trendline=mode) assert len(fig.data) == 2