Skip to content

Commit

Permalink
feat: add slope plots (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
premsrii authored Jan 4, 2023
1 parent d879b57 commit 252358a
Show file tree
Hide file tree
Showing 9 changed files with 289 additions and 122 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ repos:
- id: poetry-export
args:
[
"--dev",
"--with",
"dev",
"--format",
"requirements.txt",
"--output",
Expand Down
194 changes: 97 additions & 97 deletions poetry.lock

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ gitdb==4.0.10 ; python_version >= "3.8" and python_version < "3.12" \
gitpython==3.1.30 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:769c2d83e13f5d938b7688479da374c4e3d49f71549aaf462b646db9602ea6f8 \
--hash=sha256:cd455b0000615c60e286208ba540271af9fe531fa6a87cc590a7298785ab2882
identify==2.5.11 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:14b7076b29c99b1b0b8b08e96d448c7b877a9b07683cd8cfda2ea06af85ffa1c \
--hash=sha256:e7db36b772b188099616aaf2accbee122949d1c6a1bac4f38196720d6f9f06db
identify==2.5.12 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:0bc96b09c838310b6fcfcc61f78a981ea07f94836ef6ef553da5bb5d4745d662 \
--hash=sha256:e8a400c3062d980243d27ce10455a52832205649bbcaf27ffddb3dfaaf477bad
importlib-resources==5.10.2 ; python_version >= "3.8" and python_version < "3.9" \
--hash=sha256:7d543798b0beca10b6a01ac7cafda9f822c54db9e8376a6bf57e0cbd74d486b6 \
--hash=sha256:e4a96c8cc0339647ff9a5e0550d9f276fc5a01ffa276012b58ec108cfd7b8484
ipython==8.7.0 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:352042ddcb019f7c04e48171b4dd78e4c4bb67bf97030d170e154aac42b656d9 \
--hash=sha256:882899fe78d5417a0aa07f995db298fa28b58faeba2112d2e3a4c95fe14bb738
ipython==8.8.0 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:da01e6df1501e6e7c32b5084212ddadd4ee2471602e2cf3e0190f4de6b0ea481 \
--hash=sha256:f3bf2c08505ad2c3f4ed5c46ae0331a8547d36bf4b21a451e8ae80c0791db95b
isort==5.11.4 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:6db30c5ded9815d813932c04c2f85a360bcdd35fed496f4d8f35495ef0a261b6 \
--hash=sha256:c033fd0edb91000a7f09527fe5c75321878f98322a77ddcc81adbd83724afb7b
Expand All @@ -83,9 +83,9 @@ joblib==1.2.0 ; python_version >= "3.8" and python_version < "3.12" \
jsonschema==4.17.3 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:0f864437ab8b6076ba6707453ef8f98a6a0d512a80e93f8abdb676f737ecb60d \
--hash=sha256:a870ad254da1a8ca84b6a2905cac29d265f805acc57af304784962a2aa6508f6
jupyter-core==5.1.1 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:f1038179d0f179b0e92c8fa2289c012b29dafdc9484b41821079f1a496f5a0f2 \
--hash=sha256:f342d29eb6edb06f8dffa69adea987b3a9ee2b6702338a8cb6911516ea0b432d
jupyter-core==5.1.2 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:0f99cc639c8d00d591acfcc028aeea81473ea6c72fabe86426398220e2d91b1d \
--hash=sha256:62b00d52f030643d29f86aafdfd9b36d42421823599a272eb4c2df1d1cc7f723
lazy-object-proxy==1.8.0 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:0c1c7c0433154bb7c54185714c6929acc0ba04ee1b167314a779b9025517eada \
--hash=sha256:14010b49a2f56ec4943b6cf925f597b534ee2fe1f0738c84b3bce0c1a11ff10d \
Expand Down Expand Up @@ -248,9 +248,9 @@ ptyprocess==0.7.0 ; python_version >= "3.8" and python_version < "3.12" and sys_
pure-eval==0.2.2 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350 \
--hash=sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3
pygments==2.13.0 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:56a8508ae95f98e2b9bdf93a6be5ae3f7d8af858b43e02c5a2ff083726be40c1 \
--hash=sha256:f643f331ab57ba3c9d89212ee4a2dabc6e94f117cf4eefde99a0574720d14c42
pygments==2.14.0 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:b3ed06a9e8ac9a9aae5a6f5dbe78a8a58655d17b43b93c078f094ddc476ae297 \
--hash=sha256:fa7bd7bd2771287c0de303af8bfdfc731f51bd2c6a47ab69d117138893b82717
pylint==2.15.9 ; python_version >= "3.8" and python_version < "3.12" \
--hash=sha256:18783cca3cfee5b83c6c5d10b3cdb66c6594520ffae61890858fe8d932e1c6b4 \
--hash=sha256:349c8cd36aede4d50a0754a8c0218b43323d13d5d88f4b2952ddfe3e169681eb
Expand Down
127 changes: 127 additions & 0 deletions src/blitzly/plots/slope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Optional, Tuple, Union

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from numpy.typing import NDArray
from plotly.basedatatypes import BaseFigure

from blitzly.etc.utils import check_data, save_show_return, update_figure_layout


def simple_slope(
data: Union[pd.DataFrame, NDArray],
title: str = "Slope plot",
marker_size: int = 16,
marker_line_width: int = 4,
margin_size: int = 250,
plotly_kwargs: Optional[dict] = None,
size: Optional[Tuple[int, int]] = None,
show: bool = True,
write_html_path: Optional[str] = None,
) -> BaseFigure:

"""
Creates a slope plot. These are useful to show the difference between
two sets of data which have the same categories. For instance, it can be
used to compare two binary classifiers by plotting the various classification
metrics.
Example:
```python
from blitzly.plots.slope import simple_slope
import numpy as np
import pandas as pd
data = {
"foo": np.random.randn(10),
"bar": np.random.randn(10),
}
index = [f"category_{i+1}" for i in range(10)]
df = pd.DataFrame(data, index=index)
simple_slope(df)
```
Args:
data (Union[pd.DataFrame, NDArray]): Data to plot.
title (str): Title of the plot.
marker_size (int): Size of the circular marker.
marker_line_width (int): Thickness of the line joining the markers.
margin_size (int): Margin for displaying text labels in pixels.
plotly_kwargs (Optional[dict]): Additional keyword arguments to pass to Plotly `go.Scatter`.
size (Optional[Tuple[int, int]): Size of the plot.
show (bool): Whether to show the figure.
write_html_path (Optional[str]): The path to which the histogram should be written as an HTML file.
If None, the histogram will not be saved.
Returns:
BaseFigure: The slope plot.
"""

data = check_data(
data, min_rows=1, min_columns=2, max_columns=2, keep_as_pandas=True
)

if isinstance(data, np.ndarray):
data = pd.DataFrame(data)

data_max = data.to_numpy().max()
data_min = data.to_numpy().min()
data_range = data_max - data_min

y_range_max = data_max + 0.05 * data_range
y_range_min = data_min - 0.05 * data_range

fig = go.Figure()

for column_idx in range(2):
fig.add_trace(
go.Scatter(
x=[column_idx, column_idx],
y=[y_range_max, y_range_min],
mode="lines",
line={
"color": "black",
"width": 2,
},
showlegend=False,
)
)

for index, row in data.iterrows():
if row.iloc[0] > row.iloc[1]:
line_color = "red"
else:
line_color = "green"

fig.add_trace(
go.Scatter(
x=[0, 1],
y=[row.iloc[0], row.iloc[1]],
mode="markers+lines+text",
marker={
"size": marker_size,
},
line={
"width": marker_line_width,
"color": line_color,
},
text=index,
textposition=["middle left", "middle right"],
showlegend=False,
**plotly_kwargs if plotly_kwargs else {},
)
)

xaxis_offset = margin_size / size[0] if size is not None else 1
fig.update_layout(
xaxis={
"tickvals": [0, 1],
"ticktext": data.columns,
"range": [-xaxis_offset, 1 + xaxis_offset],
}
)

fig = update_figure_layout(fig, title, size)
return save_show_return(fig, write_html_path, show)
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np
import pandas as pd
import pytest

# pylint: disable=missing-function-docstring


@pytest.fixture()
def X_numbers_two_column() -> pd.DataFrame:
np.random.seed(42)
data = {
"foo": np.random.rand(10),
"bar": np.random.rand(10),
}
index = [f"category_{i+1}" for i in range(10)]
return pd.DataFrame(data, index=index)
Binary file not shown.
Binary file not shown.
14 changes: 2 additions & 12 deletions tests/test_cases/test_dumbbell.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import joblib
import numpy as np
import pandas as pd
import pytest

from blitzly.plots.dumbbell import simple_dumbbell
Expand All @@ -25,17 +24,8 @@ def expected_2d_numpy():

class TestSimpleDumbbell:
@staticmethod
def test_simple_dumbbell_with_pandas(expected_pandas):
np.random.seed(42)
data = {
"foo": np.random.rand(10),
"bar": np.random.rand(10),
}
index = [f"category_{i+1}" for i in range(10)]
df = pd.DataFrame(data, index=index)

fig = simple_dumbbell(df, size=(500, 500), show=False)

def test_simple_dumbbell_with_pandas(X_numbers_two_column, expected_pandas):
fig = simple_dumbbell(X_numbers_two_column, size=(500, 500), show=False)
np.testing.assert_equal(fig_to_array(fig), fig_to_array(expected_pandas))

@staticmethod
Expand Down
33 changes: 33 additions & 0 deletions tests/test_cases/test_slope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import joblib
import numpy as np
import pytest

from blitzly.plots.slope import simple_slope
from tests.helper import fig_to_array

# pylint: disable=missing-function-docstring, missing-class-docstring, redefined-outer-name


@pytest.fixture(scope="session")
def expected_pandas():
return joblib.load("tests/expected_figs/slope/simple_slope/expected_pandas.joblib")


@pytest.fixture(scope="session")
def expected_2d_numpy():
return joblib.load(
"tests/expected_figs/slope/simple_slope/expected_2d_numpy.joblib"
)


class TestSimpleSlope:
@staticmethod
def test_simple_slope_with_pandas(X_numbers_two_column, expected_pandas):
fig = simple_slope(X_numbers_two_column, size=(500, 500), show=False)
np.testing.assert_equal(fig_to_array(fig), fig_to_array(expected_pandas))

@staticmethod
def test_simple_slope_with_2d_numpy(expected_2d_numpy):
np.random.seed(42)
fig = simple_slope(np.random.randn(10, 2), size=(500, 500), show=False)
np.testing.assert_equal(fig_to_array(fig), fig_to_array(expected_2d_numpy))

0 comments on commit 252358a

Please sign in to comment.