From e94f5000256dedf3205e0a6c1b5b2d0ad80720a5 Mon Sep 17 00:00:00 2001 From: janezd Date: Wed, 15 Jan 2020 17:44:36 +0100 Subject: [PATCH] Test and Score: Minor changes after review; can be squashed into the first commit --- Orange/widgets/evaluate/owtestlearners.py | 73 +++++++++++++------ .../evaluate/tests/test_owtestlearners.py | 4 +- 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/Orange/widgets/evaluate/owtestlearners.py b/Orange/widgets/evaluate/owtestlearners.py index 47c3ecf604d..2fa6fdbbc51 100644 --- a/Orange/widgets/evaluate/owtestlearners.py +++ b/Orange/widgets/evaluate/owtestlearners.py @@ -181,7 +181,7 @@ class Outputs: use_rope = settings.Setting(False) rope = settings.Setting(0.1) - comparison_criterion = settings.Setting(0) + comparison_criterion = settings.Setting(0, schema_only=True) TARGET_AVERAGE = "(Average over classes)" class_selection = settings.ContextSetting(TARGET_AVERAGE) @@ -224,6 +224,7 @@ def __init__(self): self.train_data_missing_vals = False self.test_data_missing_vals = False self.scorers = [] + self.__pending_comparison_criterion = self.comparison_criterion #: An Ordered dictionary with current inputs and their testing results. self.learners = OrderedDict() # type: Dict[Any, Input] @@ -291,7 +292,7 @@ def __init__(self): hbox = gui.hBox(box) gui.checkBox(hbox, self, "use_rope", "Negligible difference: ", - callback=self.update_comparison_table) + callback=self._on_use_rope_changed) gui.lineEdit(hbox, self, "rope", validator=QDoubleValidator(), controlWidth=70, callback=self.update_comparison_table, alignment=Qt.AlignRight) @@ -315,14 +316,15 @@ def __init__(self): header.setSectionsClickable(False) header = table.horizontalHeader() - header.setSectionResizeMode(QHeaderView.ResizeToContents) - avg_width = self.fontMetrics().averageCharWidth() - header.setMinimumSectionSize(8 * avg_width) - header.setMaximumSectionSize(15 * avg_width) header.setTextElideMode(Qt.ElideRight) header.setDefaultAlignment(Qt.AlignCenter) header.setSectionsClickable(False) header.setStretchLastSection(False) + header.setSectionResizeMode(QHeaderView.ResizeToContents) + avg_width = self.fontMetrics().averageCharWidth() + header.setMinimumSectionSize(8 * avg_width) + header.setMaximumSectionSize(15 * avg_width) + header.setDefaultSectionSize(15 * avg_width) box.layout().addWidget(table) box.layout().addWidget(QLabel( "Table shows probabilities that the score for the model in " @@ -490,6 +492,12 @@ def _update_scorers(self): self.scorers = usable_scorers(self.data.domain.class_var) self.controls.comparison_criterion.model()[:] = \ [scorer.long_name or scorer.name for scorer in self.scorers] + if self.__pending_comparison_criterion is not None: + # Check for the unlikely case that some scorers have been removed + # from modules + if self.__pending_comparison_criterion < len(self.scorers): + self.comparison_criterion = self.__pending_comparison_criterion + self.__pending_comparison_criterion = None @Inputs.preprocessor def set_preprocessor(self, preproc): @@ -503,7 +511,7 @@ def handleNewSignals(self): """Reimplemented from OWWidget.handleNewSignals.""" self._update_class_selection() self.score_table.update_header(self.scorers) - self._update_comparison_enabled() + self._update_view_enabled() self.update_stats_model() if self.__needupdate: self.__update() @@ -522,14 +530,18 @@ def shuffle_split_changed(self): def _param_changed(self): self.modcompbox.setEnabled(self.resampling == OWTestLearners.KFold) - self._update_comparison_enabled() + self._update_view_enabled() self._invalidate() self.__update() - def _update_comparison_enabled(self): + def _update_view_enabled(self): self.comparison_table.setEnabled( self.resampling == OWTestLearners.KFold - and len(self.learners) > 1) + and len(self.learners) > 1 + and self.data is not None) + self.score_table.view.setEnabled( + len(self.learners) > 1 + and self.data is not None) def update_stats_model(self): # Update the results_model with up to date scores. @@ -552,8 +564,10 @@ def update_stats_model(self): errors = [] has_missing_scores = False + names = [] for key, slot in self.learners.items(): name = learner_name(slot.learner) + names.append(name) head = QStandardItem(name) head.setData(key, Qt.UserRole) results = slot.results @@ -616,18 +630,23 @@ def update_stats_model(self): header.sortIndicatorSection(), header.sortIndicatorOrder() ) + self._set_comparison_headers(names) self.error("\n".join(errors), shown=bool(errors)) self.Warning.scores_not_computed(shown=has_missing_scores) + def _on_use_rope_changed(self): + self.controls.rope.setEnabled(self.use_rope) + self.update_comparison_table() + def update_comparison_table(self): self.comparison_table.clearContents() - if self.resampling != OWTestLearners.KFold: - return - slots = self._successful_slots() - scores = self._scores_by_folds(slots) - self._fill_table(slots, scores) + names = [learner_name(slot.learner) for slot in slots] + self._set_comparison_headers(names) + if self.resampling == OWTestLearners.KFold: + scores = self._scores_by_folds(slots) + self._fill_table(names, scores) def _successful_slots(self): model = self.score_table.model @@ -639,6 +658,19 @@ def _successful_slots(self): if slot.results is not None and slot.results.success] return slots + def _set_comparison_headers(self, names): + table = self.comparison_table + table.setRowCount(len(names)) + table.setColumnCount(len(names)) + table.setVerticalHeaderLabels(names) + table.setHorizontalHeaderLabels(names) + header = table.horizontalHeader() + if len(names) > 2: + header.setSectionResizeMode(QHeaderView.Stretch) + else: + header.setSectionResizeMode(QHeaderView.Fixed) + + def _scores_by_folds(self, slots): scorer = self.scorers[self.comparison_criterion]() self.compbox.setTitle(f"Model comparison by {scorer.name}") @@ -665,15 +697,8 @@ def thunked(): self.Warning.scores_not_computed() return scores - def _fill_table(self, slots, scores): + def _fill_table(self, names, scores): table = self.comparison_table - table.setRowCount(len(slots)) - table.setColumnCount(len(slots)) - - names = [learner_name(slot.learner) for slot in slots] - table.setVerticalHeaderLabels(names) - table.setHorizontalHeaderLabels(names) - for row, row_name, row_scores in zip(count(), names, scores): for col, col_name, col_scores in zip(range(row), names, scores): if row_scores is None or col_scores is None: @@ -682,7 +707,7 @@ def _fill_table(self, slots, scores): p0, rope, p1 = baycomp.two_on_single( row_scores, col_scores, self.rope) self._set_cell(table, row, col, - f"{p0:.3f}
{rope:.3f})", + f"{p0:.3f}
{rope:.3f}", f"p({row_name} > {col_name}) = {p0:.3f}\n" f"p({row_name} = {col_name}) = {rope:.3f}") self._set_cell(table, col, row, diff --git a/Orange/widgets/evaluate/tests/test_owtestlearners.py b/Orange/widgets/evaluate/tests/test_owtestlearners.py index 149f8665e8b..6b9b81c0b5b 100644 --- a/Orange/widgets/evaluate/tests/test_owtestlearners.py +++ b/Orange/widgets/evaluate/tests/test_owtestlearners.py @@ -418,6 +418,8 @@ def _set_three_majorities(self): @patch("baycomp.two_on_single", Mock(wraps=baycomp.two_on_single)) def test_comparison_requires_cv(self): w = self.widget + self.send_signal(w.Inputs.train_data, Table("iris")[::15]) + w.comparison_criterion = 1 rbs = w.controls.resampling.buttons @@ -448,7 +450,6 @@ def test_comparison_requires_cv(self): baycomp.two_on_single.assert_called() baycomp.two_on_single.reset_mock() - @patch("baycomp.two_on_single", Mock(wraps=baycomp.two_on_single)) def test_comparison_requires_multiple_models(self): w = self.widget w.comparison_criterion = 1 @@ -482,7 +483,6 @@ def test_comparison_requires_multiple_models(self): self.get_output(self.widget.Outputs.evaluations_results, wait=5000) self.assertTrue(w.comparison_table.isEnabled()) - @patch("baycomp.two_on_single", Mock(wraps=baycomp.two_on_single)) def test_comparison_bad_slots(self): w = self.widget self._set_three_majorities()