Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Minor improvements to pythagorean trees #3777

Merged
merged 6 commits into from
May 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 38 additions & 36 deletions Orange/widgets/visualize/owpythagorastree.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class Outputs:
graph_name = 'scene'

# Settings
settingsHandler = settings.DomainContextHandler()

depth_limit = settings.ContextSetting(10)
target_class_index = settings.ContextSetting(0)
size_calc_idx = settings.Setting(0)
Expand All @@ -73,8 +75,7 @@ def __init__(self):
super().__init__()
# Instance variables
self.model = None
self.instances = None
self.clf_dataset = None
self.data = None
# The tree adapter instance which is passed from the outside
self.tree_adapter = None
self.legend = None
Expand Down Expand Up @@ -147,18 +148,12 @@ def __init__(self):
@Inputs.tree
def set_tree(self, model=None):
"""When a different tree is given."""
self.closeContext()
self.clear()
self.model = model

if model is not None:
self.instances = model.instances
# this bit is important for the regression classifier
if self.instances is not None and \
self.instances.domain != model.domain:
self.clf_dataset = self.instances.transform(self.model.domain)
else:
self.clf_dataset = self.instances

self.data = model.instances
self.tree_adapter = self._get_tree_adapter(self.model)
self.ptree.clear()

Expand All @@ -177,30 +172,30 @@ def set_tree(self, model=None):

self._update_main_area()

# The target class can also be passed from the meta properties
# This must be set after `_update_target_class_combo`
if hasattr(model, 'meta_target_class_index'):
self.target_class_index = model.meta_target_class_index
self.update_colors()
self.openContext(self.model)

# Get meta variables describing what the settings should look like
# if the tree is passed from the Pythagorean forest widget.
if hasattr(model, 'meta_size_calc_idx'):
self.size_calc_idx = model.meta_size_calc_idx
self.update_size_calc()
self.update_depth()

# TODO There is still something wrong with this
# if hasattr(model, 'meta_depth_limit'):
# self.depth_limit = model.meta_depth_limit
# self.update_depth()
# The forest widget sets the following attributes on the tree,
# describing the settings on the forest widget. To keep the tree
# looking the same as on the forest widget, we prefer these settings to
# context settings, if set.
if hasattr(model, "meta_target_class_index"):
self.target_class_index = model.meta_target_class_index
self.update_colors()
if hasattr(model, "meta_size_calc_idx"):
self.size_calc_idx = model.meta_size_calc_idx
self.update_size_calc()
if hasattr(model, "meta_depth_limit"):
self.depth_limit = model.meta_depth_limit
self.update_depth()

self.Outputs.annotated_data.send(create_annotated_table(self.instances, None))
self.Outputs.annotated_data.send(create_annotated_table(self.data, None))

def clear(self):
"""Clear all relevant data from the widget."""
self.model = None
self.instances = None
self.clf_dataset = None
self.data = None
self.tree_adapter = None

if self.legend is not None:
Expand Down Expand Up @@ -228,6 +223,8 @@ def update_size_calc(self):
self.invalidate_tree()

def redraw(self):
if self.data is None:
return
self.tree_adapter.shuffle_children()
self.invalidate_tree()

Expand Down Expand Up @@ -307,16 +304,21 @@ def onDeleteWidget(self):

def commit(self):
"""Commit the selected data to output."""
if self.instances is None:
if self.data is None:
self.Outputs.selected_data.send(None)
self.Outputs.annotated_data.send(None)
return
nodes = [i.tree_node.label for i in self.scene.selectedItems()
if isinstance(i, SquareGraphicsItem)]

nodes = [
i.tree_node.label for i in self.scene.selectedItems()
if isinstance(i, SquareGraphicsItem)
]
data = self.tree_adapter.get_instances_in_nodes(nodes)
self.Outputs.selected_data.send(data)
selected_indices = self.tree_adapter.get_indices(nodes)
self.Outputs.annotated_data.send(create_annotated_table(self.instances, selected_indices))
self.Outputs.annotated_data.send(
create_annotated_table(self.data, selected_indices)
)

def send_report(self):
"""Send report."""
Expand All @@ -327,9 +329,9 @@ def _update_target_class_combo(self):
label = [x for x in self.target_class_combo.parent().children()
if isinstance(x, QLabel)][0]

if self.instances.domain.has_discrete_class:
if self.data.domain.has_discrete_class:
label_text = 'Target class'
values = [c.title() for c in self.instances.domain.class_vars[0].values]
values = [c.title() for c in self.data.domain.class_vars[0].values]
values.insert(0, 'None')
else:
label_text = 'Node color'
Expand All @@ -342,7 +344,7 @@ def _update_legend_colors(self):
if self.legend is not None:
self.scene.removeItem(self.legend)

if self.instances.domain.has_discrete_class:
if self.data.domain.has_discrete_class:
self._classification_update_legend_colors()
else:
self._regression_update_legend_colors()
Expand Down Expand Up @@ -375,14 +377,14 @@ def _get_colors_domain(domain):

# The colors are the class mean
if self.target_class_index == 1:
values = (np.min(self.clf_dataset.Y), np.max(self.clf_dataset.Y))
values = (np.min(self.data.Y), np.max(self.data.Y))
colors = _get_colors_domain(self.model.domain)
while len(values) != len(colors):
values.insert(1, -1)
items = list(zip(values, colors))
# Colors are the stddev
elif self.target_class_index == 2:
values = (0, np.std(self.clf_dataset.Y))
values = (0, np.std(self.data.Y))
colors = _get_colors_domain(self.model.domain)
while len(values) != len(colors):
values.insert(1, -1)
Expand Down
32 changes: 19 additions & 13 deletions Orange/widgets/visualize/owpythagoreanforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Optional

