Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug_2410: Allowing Ints to be passed for rows/cols and refactored int… #2546

Merged
merged 3 commits into from
Jun 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions packages/python/plotly/_plotly_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,12 @@ def key(v):
return tuple(v_parts)

return sorted(vals, key=key, reverse=reverse)


def _get_int_type():
np = get_module("numpy", should_load=False)
if np:
int_type = (int, np.integer)
else:
int_type = (int,)
return int_type
25 changes: 17 additions & 8 deletions packages/python/plotly/plotly/basedatatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextlib import contextmanager
from copy import deepcopy, copy

from _plotly_utils.utils import _natural_sort_strings
from _plotly_utils.utils import _natural_sort_strings, _get_int_type
from .optional_imports import get_module

# Create Undefined sentinel value
Expand Down Expand Up @@ -1560,12 +1560,7 @@ def _validate_rows_cols(name, n, vals):
if len(vals) != n:
BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals)

try:
import numpy as np

int_type = (int, np.integer)
except ImportError:
int_type = (int,)
int_type = _get_int_type()

if [r for r in vals if not isinstance(r, int_type)]:
BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals)
Expand Down Expand Up @@ -1677,14 +1672,19 @@ def add_traces(self, data, rows=None, cols=None, secondary_ys=None):
- All remaining properties are passed to the constructor
of the specified trace type.

rows : None or list[int] (default None)
rows : None, list[int], or int (default None)
List of subplot row indexes (starting from 1) for the traces to be
added. Only valid if figure was created using
`plotly.tools.make_subplots`
If a single integer is passed, all traces will be added to row number

cols : None or list[int] (default None)
List of subplot column indexes (starting from 1) for the traces
to be added. Only valid if figure was created using
`plotly.tools.make_subplots`
If a single integer is passed, all traces will be added to column number


secondary_ys: None or list[boolean] (default None)
List of secondary_y booleans for traces to be added. See the
docstring for `add_trace` for more info.
Expand Down Expand Up @@ -1723,6 +1723,15 @@ def add_traces(self, data, rows=None, cols=None, secondary_ys=None):
for ind, new_trace in enumerate(data):
new_trace._trace_ind = ind + len(self.data)

# Allow integers as inputs to subplots
int_type = _get_int_type()

if isinstance(rows, int_type):
rows = [rows] * len(data)

if isinstance(cols, int_type):
cols = [cols] * len(data)

# Validate rows / cols
n = len(data)
BaseFigure._validate_rows_cols("rows", n, rows)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,33 @@ def test_add_traces(self):
{"type": "histogram2dcontour", "line": {"color": "cyan"}},
]
)


class TestAddTracesRowsColsDataTypes(TestCase):
def test_add_traces_with_iterable(self):
import plotly.express as px

df = px.data.tips()
fig = px.scatter(df, x="total_bill", y="tip", color="day")
from plotly.subplots import make_subplots

fig2 = make_subplots(1, 2)
fig2.add_traces(fig.data, rows=[1,] * len(fig.data), cols=[1,] * len(fig.data))

expected_data_length = 4

self.assertEqual(expected_data_length, len(fig2.data))

def test_add_traces_with_integers(self):
import plotly.express as px

df = px.data.tips()
fig = px.scatter(df, x="total_bill", y="tip", color="day")
from plotly.subplots import make_subplots

fig2 = make_subplots(1, 2)
fig2.add_traces(fig.data, rows=1, cols=2)

expected_data_length = 4

self.assertEqual(expected_data_length, len(fig2.data))
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,28 @@ def test_numpy_integer_import(self):
value = get_by_path(fig, data_path)
expected_value = (1,)
self.assertEqual(value, expected_value)

def test_get_numpy_int_type(self):
import numpy as np
from _plotly_utils.utils import _get_int_type

int_type_tuple = _get_int_type()
expected_tuple = (int, np.integer)

self.assertEqual(int_type_tuple, expected_tuple)


class TestNoNumpyIntegerBaseType(TestCase):
def test_no_numpy_int_type(self):
import sys
from _plotly_utils.utils import _get_int_type
from _plotly_utils.optional_imports import get_module

np = get_module("numpy", should_load=False)
if np:
sys.modules.pop("numpy")

int_type_tuple = _get_int_type()
expected_tuple = (int,)

self.assertEqual(int_type_tuple, expected_tuple)