From 0113657dd5074953764d97e83f008a61fce0941b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Sat, 27 Apr 2019 11:53:57 +0200 Subject: [PATCH 1/6] OwPythagoreanForest: Remember selection --- .../widgets/visualize/owpythagoreanforest.py | 25 ++++++++++++++----- .../tests/test_owpythagoreanforest.py | 22 +++++++++++++++- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py index 3cfe7d041d0..05c37aa22a9 100644 --- a/Orange/widgets/visualize/owpythagoreanforest.py +++ b/Orange/widgets/visualize/owpythagoreanforest.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Optional from AnyQt.QtCore import Qt, QRectF, QSize, QPointF, QSizeF, QModelIndex, \ - QItemSelection, QT_VERSION + QItemSelection, QItemSelectionModel, QT_VERSION from AnyQt.QtGui import QPainter, QPen, QColor, QBrush, QMouseEvent from AnyQt.QtWidgets import QSizePolicy, QGraphicsScene, QLabel, QSlider, \ QListView, QStyledItemDelegate, QStyleOptionViewItem, QStyle @@ -174,11 +174,15 @@ class Outputs: graph_name = 'scene' # Settings + settingsHandler = settings.DomainContextHandler() + depth_limit = settings.ContextSetting(10) target_class_index = settings.ContextSetting(0) size_calc_idx = settings.Setting(0) zoom = settings.Setting(200) + selected_index = settings.ContextSetting(None) + SIZE_CALCULATION = [ ('Normal', lambda x: x), ('Square root', lambda x: sqrt(x)), @@ -265,6 +269,7 @@ def __init__(self): @Inputs.random_forest def set_rf(self, model=None): """When a different forest is given.""" + self.closeContext() self.clear() self.rf_model = model @@ -283,11 +288,19 @@ def set_rf(self, model=None): self._update_target_class_combo() self._update_depth_slider() + self.openContext(model) + # Restore item selection + if self.selected_index is not None: + index = self.list_view.model().index(self.selected_index) + selection = QItemSelection(index, index) + self.list_view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect) + def clear(self): """Clear all relevant data from the widget.""" self.rf_model = None self.forest = None self.forest_model.clear() + self.selected_index = None self._clear_info_box() self._clear_target_class_combo() @@ -342,19 +355,19 @@ def onDeleteWidget(self): super().onDeleteWidget() self.clear() - def commit(self, selection): - # type: (QItemSelection) -> None + def commit(self, selection: QItemSelection) -> None: """Commit the selected tree to output.""" selected_indices = selection.indexes() if not len(selected_indices): + self.selected_index = None self.Outputs.tree.send(None) return - selected_index, = selection.indexes() + # We only allow selecting a single tree so there will always be one index + self.selected_index = selected_indices[0].row() - idx = selected_index.row() - tree = self.rf_model.trees[idx] + tree = self.rf_model.trees[self.selected_index] tree.instances = self.instances tree.meta_target_class_index = self.target_class_index tree.meta_size_calc_idx = self.size_calc_idx diff --git a/Orange/widgets/visualize/tests/test_owpythagoreanforest.py b/Orange/widgets/visualize/tests/test_owpythagoreanforest.py index 49ac389e16d..c1c62f2e344 100644 --- a/Orange/widgets/visualize/tests/test_owpythagoreanforest.py +++ b/Orange/widgets/visualize/tests/test_owpythagoreanforest.py @@ -2,7 +2,7 @@ from unittest.mock import Mock -from AnyQt.QtCore import Qt +from AnyQt.QtCore import Qt, QItemSelection, QItemSelectionModel from Orange.classification.random_forest import RandomForestLearner from Orange.data import Table @@ -201,3 +201,23 @@ def _callback(): # Check that individual squares all have the same color colors_same = [self._check_all_same(x) for x in zip(*colors)] self.assertTrue(all(colors_same)) + + def select_tree(self, idx: int) -> None: + list_view = self.widget.list_view + index = list_view.model().index(idx) + selection = QItemSelection(index, index) + list_view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect) + + def test_storing_selection(self): + # Select one of the trees + idx = 1 + self.send_signal(self.widget.Inputs.random_forest, self.titanic) + self.select_tree(idx) + # Clear input + self.send_signal(self.widget.Inputs.random_forest, None) + # Restore previous data; context settings should be restored + self.send_signal(self.widget.Inputs.random_forest, self.titanic) + + output = self.get_output(self.widget.Outputs.tree) + self.assertIsNotNone(output) + self.assertIs(output.skl_model, self.titanic.trees[idx].skl_model) From 2f109597d31017a66b76b8484b2fe6058cbb2f3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Sat, 27 Apr 2019 12:14:19 +0200 Subject: [PATCH 2/6] OwPythagoreanForest: Remove dead code --- Orange/widgets/visualize/owpythagoreanforest.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py index 05c37aa22a9..25183aeca0c 100644 --- a/Orange/widgets/visualize/owpythagoreanforest.py +++ b/Orange/widgets/visualize/owpythagoreanforest.py @@ -203,7 +203,6 @@ def __init__(self): self.rf_model = None self.forest = None self.instances = None - self.clf_dataset = None self.color_palette = None @@ -276,13 +275,7 @@ def set_rf(self, model=None): if model is not None: self.forest = self._get_forest_adapter(self.rf_model) self.forest_model[:] = self.forest.trees - self.instances = model.instances - # This bit is important for the regression classifier - if self.instances is not None and self.instances.domain != model.domain: - self.clf_dataset = self.instances.transform(self.rf_model.domain) - else: - self.clf_dataset = self.instances self._update_info_box() self._update_target_class_combo() From 4aa374637cf8e1c86b7cc5b03577fac8395ccf8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Sat, 27 Apr 2019 12:41:53 +0200 Subject: [PATCH 3/6] OwPythagorasTree: Properly use context settings --- Orange/widgets/visualize/owpythagorastree.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py index 5440b7c7966..6981dcf285b 100644 --- a/Orange/widgets/visualize/owpythagorastree.py +++ b/Orange/widgets/visualize/owpythagorastree.py @@ -57,6 +57,8 @@ class Outputs: graph_name = 'scene' # Settings + settingsHandler = settings.DomainContextHandler() + depth_limit = settings.ContextSetting(10) target_class_index = settings.ContextSetting(0) size_calc_idx = settings.Setting(0) @@ -194,7 +196,9 @@ def set_tree(self, model=None): # self.depth_limit = model.meta_depth_limit # self.update_depth() - self.Outputs.annotated_data.send(create_annotated_table(self.instances, None)) + self.openContext(self.model) + + self.Outputs.annotated_data.send(create_annotated_table(self.data, None)) def clear(self): """Clear all relevant data from the widget.""" From 85729102d11908d2fb005933dc8350a701cda999 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Sat, 27 Apr 2019 12:43:01 +0200 Subject: [PATCH 4/6] OwPythagorasTree: Remove dead code --- Orange/widgets/visualize/owpythagorastree.py | 39 +++++++++----------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py index 6981dcf285b..836f47ec422 100644 --- a/Orange/widgets/visualize/owpythagorastree.py +++ b/Orange/widgets/visualize/owpythagorastree.py @@ -75,8 +75,7 @@ def __init__(self): super().__init__() # Instance variables self.model = None - self.instances = None - self.clf_dataset = None + self.data = None # The tree adapter instance which is passed from the outside self.tree_adapter = None self.legend = None @@ -149,18 +148,12 @@ def __init__(self): @Inputs.tree def set_tree(self, model=None): """When a different tree is given.""" + self.closeContext() self.clear() self.model = model if model is not None: - self.instances = model.instances - # this bit is important for the regression classifier - if self.instances is not None and \ - self.instances.domain != model.domain: - self.clf_dataset = self.instances.transform(self.model.domain) - else: - self.clf_dataset = self.instances - + self.data = model.instances self.tree_adapter = self._get_tree_adapter(self.model) self.ptree.clear() @@ -203,8 +196,7 @@ def set_tree(self, model=None): def clear(self): """Clear all relevant data from the widget.""" self.model = None - self.instances = None - self.clf_dataset = None + self.data = None self.tree_adapter = None if self.legend is not None: @@ -311,16 +303,21 @@ def onDeleteWidget(self): def commit(self): """Commit the selected data to output.""" - if self.instances is None: + if self.data is None: self.Outputs.selected_data.send(None) self.Outputs.annotated_data.send(None) return - nodes = [i.tree_node.label for i in self.scene.selectedItems() - if isinstance(i, SquareGraphicsItem)] + + nodes = [ + i.tree_node.label for i in self.scene.selectedItems() + if isinstance(i, SquareGraphicsItem) + ] data = self.tree_adapter.get_instances_in_nodes(nodes) self.Outputs.selected_data.send(data) selected_indices = self.tree_adapter.get_indices(nodes) - self.Outputs.annotated_data.send(create_annotated_table(self.instances, selected_indices)) + self.Outputs.annotated_data.send( + create_annotated_table(self.data, selected_indices) + ) def send_report(self): """Send report.""" @@ -331,9 +328,9 @@ def _update_target_class_combo(self): label = [x for x in self.target_class_combo.parent().children() if isinstance(x, QLabel)][0] - if self.instances.domain.has_discrete_class: + if self.data.domain.has_discrete_class: label_text = 'Target class' - values = [c.title() for c in self.instances.domain.class_vars[0].values] + values = [c.title() for c in self.data.domain.class_vars[0].values] values.insert(0, 'None') else: label_text = 'Node color' @@ -346,7 +343,7 @@ def _update_legend_colors(self): if self.legend is not None: self.scene.removeItem(self.legend) - if self.instances.domain.has_discrete_class: + if self.data.domain.has_discrete_class: self._classification_update_legend_colors() else: self._regression_update_legend_colors() @@ -379,14 +376,14 @@ def _get_colors_domain(domain): # The colors are the class mean if self.target_class_index == 1: - values = (np.min(self.clf_dataset.Y), np.max(self.clf_dataset.Y)) + values = (np.min(self.data.Y), np.max(self.data.Y)) colors = _get_colors_domain(self.model.domain) while len(values) != len(colors): values.insert(1, -1) items = list(zip(values, colors)) # Colors are the stddev elif self.target_class_index == 2: - values = (0, np.std(self.clf_dataset.Y)) + values = (0, np.std(self.data.Y)) colors = _get_colors_domain(self.model.domain) while len(values) != len(colors): values.insert(1, -1) From dd8b66b50791e084242cf0dfe8f1f2e3b1bd26d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Sat, 27 Apr 2019 12:43:16 +0200 Subject: [PATCH 5/6] OwPythagorasTree: Fix redraw crash when no data --- Orange/widgets/visualize/owpythagorastree.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py index 836f47ec422..b98f0e34ac6 100644 --- a/Orange/widgets/visualize/owpythagorastree.py +++ b/Orange/widgets/visualize/owpythagorastree.py @@ -224,6 +224,8 @@ def update_size_calc(self): self.invalidate_tree() def redraw(self): + if self.data is None: + return self.tree_adapter.shuffle_children() self.invalidate_tree() From 57d4eccafa727dcbb721730c93bf907de396d4d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Wed, 22 May 2019 09:36:28 +0200 Subject: [PATCH 6/6] OWPythagoreanTree: Fix depth limit not being restored properly from context --- Orange/widgets/visualize/owpythagorastree.py | 33 +++++++++---------- .../visualize/tests/test_owpythagorastree.py | 12 +++++++ 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py index b98f0e34ac6..5eb1c63d10f 100644 --- a/Orange/widgets/visualize/owpythagorastree.py +++ b/Orange/widgets/visualize/owpythagorastree.py @@ -172,25 +172,24 @@ def set_tree(self, model=None): self._update_main_area() - # The target class can also be passed from the meta properties - # This must be set after `_update_target_class_combo` - if hasattr(model, 'meta_target_class_index'): - self.target_class_index = model.meta_target_class_index - self.update_colors() - - # Get meta variables describing what the settings should look like - # if the tree is passed from the Pythagorean forest widget. - if hasattr(model, 'meta_size_calc_idx'): - self.size_calc_idx = model.meta_size_calc_idx - self.update_size_calc() - - # TODO There is still something wrong with this - # if hasattr(model, 'meta_depth_limit'): - # self.depth_limit = model.meta_depth_limit - # self.update_depth() - self.openContext(self.model) + self.update_depth() + + # The forest widget sets the following attributes on the tree, + # describing the settings on the forest widget. To keep the tree + # looking the same as on the forest widget, we prefer these settings to + # context settings, if set. + if hasattr(model, "meta_target_class_index"): + self.target_class_index = model.meta_target_class_index + self.update_colors() + if hasattr(model, "meta_size_calc_idx"): + self.size_calc_idx = model.meta_size_calc_idx + self.update_size_calc() + if hasattr(model, "meta_depth_limit"): + self.depth_limit = model.meta_depth_limit + self.update_depth() + self.Outputs.annotated_data.send(create_annotated_table(self.data, None)) def clear(self): diff --git a/Orange/widgets/visualize/tests/test_owpythagorastree.py b/Orange/widgets/visualize/tests/test_owpythagorastree.py index 301d6c697f5..f05dd53b7ff 100644 --- a/Orange/widgets/visualize/tests/test_owpythagorastree.py +++ b/Orange/widgets/visualize/tests/test_owpythagorastree.py @@ -383,3 +383,15 @@ def test_forest_tree_table(self): square.setSelected(True) tab = self.get_output(tree_w.Outputs.selected_data, widget=tree_w) self.assertGreater(len(tab), 0) + + def test_changing_data_restores_depth_from_previous_settings(self): + titanic_data = Table("titanic")[::50] + forest = RandomForestLearner(n_estimators=3)(titanic_data) + forest.instances = titanic_data + + self.send_signal(self.widget.Inputs.tree, forest.trees[0]) + self.widget.controls.depth_limit.setValue(1) + + # The domain is still the same, so restore the depth limit from before + self.send_signal(self.widget.Inputs.tree, forest.trees[1]) + self.assertEqual(self.widget.ptree._depth_limit, 1)