Skip to content

Commit

Permalink
Distributions: Fix selection output and context
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Sep 18, 2023
1 parent f502668 commit 6ce7f67
Showing 1 changed file with 77 additions and 29 deletions.
106 changes: 77 additions & 29 deletions Orange/widgets/visualize/owdistributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def __init__(self):
self.curve_items = []
self.curve_descriptions = None
self.binnings = []
self.ordered_values = []

self.last_click_idx = None
self.drag_operation = self.DragNone
Expand Down Expand Up @@ -563,6 +564,7 @@ def replot(self):
self._set_axis_names()
self._update_controls_state()
self._call_plotting()
self._reduce_selection()
self._display_legend()
self.show_selection()

Expand All @@ -571,6 +573,8 @@ def _clear_plot(self):
self.plot_pdf.clear()
self.plot_mark.clear()
self.bar_items = []
self.ordered_values = []
self.last_click_idx = None
self.curve_items = []
self._legend.clear()
self._legend.hide()
Expand Down Expand Up @@ -628,11 +632,11 @@ def _disc_plot(self):
else:
order = np.arange(len(dist))

ordered_values = np.array(var.values)[order]
self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])
self.ordered_values = list(np.array(var.values)[order])
self.ploti.getAxis("bottom").setTicks([list(enumerate(self.ordered_values))])

colors = [QColor(0, 128, 255)]
for i, freq, desc in zip(count(), dist[order], ordered_values):
for i, freq, desc in zip(count(), dist[order], self.ordered_values):
tooltip = \
"<p style='white-space:pre;'>" \
f"<b>{escape(desc)}</b>: {int(freq)} " \
Expand All @@ -650,13 +654,13 @@ def _disc_split_plot(self):
else:
order = np.arange(len(conts))

ordered_values = np.array(var.values)[order]
self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])
self.ordered_values = list(np.array(var.values)[order])
self.ploti.getAxis("bottom").setTicks([list(enumerate(self.ordered_values))])

