Skip to content

Commit

Permalink
refactor: utils.check_data (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
premsrii authored Jan 4, 2023
1 parent 252358a commit f335f5a
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 140 deletions.
160 changes: 80 additions & 80 deletions poetry.lock

Large diffs are not rendered by default.

44 changes: 22 additions & 22 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -366,28 +366,28 @@ scikit-learn==1.2.0 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:f17420a8e3f40129aeb7e0f5ee35822d6178617007bb8f69521a2cefc20d5f00 \
--hash=sha256:fc0a72237f0c56780cf550df87201a702d3bdcbbb23c6ef7d54c19326fa23f19 \
--hash=sha256:fd3480c982b9e616b9f76ad8587804d3f4e91b4e2a6752e7dafb8a2e1f541098
scipy==1.9.3 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:06d2e1b4c491dc7d8eacea139a1b0b295f74e1a1a0f704c375028f8320d16e31 \
--hash=sha256:0d54222d7a3ba6022fdf5773931b5d7c56efe41ede7f7128c7b1637700409108 \
--hash=sha256:1884b66a54887e21addf9c16fb588720a8309a57b2e258ae1c7986d4444d3bc0 \
--hash=sha256:1a72d885fa44247f92743fc20732ae55564ff2a519e8302fb7e18717c5355a8b \
--hash=sha256:2318bef588acc7a574f5bfdff9c172d0b1bf2c8143d9582e05f878e580a3781e \
--hash=sha256:4db5b30849606a95dcf519763dd3ab6fe9bd91df49eba517359e450a7d80ce2e \
--hash=sha256:545c83ffb518094d8c9d83cce216c0c32f8c04aaf28b92cc8283eda0685162d5 \
--hash=sha256:5a04cd7d0d3eff6ea4719371cbc44df31411862b9646db617c99718ff68d4840 \
--hash=sha256:5b88e6d91ad9d59478fafe92a7c757d00c59e3bdc3331be8ada76a4f8d683f58 \
--hash=sha256:68239b6aa6f9c593da8be1509a05cb7f9efe98b80f43a5861cd24c7557e98523 \
--hash=sha256:83b89e9586c62e787f5012e8475fbb12185bafb996a03257e9675cd73d3736dd \
--hash=sha256:83c06e62a390a9167da60bedd4575a14c1f58ca9dfde59830fc42e5197283dab \
--hash=sha256:90453d2b93ea82a9f434e4e1cba043e779ff67b92f7a0e85d05d286a3625df3c \
--hash=sha256:abaf921531b5aeaafced90157db505e10345e45038c39e5d9b6c7922d68085cb \
--hash=sha256:b41bc822679ad1c9a5f023bc93f6d0543129ca0f37c1ce294dd9d386f0a21096 \
--hash=sha256:c68db6b290cbd4049012990d7fe71a2abd9ffbe82c0056ebe0f01df8be5436b0 \
--hash=sha256:cff3a5295234037e39500d35316a4c5794739433528310e117b8a9a0c76d20fc \
--hash=sha256:d01e1dd7b15bd2449c8bfc6b7cc67d630700ed655654f0dfcf121600bad205c9 \
--hash=sha256:d644a64e174c16cb4b2e41dfea6af722053e83d066da7343f333a54dae9bc31c \
--hash=sha256:da8245491d73ed0a994ed9c2e380fd058ce2fa8a18da204681f2fe1f57f98f95 \
--hash=sha256:fbc5c05c85c1a02be77b1ff591087c83bc44579c6d2bd9fb798bb64ea5e1a027
scipy==1.10.0 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:0490dc499fe23e4be35b8b6dd1e60a4a34f0c4adb30ac671e6332446b3cbbb5a \
--hash=sha256:0ab2a58064836632e2cec31ca197d3695c86b066bc4818052b3f5381bfd2a728 \
--hash=sha256:151f066fe7d6653c3ffefd489497b8fa66d7316e3e0d0c0f7ff6acca1b802809 \
--hash=sha256:16ba05d3d1b9f2141004f3f36888e05894a525960b07f4c2bfc0456b955a00be \
--hash=sha256:27e548276b5a88b51212b61f6dda49a24acf5d770dff940bd372b3f7ced8c6c2 \
--hash=sha256:2ad449db4e0820e4b42baccefc98ec772ad7818dcbc9e28b85aa05a536b0f1a2 \
--hash=sha256:2f9ea0a37aca111a407cb98aa4e8dfde6e5d9333bae06dfa5d938d14c80bb5c3 \
--hash=sha256:38bfbd18dcc69eeb589811e77fae552fa923067fdfbb2e171c9eac749885f210 \
--hash=sha256:3afcbddb4488ac950ce1147e7580178b333a29cd43524c689b2e3543a080a2c8 \
--hash=sha256:42ab8b9e7dc1ebe248e55f54eea5307b6ab15011a7883367af48dd781d1312e4 \
--hash=sha256:441cab2166607c82e6d7a8683779cb89ba0f475b983c7e4ab88f3668e268c143 \
--hash=sha256:4bd0e3278126bc882d10414436e58fa3f1eca0aa88b534fcbf80ed47e854f46c \
--hash=sha256:4df25a28bd22c990b22129d3c637fd5c3be4b7c94f975dca909d8bab3309b694 \
--hash=sha256:5cd7a30970c29d9768a7164f564d1fbf2842bfc77b7d114a99bc32703ce0bf48 \
--hash=sha256:6e4497e5142f325a5423ff5fda2fff5b5d953da028637ff7c704378c8c284ea7 \
--hash=sha256:6faf86ef7717891195ae0537e48da7524d30bc3b828b30c9b115d04ea42f076f \
--hash=sha256:954ff69d2d1bf666b794c1d7216e0a746c9d9289096a64ab3355a17c7c59db54 \
--hash=sha256:9b878c671655864af59c108c20e4da1e796154bd78c0ed6bb02bc41c84625686 \
--hash=sha256:b901b423c91281a974f6cd1c36f5c6c523e665b5a6d5e80fcb2334e14670eefd \
--hash=sha256:c8b3cbc636a87a89b770c6afc999baa6bcbb01691b5ccbbc1b1791c7c0a07540 \
--hash=sha256:e096b062d2efdea57f972d232358cb068413dc54eec4f24158bcbb5cb8bddfd8
setuptools==65.6.3 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:57f6f22bde4e042978bcd50176fdb381d7c21a9efa4041202288d3737a0c6a54 \
--hash=sha256:a7620757bf984b58deaf32fc8a4577a9bbc0850cf92c20e1ce41c38c19e5fb75
Expand Down
11 changes: 9 additions & 2 deletions src/blitzly/etc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def check_data(
max_rows: Optional[int] = None,
min_columns: Optional[int] = None,
max_columns: Optional[int] = None,
keep_as_pandas: bool = False,
as_pandas: bool = False,
) -> Union[NDArray[Any], pd.DataFrame, pd.Series]:
"""
Checks if the data is valid for plotting. The function checks for:
Expand Down Expand Up @@ -109,6 +109,10 @@ def check_data(
max_rows (Optional[int]): The maximum number of rows the data must have.
min_columns (Optional[int]): The minimum number of columns the data must have.
max_columns (Optional[int]): The maximum number of columns the data must have.
as_pandas (bool): Whether to keep data as or convert data to pd.DataFrame
Returns:
Union[pd.DataFrame, NDArray]: The data that passes all checks, and is converted to the required dtype.
Raises:
TypeError: If the data is not a DataFrame, numpy array, or list of values.
Expand All @@ -125,6 +129,9 @@ def check_data(
"""
)

