Skip to content

Commit

Permalink
Merge pull request #168 from janezd/choroplet-fix-selection
Browse files Browse the repository at this point in the history
[FIX] Choropleth: Store selection by region ids, not indices
  • Loading branch information
VesnaT authored Feb 24, 2023
2 parents 6247c5d + d3b8a91 commit 4e72ff9
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 18 deletions.
82 changes: 64 additions & 18 deletions orangecontrib/geo/widgets/owchoropleth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import itertools
from xml.sax.saxutils import escape
from typing import List, NamedTuple, Optional, Union, Callable
from typing import List, NamedTuple, Optional, Union, Callable, Tuple
from math import floor, log10
from functools import reduce

Expand Down Expand Up @@ -207,7 +207,7 @@ def __init__(self, widget, parent=None):

self.choropleth_items = [] # type: List[ChoroplethItem]

self.n_ids = 0
self.id_to_index = {}
self.selection = None # np.ndarray

self.palette = None
Expand All @@ -217,6 +217,10 @@ def __init__(self, widget, parent=None):
self._tooltip_delegate = HelpEventDelegate(self.help_event)
self.plot_widget.scene().installEventFilter(self._tooltip_delegate)

@property
def n_ids(self):
return len(self.id_to_index)

def _create_legend(self, anchor):
legend = LegendItem()
legend.setParentItem(self.plot_widget.getViewBox())
Expand Down Expand Up @@ -261,7 +265,7 @@ def clear(self):
self.color_legend.clear()
self.update_legend_visibility()
self.choropleth_items = []
self.n_ids = 0
self.id_to_index = {}
self.selection = None

def reset_graph(self):
Expand All @@ -283,7 +287,8 @@ def update_choropleth(self):
self.choropleth_items.append(choropleth_item)

if self.choropleth_items:
self.n_ids = len(self.master.region_ids)
self.id_to_index = {
id_: cnt for cnt, id_ in enumerate(self.master.region_ids)}

def update_colors(self):
"""Update agg_value and inner color of existing polygons."""
Expand Down Expand Up @@ -403,13 +408,50 @@ def select_button_clicked(self):
self.plot_widget.getViewBox().setMouseMode(
self.plot_widget.getViewBox().RectMode)

def set_selection_from_ids(self, sel_ids):
"""
Select regions with given ids.
Graph stores ids in array, where each element stores an index of
selection group for the corresponding region, while the widget
stores selection as tuples with region ids and groups.
This method receives widget-like selection and stores it in
graph's array.
Args:
ids (dict of str to int): selection, as stored by widget
"""
self.selection = np.zeros(self.n_ids, dtype=np.uint8)
if not sel_ids:
return
by_ids = np.array(
[[self.id_to_index[id_], grp]
for id_, grp in sel_ids if id_ in self.id_to_index])
if by_ids.size:
idx, grp = by_ids.T
self.selection[idx] = grp

def selected_ids(self):
"""
Return ids of selected regions. See `set_selection_from_ids`.
Returns:
ids (dict of str to int): selection, as stored by widget
"""

if self.selection is None:
return []
ids = self.master.region_ids
sel = np.flatnonzero(self.selection)
return list(zip(ids[sel], self.selection[sel]))

def select_by_id(self, region_id):
"""
This is called by a `ChoroplethItem` on click.
The selection is then based on the corresponding region.
"""
indices = np.where(self.master.region_ids == region_id)[0]
self.select_by_indices(indices)
self.select_by_indices(self.id_to_index[region_id])