gcolors = [QColor(*col) for col in self.cvar.colors]
gvalues = self.cvar.values
total = len(self.data)
for i, freqs, desc in zip(count(), conts[order], ordered_values):
for i, freqs, desc in zip(count(), conts[order], self.ordered_values):
self._add_bar(
i - 0.5, 1, 0.1, freqs, gcolors,
stacked=self.stacked_columns, expanded=self.show_probs,
Expand All @@ -682,6 +686,7 @@ def _cont_plot(self):
for i, (x0, x1), freq in zip(count(), zip(x, x[1:]), y):
tot_freq += freq
desc = self.str_int(x0, x1, not i, i == lasti, unique)
self.ordered_values.append(desc)
tooltip = \
"<p style='white-space:pre;'>" \
f"<b>{escape(desc)}</b>: " \
Expand Down Expand Up @@ -731,6 +736,7 @@ def _cont_split_plot(self):
tot_freqs += freqs
plotfreqs = tot_freqs.copy() if self.cumulative_distr else freqs
desc = self.str_int(x0, x1, not i, i == lasti, unique)
self.ordered_values.append(desc)
bar_width = width if unique else x1 - x0
self._add_bar(
x0 + xoff, bar_width, 0 if self.stacked_columns else 0.1,
Expand Down Expand Up @@ -933,23 +939,24 @@ def str_int(self, x0, x1, first, last, unique=False):
# Selection

def _on_item_clicked(self, item, modifiers, drag):
def add_or_remove(idx, add):
def add_or_remove(value, add):
self.drag_operation = [self.DragRemove, self.DragAdd][add]
if add:
self.selection.add(idx)
self.selection.add(value)
else:
if idx in self.selection:
if value in self.selection:
# This can be False when removing with dragging and the
# mouse crosses unselected items
self.selection.remove(idx)
self.selection.remove(value)

def add_range(add):
if self.last_click_idx is None:
add = True
idx_range = {idx}
idx_range = {self.ordered_values[idx]}
else:
from_idx, to_idx = sorted((self.last_click_idx, idx))
idx_range = set(range(from_idx, to_idx + 1))
idx_range = {self.ordered_values[idx]
for idx in range(from_idx, to_idx + 1)}
self.drag_operation = [self.DragRemove, self.DragAdd][add]
if add:
self.selection |= idx_range
Expand All @@ -966,19 +973,20 @@ def add_range(add):
# Dragging has to add a range, otherwise fast dragging skips bars
add_range(self.drag_operation == self.DragAdd)
else:
value = self.ordered_values[idx]
if modifiers & Qt.ShiftModifier:
add_range(self.drag_operation == self.DragAdd)
elif modifiers & Qt.ControlModifier:
add_or_remove(idx, add=idx not in self.selection)
add_or_remove(value, add=value not in self.selection)
else:
if self.selection == {idx}:
# Clicking on a single selected bar deselects it,
if self.selection == {value}:
# Clicking on a single selected bar deselects it,
# but dragging from here will select
add_or_remove(idx, add=False)
add_or_remove(value, add=False)
self.drag_operation = self.DragAdd
else:
self.selection.clear()
add_or_remove(idx, add=True)
add_or_remove(value, add=True)
self.last_click_idx = idx

self.show_selection()
Expand Down Expand Up @@ -1047,44 +1055,56 @@ def _padding(i):

def grouped_selection(self):
return [[g[1] for g in group]
for _, group in groupby(enumerate(sorted(self.selection)),
for _, group in groupby(enumerate(sorted(map(self.ordered_values.index,
self.selection))),
key=lambda x: x[1] - x[0])]
# Alternative:
# groups = []
# last = None
# for idx, value in enumerate(self.ordered_values):
# if value in self.selection:
# if last is None:
# groups.append(last := [])
# last.append(idx)
# else:
# last = None
# return groups

def keyPressEvent(self, e):
def on_nothing_selected():
if e.key() == Qt.Key_Left:
self.last_click_idx = len(self.bar_items) - 1
else:
self.last_click_idx = 0
self.selection.add(self.last_click_idx)
self.selection.add(self.ordered_values[self.last_click_idx])

def on_key_left():
if e.modifiers() & Qt.ShiftModifier:
if self.key_operation == Qt.Key_Right and first != last:
self.selection.remove(last)
self.selection.remove(self.ordered_values[last])
self.last_click_idx = last - 1
elif first:
self.key_operation = Qt.Key_Left
self.selection.add(first - 1)
self.selection.add(self.ordered_values[first - 1])
self.last_click_idx = first - 1
else:
self.selection.clear()
self.last_click_idx = max(first - 1, 0)
self.selection.add(self.last_click_idx)
self.selection.add(self.ordered_values[self.last_click_idx])

def on_key_right():
if e.modifiers() & Qt.ShiftModifier:
if self.key_operation == Qt.Key_Left and first != last:
self.selection.remove(first)
self.selection.remove(self.ordered_values[first])
self.last_click_idx = first + 1
elif not self._is_last_bar(last):
self.key_operation = Qt.Key_Right
self.selection.add(last + 1)
self.selection.add(self.ordered_values[last + 1])
self.last_click_idx = last + 1
else:
self.selection.clear()
self.last_click_idx = min(last + 1, len(self.bar_items) - 1)
self.selection.add(self.last_click_idx)
self.selection.add(self.ordered_values[self.last_click_idx])

if not self.is_valid or not self.bar_items \
or e.key() not in (Qt.Key_Left, Qt.Key_Right):
Expand All @@ -1095,7 +1115,8 @@ def on_key_right():
if not self.selection:
on_nothing_selected()
else:
first, last = min(self.selection), max(self.selection)
sel_indices = list(map(self.ordered_values.index, self.selection))
first, last = min(sel_indices), max(sel_indices)
if e.key() == Qt.Key_Left:
on_key_left()
else:
Expand All @@ -1111,6 +1132,26 @@ def keyReleaseEvent(self, ev):
self.key_operation = None
super().keyReleaseEvent(ev)

def _reduce_selection(self):
"""
Unselect any bars that no longer appear in the plot; migrate from ints
This function is called after plotting to remove any bars that have
been selected but are no longer plotted. This occurs in particular
when the widget receives new data with discrete variables that lack
some values.
This function also migrates from previous settings, which stored ints
instead of values. This migration requires bar labels and cannot be
done before plotting.
"""
if self.selection and isinstance(next(iter(self.selection)), int):
self.selection = {
self.ordered_values[idx] for idx in self.selection
if idx < len(self.ordered_values)}
else:
self.selection = {value for value in self.selection
if value in self.ordered_values}

# -----------------------------
# Output
Expand Down Expand Up @@ -1143,9 +1184,16 @@ def apply(self):
def _get_output_indices_disc(self):
group_indices = np.zeros(len(self.data), dtype=np.int32)
col = self.data.get_column(self.var)
for group_idx, val_idx in enumerate(self.selection, start=1):
group_indices[col == val_idx] = group_idx
values = [self.var.values[i] for i in self.selection]
group_idx = 1
values = []
# self.selection is a set, so its order is random;
# we iterate through ordered_value to get the same order as in chart
for value in self.ordered_values:
if value not in self.selection:
continue
group_indices[col == self.var.to_val(value)] = group_idx
group_idx += 1
values.append(value)
return group_indices, values

def _get_output_indices_cont(self):
Expand Down

0 comments on commit 6ce7f67

Please sign in to comment.