if isinstance(data, np.ndarray) and as_pandas:
df = pd.DataFrame(data)

if isinstance(data, (pd.DataFrame, pd.Series)):
df = data.copy()
data = data.to_numpy()
Expand Down Expand Up @@ -160,7 +167,7 @@ def check_data(
if max_columns and data.shape[1] > max_columns:
raise ValueError(f"The data must have a maximum of {max_columns} column(s)!")

if "df" in locals() and keep_as_pandas:
if "df" in locals() and as_pandas:
return df

return data.copy()
8 changes: 1 addition & 7 deletions src/blitzly/plots/dumbbell.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional, Tuple, Union

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from numpy.typing import NDArray
Expand Down Expand Up @@ -59,12 +58,7 @@ def simple_dumbbell(
BaseFigure: The dumbbell plot.
"""

data = check_data(
data, min_rows=1, min_columns=2, max_columns=2, keep_as_pandas=True
)

if isinstance(data, np.ndarray):
data = pd.DataFrame(data)
data = check_data(data, min_rows=1, min_columns=2, max_columns=2, as_pandas=True)

fig = go.Figure()

Expand Down
5 changes: 1 addition & 4 deletions src/blitzly/plots/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,7 @@ def multi_scatter(
BaseFigure: The multi scatter plot.
"""

df = check_data(data, min_rows=1, min_columns=1, keep_as_pandas=True)

if isinstance(df, pd.DataFrame) is False:
raise TypeError("`data` must be a Pandas DataFrame!")
df = check_data(data, min_rows=1, min_columns=1, as_pandas=True)

if len([i for i in list(sum(x_y_columns, ())) if i not in df.columns]) > 0:
raise ValueError(
Expand Down
8 changes: 1 addition & 7 deletions src/blitzly/plots/slope.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional, Tuple, Union

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from numpy.typing import NDArray
Expand Down Expand Up @@ -59,12 +58,7 @@ def simple_slope(
BaseFigure: The slope plot.
"""

data = check_data(
data, min_rows=1, min_columns=2, max_columns=2, keep_as_pandas=True
)

if isinstance(data, np.ndarray):
data = pd.DataFrame(data)
data = check_data(data, min_rows=1, min_columns=2, max_columns=2, as_pandas=True)

data_max = data.to_numpy().max()
data_min = data.to_numpy().min()
Expand Down
17 changes: 0 additions & 17 deletions tests/test_cases/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,6 @@ def test_multi_scatter_with_valid_values_and_size(
fig_to_array(expected_multi_scatter_with_valid_values_size),
)

@staticmethod
def test_multi_scatter_no_dataframe_exception():
np.random.seed(42)
random_a = np.linspace(0, 1, 100)
random_b = np.random.randn(100) + 5
random_c = np.random.randn(100)
random_d = np.random.randn(100) - 5

with pytest.raises(TypeError) as error:
_ = multi_scatter(
np.array([random_a, random_b, random_c, random_d]).T,
x_y_columns=[("a", "b"), ("a", "c"), ("a", "d")],
modes=["lines", "markers", "lines+markers"],
show=False,
)
assert str(error.value) == "`data` must be a Pandas DataFrame!"

@staticmethod
def test_multi_scatter_incompatible_columns_exception():
np.random.seed(42)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_cases/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,10 @@ def test_check_data_for_max_columns():

@staticmethod
def test_check_data_with_keep_as_pandas():
data = check_data(pd.DataFrame(np.array([[1, 2]])), keep_as_pandas=True)
data = check_data(pd.DataFrame(np.array([[1, 2]])), as_pandas=True)
assert isinstance(data, pd.DataFrame)

@staticmethod
def test_check_data_with_convert_to_pandas():
data = check_data(np.array([[1, 2]]), as_pandas=True)
assert isinstance(data, pd.DataFrame)

0 comments on commit f335f5a

Please sign in to comment.