Skip to content

Commit

Permalink
Merge pull request biolab#1996 from ales-erjavec/fixes/test-learners-…
Browse files Browse the repository at this point in the history
…fix-one-vs-rest

[FIX] Test Learners: Fix AUC for selected single target class
  • Loading branch information
markotoplak authored Feb 3, 2017
2 parents 8ca36a3 + 4d5d38a commit 4830c26
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 3 deletions.
36 changes: 33 additions & 3 deletions Orange/widgets/evaluate/owtestlearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,11 +709,41 @@ def results_merge(results):


def results_one_vs_rest(results, pos_index):
from Orange.preprocess.transformation import Indicator
actual = results.actual == pos_index
predicted = results.predicted == pos_index
return Orange.evaluation.Results(
nmethods=1, domain=results.domain,
actual=actual, predicted=predicted)
if results.probabilities is not None:
c = results.probabilities.shape[2]
assert c >= 2
neg_indices = [i for i in range(c) if i != pos_index]
pos_prob = results.probabilities[:, :, [pos_index]]
neg_prob = np.sum(results.probabilities[:, :, neg_indices],
axis=2, keepdims=True)
probabilities = np.dstack((neg_prob, pos_prob))
else:
probabilities = None

res = Orange.evaluation.Results()
res.actual = actual
res.predicted = predicted
res.folds = results.folds
res.row_indices = results.row_indices
res.probabilities = probabilities

value = results.domain.class_var.values[pos_index]
class_var = Orange.data.DiscreteVariable(
"I({}=={})".format(results.domain.class_var.name, value),
values=["False", "True"],
compute_value=Indicator(results.domain.class_var, pos_index)
)
domain = Orange.data.Domain(
results.domain.attributes,
[class_var],
results.domain.metas
)
res.data = None
res.domain = domain
return res


def main(argv=None):
Expand Down
67 changes: 67 additions & 0 deletions Orange/widgets/evaluate/tests/test_owtestlearners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# pylint: disable=missing-docstring
import numpy as np

import unittest

from Orange.data import Table
from Orange.classification import MajorityLearner
from Orange.regression import MeanLearner

from Orange.evaluation import Results, TestOnTestData
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.evaluate.owtestlearners import OWTestLearners
from Orange.widgets.evaluate import owtestlearners


class TestOWTestLearners(WidgetTest):
def setUp(self):
super().setUp()
self.widget = self.create_widget(OWTestLearners) # type: OWTestLearners

def test_basic(self):
data = Table("iris")[::3]
self.send_signal("Data", data)
self.send_signal("Learner", MajorityLearner(), 0)
res = self.get_output("Evaluation Results")
self.assertIsInstance(res, Results)
self.assertIsNotNone(res.domain)
self.assertIsNotNone(res.data)
self.assertIsNotNone(res.probabilities)

self.send_signal("Learner", None, 0)

data = Table("housing")[::10]
self.send_signal("Data", data)
self.send_signal("Learner", MeanLearner(), 0)
res = self.get_output("Evaluation Results")
self.assertIsInstance(res, Results)
self.assertIsNotNone(res.domain)
self.assertIsNotNone(res.data)


class TestHelpers(unittest.TestCase):
def test_results_one_vs_rest(self):
data = Table("lenses")
learners = [MajorityLearner()]
res = TestOnTestData(data[1::2], data[::2], learners=learners)
r1 = owtestlearners.results_one_vs_rest(res, pos_index=0)
r2 = owtestlearners.results_one_vs_rest(res, pos_index=1)
r3 = owtestlearners.results_one_vs_rest(res, pos_index=2)

np.testing.assert_almost_equal(np.sum(r1.probabilities, axis=2), 1.0)
np.testing.assert_almost_equal(np.sum(r2.probabilities, axis=2), 1.0)
np.testing.assert_almost_equal(np.sum(r3.probabilities, axis=2), 1.0)

np.testing.assert_almost_equal(
r1.probabilities[:, :, 1] +
r2.probabilities[:, :, 1] +
r3.probabilities[:, :, 1],
1.0
)
self.assertEqual(r1.folds, res.folds)
self.assertEqual(r2.folds, res.folds)
self.assertEqual(r3.folds, res.folds)

np.testing.assert_equal(r1.row_indices, res.row_indices)
np.testing.assert_equal(r2.row_indices, res.row_indices)
np.testing.assert_equal(r3.row_indices, res.row_indices)

0 comments on commit 4830c26

Please sign in to comment.