Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug categorical annotatoins points #304

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions src/napari_spatialdata/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,35 @@ def _(self, vec: pd.Series, **kwargs: Any) -> dict[str, Any]:
if isinstance(layer, Labels):
element_indices = element_indices[element_indices != 0]
# When merging if the row is not present in the other table it will be nan so we can give it a default color
if (vec_color_name := vec.name + "_color") not in self.model.adata.uns:
colorer = AnnData(shape=(len(vec), 0), obs=pd.DataFrame(index=vec.index, data={"vec": vec}))
_set_colors_for_categorical_obs(colorer, "vec", palette="tab20")
colors = colorer.uns["vec_colors"]
color_dict = dict(zip(vec.cat.categories, colors))
color_dict.update({np.nan: "#808080ff"})
vec_color_name = vec.name + "_color"
if self._attr != "columns_df":
if vec_color_name not in self.model.adata.uns:
colorer = AnnData(shape=(len(vec), 0), obs=pd.DataFrame(index=vec.index, data={"vec": vec}))
_set_colors_for_categorical_obs(colorer, "vec", palette="tab20")
colors = colorer.uns["vec_colors"]
color_dict = dict(zip(vec.cat.categories, colors))
color_dict.update({np.nan: "#808080ff"})
else:
color_dict = self.model.adata.uns[vec_color_name]
else:
color_dict = self.model.adata.uns[vec_color_name]
df = layer.metadata["_columns_df"]
if vec_color_name not in df.columns:
colorer = AnnData(shape=(len(vec), 0), obs=pd.DataFrame(index=vec.index, data={"vec": vec}))
_set_colors_for_categorical_obs(colorer, "vec", palette="tab20")
colors = colorer.uns["vec_colors"]
color_dict = dict(zip(vec.cat.categories, colors))
color_dict.update({np.nan: "#808080ff"})
color_column = vec.apply(lambda x: color_dict[x])
df[vec_color_name] = color_column
else:
unique_colors = df[[vec.name, vec_color_name]].drop_duplicates()
unique_colors.set_index(vec.name, inplace=True)
if not unique_colors.index.is_unique:
raise ValueError(
f"The {vec_color_name} column must have unique values for the each {vec.name} level. Found:\n"
f"{unique_colors}"
)
color_dict = unique_colors.to_dict()["genes_color"]

if self.model.instance_key is not None and self.model.instance_key == vec.index.name:
merge_df = pd.merge(
Expand Down
30 changes: 30 additions & 0 deletions tests/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from napari.utils.events import EventedList
from napari_spatialdata._sdata_widgets import SdataWidget
from napari_spatialdata._view import QtAdataViewWidget
from napari_spatialdata.utils._test_utils import click_list_widget_item, get_center_pos_listitem
from spatialdata.datasets import blobs


Expand Down Expand Up @@ -33,3 +34,32 @@ def test_channel_slider_images(qtbot, make_napari_viewer: any, widget: Any, n_ch
qtbot.wait(50) # wait for a short time to simulate user interaction

viewer.close()


@pytest.mark.parametrize("widget", [QtAdataViewWidget])
@pytest.mark.parametrize("column", ["genes", "instance_id"])
def test_plot_dataframe_annotation_on_points(qtbot, make_napari_viewer: any, widget: Any, column: str):
sdata_blobs = blobs()
viewer = make_napari_viewer()
sdata_widget = SdataWidget(viewer, EventedList([sdata_blobs]))

viewer.window.add_dock_widget(sdata_widget, name="SpatialData")

# init the adata view widget
widget = widget(viewer)

sdata_widget.viewer_model.add_sdata_points(sdata_blobs, "blobs_points", "global", False)

# plot dataframe annotations on the points
center_pos = get_center_pos_listitem(widget.dataframe_columns_widget, "instance_id")
# TODO: the double click doesn't trigger the signal, so below we are calling _onAction directly (looking at a
# screenshot of the qtbot, the interface shows to be correctly clicked)
click_list_widget_item(
qtbot,
widget=widget.dataframe_columns_widget,
position=center_pos,
wait_signal="currentItemChanged",
click="double",
)
widget.dataframe_columns_widget._onAction([column])
viewer.close()
Loading