Skip to content

Commit

Permalink
bug categorical annotatoins points (#304)
Browse files Browse the repository at this point in the history
* bug categorical annotatoins points

* added test for continuous column
  • Loading branch information
LucaMarconato authored Aug 23, 2024
1 parent 060d5d0 commit 33fd676
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
35 changes: 28 additions & 7 deletions src/napari_spatialdata/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,35 @@ def _(self, vec: pd.Series, **kwargs: Any) -> dict[str, Any]:
if isinstance(layer, Labels):
element_indices = element_indices[element_indices != 0]
# When merging if the row is not present in the other table it will be nan so we can give it a default color
if (vec_color_name := vec.name + "_color") not in self.model.adata.uns:
colorer = AnnData(shape=(len(vec), 0), obs=pd.DataFrame(index=vec.index, data={"vec": vec}))
_set_colors_for_categorical_obs(colorer, "vec", palette="tab20")
colors = colorer.uns["vec_colors"]
color_dict = dict(zip(vec.cat.categories, colors))
color_dict.update({np.nan: "#808080ff"})
vec_color_name = vec.name + "_color"
if self._attr != "columns_df":
if vec_color_name not in self.model.adata.uns:
colorer = AnnData(shape=(len(vec), 0), obs=pd.DataFrame(index=vec.index, data={"vec": vec}))
_set_colors_for_categorical_obs(colorer, "vec", palette="tab20")
colors = colorer.uns["vec_colors"]
color_dict = dict(zip(vec.cat.categories, colors))
color_dict.update({np.nan: "#808080ff"})
else:
color_dict = self.model.adata.uns[vec_color_name]
else:
color_dict = self.model.adata.uns[vec_color_name]
df = layer.metadata["_columns_df"]
if vec_color_name not in df.columns:
colorer = AnnData(shape=(len(vec), 0), obs=pd.DataFrame(index=vec.index, data={"vec": vec}))
_set_colors_for_categorical_obs(colorer, "vec", palette="tab20")
colors = colorer.uns["vec_colors"]
color_dict = dict(zip(vec.cat.categories, colors))
color_dict.update({np.nan: "#808080ff"})
color_column = vec.apply(lambda x: color_dict[x])
df[vec_color_name] = color_column
else:
unique_colors = df[[vec.name, vec_color_name]].drop_duplicates()
unique_colors.set_index(vec.name, inplace=True)
if not unique_colors.index.is_unique:
raise ValueError(
f"The {vec_color_name} column must have unique values for the each {vec.name} level. Found:\n"
f"{unique_colors}"
)
color_dict = unique_colors.to_dict()["genes_color"]

if self.model.instance_key is not None and self.model.instance_key == vec.index.name:
merge_df = pd.merge(
Expand Down
30 changes: 30 additions & 0 deletions tests/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from napari.utils.events import EventedList
from napari_spatialdata._sdata_widgets import SdataWidget
from napari_spatialdata._view import QtAdataViewWidget
from napari_spatialdata.utils._test_utils import click_list_widget_item, get_center_pos_listitem
from spatialdata.datasets import blobs


Expand Down Expand Up @@ -33,3 +34,32 @@ def test_channel_slider_images(qtbot, make_napari_viewer: any, widget: Any, n_ch
qtbot.wait(50) # wait for a short time to simulate user interaction

viewer.close()


@pytest.mark.parametrize("widget", [QtAdataViewWidget])
@pytest.mark.parametrize("column", ["genes", "instance_id"])
def test_plot_dataframe_annotation_on_points(qtbot, make_napari_viewer: any, widget: Any, column: str):
sdata_blobs = blobs()
viewer = make_napari_viewer()
sdata_widget = SdataWidget(viewer, EventedList([sdata_blobs]))

viewer.window.add_dock_widget(sdata_widget, name="SpatialData")

# init the adata view widget
widget = widget(viewer)

sdata_widget.viewer_model.add_sdata_points(sdata_blobs, "blobs_points", "global", False)

# plot dataframe annotations on the points
center_pos = get_center_pos_listitem(widget.dataframe_columns_widget, "instance_id")
# TODO: the double click doesn't trigger the signal, so below we are calling _onAction directly (looking at a
# screenshot of the qtbot, the interface shows to be correctly clicked)
click_list_widget_item(
qtbot,
widget=widget.dataframe_columns_widget,
position=center_pos,
wait_signal="currentItemChanged",
click="double",
)
widget.dataframe_columns_widget._onAction([column])
viewer.close()

0 comments on commit 33fd676

Please sign in to comment.