Skip to content

Commit

Permalink
Distributions: Fix selection output and context
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Sep 18, 2023
1 parent f502668 commit 9254145
Showing 1 changed file with 67 additions and 29 deletions.
96 changes: 67 additions & 29 deletions Orange/widgets/visualize/owdistributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def __init__(self):
self.curve_items = []
self.curve_descriptions = None
self.binnings = []
self.ordered_values = []

self.last_click_idx = None
self.drag_operation = self.DragNone
Expand Down Expand Up @@ -563,6 +564,7 @@ def replot(self):
self._set_axis_names()
self._update_controls_state()
self._call_plotting()
self.__migrate_context_selection_from_int()
self._display_legend()
self.show_selection()

Expand All @@ -571,6 +573,8 @@ def _clear_plot(self):
self.plot_pdf.clear()
self.plot_mark.clear()
self.bar_items = []
self.ordered_values = []
self.last_click_idx = None
self.curve_items = []
self._legend.clear()
self._legend.hide()
Expand Down Expand Up @@ -628,11 +632,11 @@ def _disc_plot(self):
else:
order = np.arange(len(dist))

ordered_values = np.array(var.values)[order]
self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])
self.ordered_values = list(np.array(var.values)[order])
self.ploti.getAxis("bottom").setTicks([list(enumerate(self.ordered_values))])

colors = [QColor(0, 128, 255)]
for i, freq, desc in zip(count(), dist[order], ordered_values):
for i, freq, desc in zip(count(), dist[order], self.ordered_values):
tooltip = \
"<p style='white-space:pre;'>" \
f"<b>{escape(desc)}</b>: {int(freq)} " \
Expand All @@ -650,13 +654,13 @@ def _disc_split_plot(self):
else:
order = np.arange(len(conts))

ordered_values = np.array(var.values)[order]
self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])
self.ordered_values = list(np.array(var.values)[order])
self.ploti.getAxis("bottom").setTicks([list(enumerate(self.ordered_values))])

