diff --git a/examples/playground.ipynb b/examples/playground.ipynb index 928b89e..c18272a 100644 --- a/examples/playground.ipynb +++ b/examples/playground.ipynb @@ -81,32 +81,31 @@ "\n", "**In Plotly you do:**\n", "```python\n", - "fig = go.Figure(\n", - " data=[\n", - " go.Scatter(\n", - " x=data.iloc[:, 0],\n", - " y=data.index,\n", - " mode=\"markers\",\n", - " name=data.columns[0],\n", - " ),\n", - " go.Scatter(\n", - " x=data.iloc[:, 1],\n", - " y=data.index,\n", - " mode=\"markers\",\n", - " name=data.columns[1],\n", - " ),\n", - " ]\n", - " )\n", + "fig = go.Figure()\n", "\n", "for index, row in data.iterrows():\n", - " fig.add_shape(\n", - " type=\"line\",\n", - " layer=\"below\",\n", - " x0=row.iloc[0],\n", - " x1=row.iloc[1],\n", - " y0=index,\n", - " y1=index,\n", - " line=dict(width=8),\n", + " fig.add_trace(\n", + " go.Scatter(\n", + " x=[row.iloc[0], row.iloc[1]],\n", + " y=[index, index],\n", + " mode=\"lines\",\n", + " showlegend=False,\n", + " line={\n", + " \"color\": \"black\",\n", + " \"width\": marker_line_width,\n", + " },\n", + " )\n", + " )\n", + "\n", + "for column_idx, column_name in enumerate(data.columns):\n", + " fig.add_trace(\n", + " go.Scatter(\n", + " x=data.iloc[:, column_idx],\n", + " y=data.index,\n", + " mode=\"markers\",\n", + " name=column_name,\n", + " **plotly_kwargs if plotly_kwargs else {},\n", + " )\n", " )\n", "\n", "fig.update_traces(\n", @@ -117,8 +116,8 @@ " title=f\"Dumbbell plot\",\n", ")\n", "fig.update_layout(\n", - " width=size[0],\n", - " height=size[1],\n", + " width=500,\n", + " height=500,\n", ")\n", "```\n", "\n", @@ -225,7 +224,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0" + "version": "3.11.0 | packaged by conda-forge | (main, Oct 25 2022, 06:21:25) [Clang 14.0.4 ]" }, "vscode": { "interpreter": { diff --git a/src/blitzly/plots/dumbbell.py b/src/blitzly/plots/dumbbell.py index 8de3065..e91b0f4 100644 --- a/src/blitzly/plots/dumbbell.py +++ b/src/blitzly/plots/dumbbell.py @@ -66,34 +66,31 @@ def simple_dumbbell( if isinstance(data, np.ndarray): data = pd.DataFrame(data) - fig = go.Figure( - data=[ + fig = go.Figure() + + for index, row in data.iterrows(): + fig.add_trace( go.Scatter( - x=data.iloc[:, 0], - y=data.index, - mode="markers", - name=data.columns[0], - **plotly_kwargs if plotly_kwargs else {}, - ), + x=[row.iloc[0], row.iloc[1]], + y=[index, index], + mode="lines", + showlegend=False, + line={ + "color": "black", + "width": marker_line_width, + }, + ) + ) + + for column_idx, column_name in enumerate(data.columns): + fig.add_trace( go.Scatter( - x=data.iloc[:, 1], + x=data.iloc[:, column_idx], y=data.index, mode="markers", - name=data.columns[1], + name=column_name, **plotly_kwargs if plotly_kwargs else {}, - ), - ] - ) - - for index, row in data.iterrows(): - fig.add_shape( - type="line", - layer="below", - x0=row.iloc[0], - x1=row.iloc[1], - y0=index, - y1=index, - line=dict(width=marker_line_width), + ) ) fig.update_traces( diff --git a/tests/expected_figs/dumbbell/simple_dumbbell/expected_2d_numpy.joblib b/tests/expected_figs/dumbbell/simple_dumbbell/expected_2d_numpy.joblib index 7dc4a3b..3d12981 100644 Binary files a/tests/expected_figs/dumbbell/simple_dumbbell/expected_2d_numpy.joblib and b/tests/expected_figs/dumbbell/simple_dumbbell/expected_2d_numpy.joblib differ diff --git a/tests/expected_figs/dumbbell/simple_dumbbell/expected_pandas.joblib b/tests/expected_figs/dumbbell/simple_dumbbell/expected_pandas.joblib index 614941e..bd48e92 100644 Binary files a/tests/expected_figs/dumbbell/simple_dumbbell/expected_pandas.joblib and b/tests/expected_figs/dumbbell/simple_dumbbell/expected_pandas.joblib differ