def select_by_rectangle(self, rect: QRectF):
"""
Expand All @@ -419,7 +461,7 @@ def select_by_rectangle(self, rect: QRectF):
indices = set()
for ci in self.choropleth_items:
if ci.intersects(poly_rect):
indices.add(np.where(self.master.region_ids == ci.region.id)[0][0])
indices.add(self.id_to_index[ci.region.id])
if indices:
self.select_by_indices(np.array(list(indices)))

Expand Down Expand Up @@ -579,7 +621,7 @@ class Outputs:

settings_version = 2
settingsHandler = DomainContextHandler()
selection = Setting(None, schema_only=True)
selection: Optional[List[Tuple[str, int]]] = Setting(None, schema_only=True)
auto_commit = Setting(True)

attr_lat = ContextSetting(None)
Expand Down Expand Up @@ -782,17 +824,21 @@ def setup_plot(self):

def apply_selection(self):
if self.data is not None and self.selection is not None:
index_group = np.array(self.selection).T
selection = np.zeros(self.graph.n_ids, dtype=np.uint8)
selection[index_group[0]] = index_group[1]
self.graph.selection = selection
# on-the-spot migration of a context-like setting
if self.selection and isinstance(self.selection[0][0], int):
self.selection = [
(self.region_ids[id_], grp) for id_, grp in self.selection
if id_ < len(self.region_ids)
]
self.graph.set_selection_from_ids(self.selection)
# Retrieve selection back from graph
# to remove any regions that no longer exist in the new data
self.selection = self.graph.selected_ids()
self.graph.update_selection_colors()

def selection_changed(self):
sel = None if self.data and isinstance(self.data, SqlTable) \
else self.graph.selection
self.selection = [(i, x) for i, x in enumerate(sel) if x] \
if sel is not None else None
self.selection = None if self.data and isinstance(self.data, SqlTable) \
else self.graph.selected_ids()
self.commit()

def commit(self):
Expand All @@ -804,9 +850,9 @@ def send_data(self):
if data:
group_sel = np.zeros(len(data), dtype=int)

if len(graph_sel):
if self.selection:
# we get selection by region ids so we have to map it to points
for id, s in zip(self.region_ids, graph_sel):
for id, s in self.selection:
if s == 0:
continue
id_indices = np.where(self.data_ids == id)[0]
Expand Down
39 changes: 39 additions & 0 deletions orangecontrib/geo/widgets/tests/test_owchoropleth.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,44 @@ def test_no_data(self):
self.send_signal(self.widget.Inputs.data, self.data)


class TestOWChoroplethPlotGraph(WidgetTest):
@classmethod
def setUpClass(cls):
super().setUpClass()
WidgetOutputsTestMixin.init(cls)
cls.same_input_output_domain = False
cls.signal_name = "Data"

@patch("orangecontrib.geo.widgets.plotutils.ImageLoader")
def setUp(self, _):
self.widget = self.create_widget(OWChoropleth)
self.widget.admin_level = 1
data = self.data = Table("India_census_district_population")
self.send_signal(self.widget.Inputs.data, data)
self.graph = self.widget.graph

def test_set_get_selection_by_ids(self):
selection = list(zip(self.widget.region_ids[[1, 3, 4]], [1, 1, 2]))
self.graph.set_selection_from_ids(selection)
np.testing.assert_equal(self.graph.selection[:6], [0, 1, 0, 1, 2, 0])
self.assertEqual(self.graph.selected_ids(), selection)

selection.append(("foo", 1))
selection.pop(0)
self.graph.set_selection_from_ids(selection)
np.testing.assert_equal(self.graph.selection[:6], [0, 0, 0, 1, 2, 0])
self.assertEqual(self.graph.selected_ids(), selection[:2])

self.graph.set_selection_from_ids([])
self.assertFalse(np.any(self.graph.selection))
self.assertEqual(self.graph.selected_ids(), [])

self.graph.set_selection_from_ids(selection)
np.testing.assert_equal(self.graph.selection[:6], [0, 0, 0, 1, 2, 0])
self.graph.set_selection_from_ids(None)
self.assertFalse(np.any(self.graph.selection))
self.assertEqual(self.graph.selected_ids(), [])


if __name__ == "__main__":
unittest.main()

0 comments on commit 4e72ff9

Please sign in to comment.