diff --git a/packages/python/plotly/_plotly_utils/utils.py b/packages/python/plotly/_plotly_utils/utils.py index c1ba92951d..cbf8d3a6b9 100644 --- a/packages/python/plotly/_plotly_utils/utils.py +++ b/packages/python/plotly/_plotly_utils/utils.py @@ -247,3 +247,12 @@ def key(v): return tuple(v_parts) return sorted(vals, key=key, reverse=reverse) + + +def _get_int_type(): + np = get_module("numpy", should_load=False) + if np: + int_type = (int, np.integer) + else: + int_type = (int,) + return int_type diff --git a/packages/python/plotly/plotly/basedatatypes.py b/packages/python/plotly/plotly/basedatatypes.py index f520a1669a..6a36f9bdae 100644 --- a/packages/python/plotly/plotly/basedatatypes.py +++ b/packages/python/plotly/plotly/basedatatypes.py @@ -9,7 +9,7 @@ from contextlib import contextmanager from copy import deepcopy, copy -from _plotly_utils.utils import _natural_sort_strings +from _plotly_utils.utils import _natural_sort_strings, _get_int_type from .optional_imports import get_module # Create Undefined sentinel value @@ -1560,12 +1560,7 @@ def _validate_rows_cols(name, n, vals): if len(vals) != n: BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals) - try: - import numpy as np - - int_type = (int, np.integer) - except ImportError: - int_type = (int,) + int_type = _get_int_type() if [r for r in vals if not isinstance(r, int_type)]: BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals) @@ -1677,14 +1672,19 @@ def add_traces(self, data, rows=None, cols=None, secondary_ys=None): - All remaining properties are passed to the constructor of the specified trace type. - rows : None or list[int] (default None) + rows : None, list[int], or int (default None) List of subplot row indexes (starting from 1) for the traces to be added. Only valid if figure was created using `plotly.tools.make_subplots` + If a single integer is passed, all traces will be added to row number + cols : None or list[int] (default None) List of subplot column indexes (starting from 1) for the traces to be added. Only valid if figure was created using `plotly.tools.make_subplots` + If a single integer is passed, all traces will be added to column number + + secondary_ys: None or list[boolean] (default None) List of secondary_y booleans for traces to be added. See the docstring for `add_trace` for more info. @@ -1723,6 +1723,15 @@ def add_traces(self, data, rows=None, cols=None, secondary_ys=None): for ind, new_trace in enumerate(data): new_trace._trace_ind = ind + len(self.data) + # Allow integers as inputs to subplots + int_type = _get_int_type() + + if isinstance(rows, int_type): + rows = [rows] * len(data) + + if isinstance(cols, int_type): + cols = [cols] * len(data) + # Validate rows / cols n = len(data) BaseFigure._validate_rows_cols("rows", n, rows) diff --git a/packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py b/packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py index 2322b30efa..aac4cfa972 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py +++ b/packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py @@ -63,3 +63,33 @@ def test_add_traces(self): {"type": "histogram2dcontour", "line": {"color": "cyan"}}, ] ) + + +class TestAddTracesRowsColsDataTypes(TestCase): + def test_add_traces_with_iterable(self): + import plotly.express as px + + df = px.data.tips() + fig = px.scatter(df, x="total_bill", y="tip", color="day") + from plotly.subplots import make_subplots + + fig2 = make_subplots(1, 2) + fig2.add_traces(fig.data, rows=[1,] * len(fig.data), cols=[1,] * len(fig.data)) + + expected_data_length = 4 + + self.assertEqual(expected_data_length, len(fig2.data)) + + def test_add_traces_with_integers(self): + import plotly.express as px + + df = px.data.tips() + fig = px.scatter(df, x="total_bill", y="tip", color="day") + from plotly.subplots import make_subplots + + fig2 = make_subplots(1, 2) + fig2.add_traces(fig.data, rows=1, cols=2) + + expected_data_length = 4 + + self.assertEqual(expected_data_length, len(fig2.data)) diff --git a/packages/python/plotly/plotly/tests/test_core/test_utils/test_utils.py b/packages/python/plotly/plotly/tests/test_core/test_utils/test_utils.py index d35a33d376..a3732d8525 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_utils/test_utils.py +++ b/packages/python/plotly/plotly/tests/test_core/test_utils/test_utils.py @@ -70,3 +70,28 @@ def test_numpy_integer_import(self): value = get_by_path(fig, data_path) expected_value = (1,) self.assertEqual(value, expected_value) + + def test_get_numpy_int_type(self): + import numpy as np + from _plotly_utils.utils import _get_int_type + + int_type_tuple = _get_int_type() + expected_tuple = (int, np.integer) + + self.assertEqual(int_type_tuple, expected_tuple) + + +class TestNoNumpyIntegerBaseType(TestCase): + def test_no_numpy_int_type(self): + import sys + from _plotly_utils.utils import _get_int_type + from _plotly_utils.optional_imports import get_module + + np = get_module("numpy", should_load=False) + if np: + sys.modules.pop("numpy") + + int_type_tuple = _get_int_type() + expected_tuple = (int,) + + self.assertEqual(int_type_tuple, expected_tuple)