From 3ac8507706b9309f6eeabfb45348445a4c9b6add Mon Sep 17 00:00:00 2001 From: Michael Chapman Date: Fri, 5 Jun 2020 15:11:45 -0500 Subject: [PATCH 1/3] bug_2410: Allowing Ints to be passed for rows/cols and refactored int checks --- packages/python/plotly/_plotly_utils/utils.py | 9 ++++++ .../python/plotly/plotly/basedatatypes.py | 25 +++++++++++----- .../test_figure_messages/test_add_traces.py | 30 +++++++++++++++++++ .../tests/test_core/test_utils/test_utils.py | 25 ++++++++++++++++ 4 files changed, 81 insertions(+), 8 deletions(-) diff --git a/packages/python/plotly/_plotly_utils/utils.py b/packages/python/plotly/_plotly_utils/utils.py index c1ba92951d..cbf8d3a6b9 100644 --- a/packages/python/plotly/_plotly_utils/utils.py +++ b/packages/python/plotly/_plotly_utils/utils.py @@ -247,3 +247,12 @@ def key(v): return tuple(v_parts) return sorted(vals, key=key, reverse=reverse) + + +def _get_int_type(): + np = get_module("numpy", should_load=False) + if np: + int_type = (int, np.integer) + else: + int_type = (int,) + return int_type diff --git a/packages/python/plotly/plotly/basedatatypes.py b/packages/python/plotly/plotly/basedatatypes.py index f520a1669a..cce3eb86c9 100644 --- a/packages/python/plotly/plotly/basedatatypes.py +++ b/packages/python/plotly/plotly/basedatatypes.py @@ -9,7 +9,7 @@ from contextlib import contextmanager from copy import deepcopy, copy -from _plotly_utils.utils import _natural_sort_strings +from _plotly_utils.utils import _natural_sort_strings, _get_int_type from .optional_imports import get_module # Create Undefined sentinel value @@ -1560,12 +1560,7 @@ def _validate_rows_cols(name, n, vals): if len(vals) != n: BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals) - try: - import numpy as np - - int_type = (int, np.integer) - except ImportError: - int_type = (int,) + int_type = _get_int_type() if [r for r in vals if not isinstance(r, int_type)]: BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals) @@ -1677,14 +1672,19 @@ def add_traces(self, data, rows=None, cols=None, secondary_ys=None): - All remaining properties are passed to the constructor of the specified trace type. - rows : None or list[int] (default None) + rows : None, list[int], or int (default None) List of subplot row indexes (starting from 1) for the traces to be added. Only valid if figure was created using `plotly.tools.make_subplots` + If a single integer is added, all traces will be added to row number + cols : None or list[int] (default None) List of subplot column indexes (starting from 1) for the traces to be added. Only valid if figure was created using `plotly.tools.make_subplots` + If a single integer is added, all traces will be added to column number + + secondary_ys: None or list[boolean] (default None) List of secondary_y booleans for traces to be added. See the docstring for `add_trace` for more info. @@ -1723,6 +1723,15 @@ def add_traces(self, data, rows=None, cols=None, secondary_ys=None): for ind, new_trace in enumerate(data): new_trace._trace_ind = ind + len(self.data) + # Allow integers as inputs to subplots + int_type = _get_int_type() + + if isinstance(rows, int_type): + rows = [rows] * len(data) + + if isinstance(cols, int_type): + cols = [cols] * len(data) + # Validate rows / cols n = len(data) BaseFigure._validate_rows_cols("rows", n, rows) diff --git a/packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py b/packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py index 2322b30efa..aac4cfa972 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py +++ b/packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py @@ -63,3 +63,33 @@ def test_add_traces(self): {"type": "histogram2dcontour", "line": {"color": "cyan"}}, ] ) + + +class TestAddTracesRowsColsDataTypes(TestCase): + def test_add_traces_with_iterable(self): + import plotly.express as px + + df = px.data.tips() + fig = px.scatter(df, x="total_bill", y="tip", color="day") + from plotly.subplots import make_subplots + + fig2 = make_subplots(1, 2) + fig2.add_traces(fig.data, rows=[1,] * len(fig.data), cols=[1,] * len(fig.data)) + + expected_data_length = 4 + + self.assertEqual(expected_data_length, len(fig2.data)) + + def test_add_traces_with_integers(self): + import plotly.express as px + + df = px.data.tips() + fig = px.scatter(df, x="total_bill", y="tip", color="day") + from plotly.subplots import make_subplots + + fig2 = make_subplots(1, 2) + fig2.add_traces(fig.data, rows=1, cols=2) + + expected_data_length = 4 + + self.assertEqual(expected_data_length, len(fig2.data)) diff --git a/packages/python/plotly/plotly/tests/test_core/test_utils/test_utils.py b/packages/python/plotly/plotly/tests/test_core/test_utils/test_utils.py index d35a33d376..a3732d8525 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_utils/test_utils.py +++ b/packages/python/plotly/plotly/tests/test_core/test_utils/test_utils.py @@ -70,3 +70,28 @@ def test_numpy_integer_import(self): value = get_by_path(fig, data_path) expected_value = (1,) self.assertEqual(value, expected_value) + + def test_get_numpy_int_type(self): + import numpy as np + from _plotly_utils.utils import _get_int_type + + int_type_tuple = _get_int_type() + expected_tuple = (int, np.integer) + + self.assertEqual(int_type_tuple, expected_tuple) + + +class TestNoNumpyIntegerBaseType(TestCase): + def test_no_numpy_int_type(self): + import sys + from _plotly_utils.utils import _get_int_type + from _plotly_utils.optional_imports import get_module + + np = get_module("numpy", should_load=False) + if np: + sys.modules.pop("numpy") + + int_type_tuple = _get_int_type() + expected_tuple = (int,) + + self.assertEqual(int_type_tuple, expected_tuple) From 30f5dd55ad01d85e1692b035b4dd33c19da0bd33 Mon Sep 17 00:00:00 2001 From: Michael Chapman Date: Mon, 22 Jun 2020 16:38:38 -0500 Subject: [PATCH 2/3] Update packages/python/plotly/plotly/basedatatypes.py Quick cosmetic update to the docstring Co-authored-by: Emmanuelle Gouillart --- packages/python/plotly/plotly/basedatatypes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/basedatatypes.py b/packages/python/plotly/plotly/basedatatypes.py index cce3eb86c9..8e474abfe2 100644 --- a/packages/python/plotly/plotly/basedatatypes.py +++ b/packages/python/plotly/plotly/basedatatypes.py @@ -1676,7 +1676,7 @@ def add_traces(self, data, rows=None, cols=None, secondary_ys=None): List of subplot row indexes (starting from 1) for the traces to be added. Only valid if figure was created using `plotly.tools.make_subplots` - If a single integer is added, all traces will be added to row number + If a single integer is passed, all traces will be added to row number cols : None or list[int] (default None) List of subplot column indexes (starting from 1) for the traces From 4efac29b6de54cb241a7f8128f59cb3a12d4b794 Mon Sep 17 00:00:00 2001 From: Michael Chapman Date: Mon, 22 Jun 2020 16:39:01 -0500 Subject: [PATCH 3/3] Update packages/python/plotly/plotly/basedatatypes.py Quick cosmetic update to the docstring Co-authored-by: Emmanuelle Gouillart --- packages/python/plotly/plotly/basedatatypes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/basedatatypes.py b/packages/python/plotly/plotly/basedatatypes.py index 8e474abfe2..6a36f9bdae 100644 --- a/packages/python/plotly/plotly/basedatatypes.py +++ b/packages/python/plotly/plotly/basedatatypes.py @@ -1682,7 +1682,7 @@ def add_traces(self, data, rows=None, cols=None, secondary_ys=None): List of subplot column indexes (starting from 1) for the traces to be added. Only valid if figure was created using `plotly.tools.make_subplots` - If a single integer is added, all traces will be added to column number + If a single integer is passed, all traces will be added to column number secondary_ys: None or list[boolean] (default None)