gcolors = [QColor(*col) for col in self.cvar.colors]
gvalues = self.cvar.values
total = len(self.data)
for i, freqs, desc in zip(count(), conts[order], ordered_values):
for i, freqs, desc in zip(count(), conts[order], self.ordered_values):
self._add_bar(
i - 0.5, 1, 0.1, freqs, gcolors,
stacked=self.stacked_columns, expanded=self.show_probs,
Expand All @@ -682,6 +686,7 @@ def _cont_plot(self):
for i, (x0, x1), freq in zip(count(), zip(x, x[1:]), y):
tot_freq += freq
desc = self.str_int(x0, x1, not i, i == lasti, unique)
self.ordered_values.append(desc)
tooltip = \
"<p style='white-space:pre;'>" \
f"<b>{escape(desc)}</b>: " \
Expand Down Expand Up @@ -731,6 +736,7 @@ def _cont_split_plot(self):
tot_freqs += freqs
plotfreqs = tot_freqs.copy() if self.cumulative_distr else freqs
desc = self.str_int(x0, x1, not i, i == lasti, unique)
self.ordered_values.append(desc)
bar_width = width if unique else x1 - x0
self._add_bar(
x0 + xoff, bar_width, 0 if self.stacked_columns else 0.1,
Expand Down Expand Up @@ -933,23 +939,24 @@ def str_int(self, x0, x1, first, last, unique=False):
# Selection

def _on_item_clicked(self, item, modifiers, drag):
def add_or_remove(idx, add):
def add_or_remove(value, add):
self.drag_operation = [self.DragRemove, self.DragAdd][add]
if add:
self.selection.add(idx)
self.selection.add(value)
else:
if idx in self.selection:
if value in self.selection:
# This can be False when removing with dragging and the
# mouse crosses unselected items
self.selection.remove(idx)
self.selection.remove(value)

def add_range(add):
if self.last_click_idx is None:
add = True
idx_range = {idx}
idx_range = {self.ordered_values[idx]}
else:
from_idx, to_idx = sorted((self.last_click_idx, idx))
idx_range = set(range(from_idx, to_idx + 1))
idx_range = {self.ordered_values[idx]
for idx in range(from_idx, to_idx + 1)}
self.drag_operation = [self.DragRemove, self.DragAdd][add]
if add:
self.selection |= idx_range
Expand All @@ -966,19 +973,20 @@ def add_range(add):
# Dragging has to add a range, otherwise fast dragging skips bars
add_range(self.drag_operation == self.DragAdd)
else:
value = self.ordered_values[idx]
if modifiers & Qt.ShiftModifier:
add_range(self.drag_operation == self.DragAdd)
elif modifiers & Qt.ControlModifier:
add_or_remove(idx, add=idx not in self.selection)
add_or_remove(value, add=value not in self.selection)
else:
if self.selection == {idx}:
# Clicking on a single selected bar deselects it,
if self.selection == {value}:
# Clicking on a single selected bar deselects it,
# but dragging from here will select
add_or_remove(idx, add=False)
add_or_remove(value, add=False)
self.drag_operation = self.DragAdd
else:
self.selection.clear()
add_or_remove(idx, add=True)
add_or_remove(value, add=True)
self.last_click_idx = idx

self.show_selection()
Expand Down Expand Up @@ -1047,44 +1055,56 @@ def _padding(i):

def grouped_selection(self):
return [[g[1] for g in group]
for _, group in groupby(enumerate(sorted(self.selection)),
for _, group in groupby(enumerate(sorted(map(self.ordered_values.index,
self.selection))),
key=lambda x: x[1] - x[0])]
# Alternative:
# groups = []
# last = None
# for idx, value in enumerate(self.ordered_values):
# if value in self.selection:
# if last is None:
# groups.append(last := [])
# last.append(idx)
# else:
# last = None
# return groups

def keyPressEvent(self, e):
def on_nothing_selected():
if e.key() == Qt.Key_Left:
self.last_click_idx = len(self.bar_items) - 1
else:
self.last_click_idx = 0
self.selection.add(self.last_click_idx)
self.selection.add(self.ordered_values[self.last_click_idx])

def on_key_left():
if e.modifiers() & Qt.ShiftModifier:
if self.key_operation == Qt.Key_Right and first != last:
self.selection.remove(last)
self.selection.remove(self.ordered_values[last])
self.last_click_idx = last - 1
elif first:
self.key_operation = Qt.Key_Left
self.selection.add(first - 1)
self.selection.add(self.ordered_values[first - 1])
self.last_click_idx = first - 1
else:
self.selection.clear()
self.last_click_idx = max(first - 1, 0)
self.selection.add(self.last_click_idx)
self.selection.add(self.ordered_values[self.last_click_idx])

def on_key_right():
if e.modifiers() & Qt.ShiftModifier:
if self.key_operation == Qt.Key_Left and first != last:
self.selection.remove(first)
self.selection.remove(self.ordered_values[first])
self.last_click_idx = first + 1
elif not self._is_last_bar(last):
self.key_operation = Qt.Key_Right
self.selection.add(last + 1)
self.selection.add(self.ordered_values[last + 1])
self.last_click_idx = last + 1
else:
self.selection.clear()
self.last_click_idx = min(last + 1, len(self.bar_items) - 1)
self.selection.add(self.last_click_idx)
self.selection.add(self.ordered_values[self.last_click_idx])

if not self.is_valid or not self.bar_items \
or e.key() not in (Qt.Key_Left, Qt.Key_Right):
Expand All @@ -1095,7 +1115,8 @@ def on_key_right():
if not self.selection:
on_nothing_selected()
else:
first, last = min(self.selection), max(self.selection)
sel_indices = list(map(self.ordered_values.index, self.selection))
first, last = min(sel_indices), max(sel_indices)
if e.key() == Qt.Key_Left:
on_key_left()
else:
Expand Down Expand Up @@ -1143,9 +1164,16 @@ def apply(self):
def _get_output_indices_disc(self):
group_indices = np.zeros(len(self.data), dtype=np.int32)
col = self.data.get_column(self.var)
for group_idx, val_idx in enumerate(self.selection, start=1):
group_indices[col == val_idx] = group_idx
values = [self.var.values[i] for i in self.selection]
group_idx = 1
values = []
# self.selection is a set, so its order is random;
# we iterate through ordered_value to get the same order as in chart
for value in self.ordered_values:
if value not in self.selection:
continue
group_indices[col == self.var.to_val(value)] = group_idx
group_idx += 1
values.append(value)
return group_indices, values

def _get_output_indices_cont(self):
Expand Down Expand Up @@ -1220,6 +1248,16 @@ def send_report(self):
text += f" with columns split by '{self.cvar.name}'"
self.report_caption(text)

def __migrate_context_selection_from_int(self):
# Selection used to be stored as bar indices, which would require
# stricter context (matching of values) to work.
# This migration is called after the graph is plotted because it
# requires self.ordered_values.
if self.selection and isinstance(next(iter(self.selection)), int):
self.selection = {
self.ordered_values[idx] for idx in self.selection
if idx < len(self.ordered_values)}


if __name__ == "__main__": # pragma: no cover
WidgetPreview(OWDistributions).run(Table("heart_disease.tab"))

0 comments on commit 9254145

Please sign in to comment.