diff --git a/orangecontrib/geo/widgets/owchoropleth.py b/orangecontrib/geo/widgets/owchoropleth.py index 8333192..12cd4a5 100644 --- a/orangecontrib/geo/widgets/owchoropleth.py +++ b/orangecontrib/geo/widgets/owchoropleth.py @@ -1,7 +1,7 @@ import sys import itertools from xml.sax.saxutils import escape -from typing import List, NamedTuple, Optional, Union, Callable +from typing import List, NamedTuple, Optional, Union, Callable, Tuple from math import floor, log10 from functools import reduce @@ -207,7 +207,7 @@ def __init__(self, widget, parent=None): self.choropleth_items = [] # type: List[ChoroplethItem] - self.n_ids = 0 + self.id_to_index = {} self.selection = None # np.ndarray self.palette = None @@ -217,6 +217,10 @@ def __init__(self, widget, parent=None): self._tooltip_delegate = HelpEventDelegate(self.help_event) self.plot_widget.scene().installEventFilter(self._tooltip_delegate) + @property + def n_ids(self): + return len(self.id_to_index) + def _create_legend(self, anchor): legend = LegendItem() legend.setParentItem(self.plot_widget.getViewBox()) @@ -261,7 +265,7 @@ def clear(self): self.color_legend.clear() self.update_legend_visibility() self.choropleth_items = [] - self.n_ids = 0 + self.id_to_index = {} self.selection = None def reset_graph(self): @@ -283,7 +287,8 @@ def update_choropleth(self): self.choropleth_items.append(choropleth_item) if self.choropleth_items: - self.n_ids = len(self.master.region_ids) + self.id_to_index = { + id_: cnt for cnt, id_ in enumerate(self.master.region_ids)} def update_colors(self): """Update agg_value and inner color of existing polygons.""" @@ -403,13 +408,50 @@ def select_button_clicked(self): self.plot_widget.getViewBox().setMouseMode( self.plot_widget.getViewBox().RectMode) + def set_selection_from_ids(self, sel_ids): + """ + Select regions with given ids. + + Graph stores ids in array, where each element stores an index of + selection group for the corresponding region, while the widget + stores selection as tuples with region ids and groups. + + This method receives widget-like selection and stores it in + graph's array. + + Args: + ids (dict of str to int): selection, as stored by widget + """ + self.selection = np.zeros(self.n_ids, dtype=np.uint8) + if not sel_ids: + return + by_ids = np.array( + [[self.id_to_index[id_], grp] + for id_, grp in sel_ids if id_ in self.id_to_index]) + if by_ids.size: + idx, grp = by_ids.T + self.selection[idx] = grp + + def selected_ids(self): + """ + Return ids of selected regions. See `set_selection_from_ids`. + + Returns: + ids (dict of str to int): selection, as stored by widget + """ + + if self.selection is None: + return [] + ids = self.master.region_ids + sel = np.flatnonzero(self.selection) + return list(zip(ids[sel], self.selection[sel])) + def select_by_id(self, region_id): """ This is called by a `ChoroplethItem` on click. The selection is then based on the corresponding region. """ - indices = np.where(self.master.region_ids == region_id)[0] - self.select_by_indices(indices) + self.select_by_indices(self.id_to_index[region_id]) def select_by_rectangle(self, rect: QRectF): """ @@ -419,7 +461,7 @@ def select_by_rectangle(self, rect: QRectF): indices = set() for ci in self.choropleth_items: if ci.intersects(poly_rect): - indices.add(np.where(self.master.region_ids == ci.region.id)[0][0]) + indices.add(self.id_to_index[ci.region.id]) if indices: self.select_by_indices(np.array(list(indices))) @@ -579,7 +621,7 @@ class Outputs: settings_version = 2 settingsHandler = DomainContextHandler() - selection = Setting(None, schema_only=True) + selection: Optional[List[Tuple[str, int]]] = Setting(None, schema_only=True) auto_commit = Setting(True) attr_lat = ContextSetting(None) @@ -782,17 +824,21 @@ def setup_plot(self): def apply_selection(self): if self.data is not None and self.selection is not None: - index_group = np.array(self.selection).T - selection = np.zeros(self.graph.n_ids, dtype=np.uint8) - selection[index_group[0]] = index_group[1] - self.graph.selection = selection + # on-the-spot migration of a context-like setting + if self.selection and isinstance(self.selection[0][0], int): + self.selection = [ + (self.region_ids[id_], grp) for id_, grp in self.selection + if id_ < len(self.region_ids) + ] + self.graph.set_selection_from_ids(self.selection) + # Retrieve selection back from graph + # to remove any regions that no longer exist in the new data + self.selection = self.graph.selected_ids() self.graph.update_selection_colors() def selection_changed(self): - sel = None if self.data and isinstance(self.data, SqlTable) \ - else self.graph.selection - self.selection = [(i, x) for i, x in enumerate(sel) if x] \ - if sel is not None else None + self.selection = None if self.data and isinstance(self.data, SqlTable) \ + else self.graph.selected_ids() self.commit() def commit(self): @@ -804,9 +850,9 @@ def send_data(self): if data: group_sel = np.zeros(len(data), dtype=int) - if len(graph_sel): + if self.selection: # we get selection by region ids so we have to map it to points - for id, s in zip(self.region_ids, graph_sel): + for id, s in self.selection: if s == 0: continue id_indices = np.where(self.data_ids == id)[0] diff --git a/orangecontrib/geo/widgets/tests/test_owchoropleth.py b/orangecontrib/geo/widgets/tests/test_owchoropleth.py index 954cc1a..fb018a8 100644 --- a/orangecontrib/geo/widgets/tests/test_owchoropleth.py +++ b/orangecontrib/geo/widgets/tests/test_owchoropleth.py @@ -122,5 +122,44 @@ def test_no_data(self): self.send_signal(self.widget.Inputs.data, self.data) +class TestOWChoroplethPlotGraph(WidgetTest): + @classmethod + def setUpClass(cls): + super().setUpClass() + WidgetOutputsTestMixin.init(cls) + cls.same_input_output_domain = False + cls.signal_name = "Data" + + @patch("orangecontrib.geo.widgets.plotutils.ImageLoader") + def setUp(self, _): + self.widget = self.create_widget(OWChoropleth) + self.widget.admin_level = 1 + data = self.data = Table("India_census_district_population") + self.send_signal(self.widget.Inputs.data, data) + self.graph = self.widget.graph + + def test_set_get_selection_by_ids(self): + selection = list(zip(self.widget.region_ids[[1, 3, 4]], [1, 1, 2])) + self.graph.set_selection_from_ids(selection) + np.testing.assert_equal(self.graph.selection[:6], [0, 1, 0, 1, 2, 0]) + self.assertEqual(self.graph.selected_ids(), selection) + + selection.append(("foo", 1)) + selection.pop(0) + self.graph.set_selection_from_ids(selection) + np.testing.assert_equal(self.graph.selection[:6], [0, 0, 0, 1, 2, 0]) + self.assertEqual(self.graph.selected_ids(), selection[:2]) + + self.graph.set_selection_from_ids([]) + self.assertFalse(np.any(self.graph.selection)) + self.assertEqual(self.graph.selected_ids(), []) + + self.graph.set_selection_from_ids(selection) + np.testing.assert_equal(self.graph.selection[:6], [0, 0, 0, 1, 2, 0]) + self.graph.set_selection_from_ids(None) + self.assertFalse(np.any(self.graph.selection)) + self.assertEqual(self.graph.selected_ids(), []) + + if __name__ == "__main__": unittest.main()