from AnyQt.QtCore import Qt, QRectF, QSize, QPointF, QSizeF, QModelIndex, \
QItemSelection, QT_VERSION
QItemSelection, QItemSelectionModel, QT_VERSION
from AnyQt.QtGui import QPainter, QPen, QColor, QBrush, QMouseEvent
from AnyQt.QtWidgets import QSizePolicy, QGraphicsScene, QLabel, QSlider, \
QListView, QStyledItemDelegate, QStyleOptionViewItem, QStyle
Expand Down Expand Up @@ -174,11 +174,15 @@ class Outputs:
graph_name = 'scene'

# Settings
settingsHandler = settings.DomainContextHandler()

depth_limit = settings.ContextSetting(10)
target_class_index = settings.ContextSetting(0)
size_calc_idx = settings.Setting(0)
zoom = settings.Setting(200)

selected_index = settings.ContextSetting(None)

SIZE_CALCULATION = [
('Normal', lambda x: x),
('Square root', lambda x: sqrt(x)),
Expand All @@ -199,7 +203,6 @@ def __init__(self):
self.rf_model = None
self.forest = None
self.instances = None
self.clf_dataset = None

self.color_palette = None

Expand Down Expand Up @@ -265,29 +268,32 @@ def __init__(self):
@Inputs.random_forest
def set_rf(self, model=None):
"""When a different forest is given."""
self.closeContext()
self.clear()
self.rf_model = model

if model is not None:
self.forest = self._get_forest_adapter(self.rf_model)
self.forest_model[:] = self.forest.trees

self.instances = model.instances
# This bit is important for the regression classifier
if self.instances is not None and self.instances.domain != model.domain:
self.clf_dataset = self.instances.transform(self.rf_model.domain)
else:
self.clf_dataset = self.instances

self._update_info_box()
self._update_target_class_combo()
self._update_depth_slider()

self.openContext(model)
# Restore item selection
if self.selected_index is not None:
index = self.list_view.model().index(self.selected_index)
selection = QItemSelection(index, index)
self.list_view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect)

def clear(self):
"""Clear all relevant data from the widget."""
self.rf_model = None
self.forest = None
self.forest_model.clear()
self.selected_index = None

self._clear_info_box()
self._clear_target_class_combo()
Expand Down Expand Up @@ -342,19 +348,19 @@ def onDeleteWidget(self):
super().onDeleteWidget()
self.clear()

def commit(self, selection):
# type: (QItemSelection) -> None
def commit(self, selection: QItemSelection) -> None:
"""Commit the selected tree to output."""
selected_indices = selection.indexes()

if not len(selected_indices):
self.selected_index = None
self.Outputs.tree.send(None)
return

selected_index, = selection.indexes()
# We only allow selecting a single tree so there will always be one index
self.selected_index = selected_indices[0].row()

idx = selected_index.row()
tree = self.rf_model.trees[idx]
tree = self.rf_model.trees[self.selected_index]
tree.instances = self.instances
tree.meta_target_class_index = self.target_class_index
tree.meta_size_calc_idx = self.size_calc_idx
Expand Down
12 changes: 12 additions & 0 deletions Orange/widgets/visualize/tests/test_owpythagorastree.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,15 @@ def test_forest_tree_table(self):
square.setSelected(True)
tab = self.get_output(tree_w.Outputs.selected_data, widget=tree_w)
self.assertGreater(len(tab), 0)

def test_changing_data_restores_depth_from_previous_settings(self):
titanic_data = Table("titanic")[::50]
forest = RandomForestLearner(n_estimators=3)(titanic_data)
forest.instances = titanic_data

self.send_signal(self.widget.Inputs.tree, forest.trees[0])
self.widget.controls.depth_limit.setValue(1)

# The domain is still the same, so restore the depth limit from before
self.send_signal(self.widget.Inputs.tree, forest.trees[1])
self.assertEqual(self.widget.ptree._depth_limit, 1)
22 changes: 21 additions & 1 deletion Orange/widgets/visualize/tests/test_owpythagoreanforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from unittest.mock import Mock

from AnyQt.QtCore import Qt
from AnyQt.QtCore import Qt, QItemSelection, QItemSelectionModel

from Orange.classification.random_forest import RandomForestLearner
from Orange.data import Table
Expand Down Expand Up @@ -201,3 +201,23 @@ def _callback():
# Check that individual squares all have the same color
colors_same = [self._check_all_same(x) for x in zip(*colors)]
self.assertTrue(all(colors_same))

def select_tree(self, idx: int) -> None:
list_view = self.widget.list_view
index = list_view.model().index(idx)
selection = QItemSelection(index, index)
list_view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect)

def test_storing_selection(self):
# Select one of the trees
idx = 1
self.send_signal(self.widget.Inputs.random_forest, self.titanic)
self.select_tree(idx)
# Clear input
self.send_signal(self.widget.Inputs.random_forest, None)
# Restore previous data; context settings should be restored
self.send_signal(self.widget.Inputs.random_forest, self.titanic)

output = self.get_output(self.widget.Outputs.tree)
self.assertIsNotNone(output)
self.assertIs(output.skl_model, self.titanic.trees[idx].skl_model)