Skip to content

Commit

Permalink
Merge pull request #2554 from plotly/trendline_fix
Browse files Browse the repository at this point in the history
make trendlines more robust
  • Loading branch information
nicolaskruchten authored Jun 22, 2020
2 parents fd3b741 + e823437 commit 8be4915
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).

- Fixed special cases with `px.sunburst` and `px.treemap` with `path` input ([#2524](https://github.com/plotly/plotly.py/pull/2524))
- Fixed bug in `hover_data` argument of `px` functions, when the column name is changed with labels and `hover_data` is a dictionary setting up a specific format for the hover data ([#2544](https://github.com/plotly/plotly.py/pull/2544)).
- Made the Plotly Express `trendline` argument more robust and made it work with datetime `x` values ([#2554](https://github.com/plotly/plotly.py/pull/2554))

## [4.8.1] - 2020-05-28

Expand Down
47 changes: 36 additions & 11 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,17 +277,35 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
attr_value in ["ols", "lowess"]
and args["x"]
and args["y"]
and len(trace_data) > 1
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
):
import statsmodels.api as sm

# sorting is bad but trace_specs with "trendline" have no other attrs
sorted_trace_data = trace_data.sort_values(by=args["x"])
y = sorted_trace_data[args["y"]]
x = sorted_trace_data[args["x"]]
y = sorted_trace_data[args["y"]].values
x = sorted_trace_data[args["x"]].values

x_is_date = False
if x.dtype.type == np.datetime64:
x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds
x_is_date = True
elif x.dtype.type == np.object_:
try:
x = x.astype(np.float64)
except ValueError:
raise ValueError(
"Could not convert value of 'x' ('%s') into a numeric type. "
"If 'x' contains stringified dates, please convert to a datetime column."
% args["x"]
)
if y.dtype.type == np.object_:
try:
y = y.astype(np.float64)
except ValueError:
raise ValueError(
"Could not convert value of 'y' into a numeric type."
)

if attr_value == "lowess":
# missing ='drop' is the default value for lowess but not for OLS (None)
Expand All @@ -298,25 +316,32 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
hover_header = "<b>LOWESS trendline</b><br><br>"
elif attr_value == "ols":
fit_results = sm.OLS(
y.values, sm.add_constant(x.values), missing="drop"
y, sm.add_constant(x), missing="drop"
).fit()
trace_patch["y"] = fit_results.predict()
trace_patch["x"] = x[
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
]
hover_header = "<b>OLS trendline</b><br>"
hover_header += "%s = %g * %s + %g<br>" % (
args["y"],
fit_results.params[1],
args["x"],
fit_results.params[0],
)
if len(fit_results.params) == 2:
hover_header += "%s = %g * %s + %g<br>" % (
args["y"],
fit_results.params[1],
args["x"],
fit_results.params[0],
)
else:
hover_header += "%s = %g<br>" % (
args["y"],
fit_results.params[0],
)
hover_header += (
"R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
)
if x_is_date:
trace_patch["x"] = pd.to_datetime(trace_patch["x"] * 10 ** 9)
mapping_labels[get_label(args, args["x"])] = "%{x}"
mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>"

elif attr_name.startswith("error"):
error_xy = attr_name[:7]
arr = "arrayminus" if attr_name.endswith("minus") else "array"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,110 @@
import plotly.express as px
import numpy as np
import pandas as pd
import pytest
from datetime import datetime


def test_trendline_nan_values():
@pytest.mark.parametrize("mode", ["ols", "lowess"])
def test_trendline_results_passthrough(mode):
df = px.data.gapminder().query("continent == 'Oceania'")
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
assert len(fig.data) == 4
for trace in fig["data"][0::2]:
assert "trendline" not in trace.hovertemplate
for trendline in fig["data"][1::2]:
assert "trendline" in trendline.hovertemplate
if mode == "ols":
assert "R<sup>2</sup>" in trendline.hovertemplate
results = px.get_trendline_results(fig)
if mode == "ols":
assert len(results) == 2
assert results["country"].values[0] == "Australia"
assert results["country"].values[0] == "Australia"
au_result = results["px_fit_results"].values[0]
assert len(au_result.params) == 2
else:
assert len(results) == 0


@pytest.mark.parametrize("mode", ["ols", "lowess"])
def test_trendline_enough_values(mode):
fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode)
assert len(fig.data) == 2
assert len(fig.data[1].x) == 2
fig = px.scatter(x=[0], y=[0], trendline=mode)
assert len(fig.data) == 2
assert fig.data[1].x is None
fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode)
assert len(fig.data) == 2
assert fig.data[1].x is None
fig = px.scatter(x=[0, 1], y=np.array([0, np.nan]), trendline=mode)
assert len(fig.data) == 2
assert fig.data[1].x is None
fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode)
assert len(fig.data) == 2
assert fig.data[1].x is None
fig = px.scatter(
x=np.array([0, 1, np.nan]), y=np.array([0, np.nan, 1]), trendline=mode
)
assert len(fig.data) == 2
assert fig.data[1].x is None
fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode)
assert len(fig.data) == 2
assert len(fig.data[1].x) == 2
fig = px.scatter(
x=np.array([0, 1, np.nan, 2]), y=np.array([1, np.nan, 1, 2]), trendline=mode
)
assert len(fig.data) == 2
assert len(fig.data[1].x) == 2


@pytest.mark.parametrize("mode", ["ols", "lowess"])
def test_trendline_nan_values(mode):
df = px.data.gapminder().query("continent == 'Oceania'")
start_date = 1970
df["pop"][df["year"] < start_date] = np.nan
modes = ["ols", "lowess"]
for mode in modes:
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
for trendline in fig["data"][1::2]:
assert trendline.x[0] >= start_date
assert len(trendline.x) == len(trendline.y)
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
for trendline in fig["data"][1::2]:
assert trendline.x[0] >= start_date
assert len(trendline.x) == len(trendline.y)


def test_no_slope_ols_trendline():
fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols")
assert "y = 1" in fig.data[1].hovertemplate # then + x*(some small number)
results = px.get_trendline_results(fig)
params = results["px_fit_results"].iloc[0].params
assert np.all(np.isclose(params, [0, 1]))

fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols")
assert "y = 0" in fig.data[1].hovertemplate
results = px.get_trendline_results(fig)
params = results["px_fit_results"].iloc[0].params
assert np.all(np.isclose(params, [0]))

fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols")
assert "y = 0" in fig.data[1].hovertemplate
fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols")
assert "y = 0 * x + 1" in fig.data[1].hovertemplate
fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols")
assert "y = 0 * x + 1.5" in fig.data[1].hovertemplate


@pytest.mark.parametrize("mode", ["ols", "lowess"])
def test_trendline_on_timeseries(mode):
df = px.data.stocks()

with pytest.raises(ValueError) as err_msg:
px.scatter(df, x="date", y="GOOG", trendline=mode)
assert "Could not convert value of 'x' ('date') into a numeric type." in str(
err_msg.value
)

df["date"] = pd.to_datetime(df["date"])
fig = px.scatter(df, x="date", y="GOOG", trendline=mode)
assert len(fig.data) == 2
assert len(fig.data[0].x) == len(fig.data[1].x)
assert type(fig.data[0].x[0]) == datetime
assert type(fig.data[1].x[0]) == datetime
assert np.all(fig.data[0].x == fig.data[1].x)

0 comments on commit 8be4915

Please sign in to comment.