Skip to content

Commit

Permalink
Fixing error in CUR test
Browse files Browse the repository at this point in the history
  • Loading branch information
rosecers committed May 18, 2023
1 parent b83f986 commit 1ba16f8
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/skmatter/decomposition/_pcovr.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class PCovR(_BasePCA, LinearModel):
>>> Y = np.array([[0, -5], [-1, 1], [1, -5], [-3, 2]])
>>> pcovr = PCovR(mixing=0.1, n_components=2)
>>> pcovr.fit(X, Y)
PCovR(mixing=0.1, n_components=2, space='sample')
PCovR(mixing=0.1, n_components=2)
>>> pcovr.transform(X)
array([[ 3.2630561 , 0.06663787],
[-2.69395511, -0.41582771],
Expand Down
5 changes: 1 addition & 4 deletions tests/test_check_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from skmatter.feature_selection import PCovCUR as fPCovCUR
from skmatter.feature_selection import PCovFPS as fPCovFPS
from skmatter.linear_model import RidgeRegression2FoldCV # OrthogonalRegression,
from skmatter.preprocessing import (
KernelNormalizer,
StandardFlexibleScaler,
)
from skmatter.preprocessing import KernelNormalizer, StandardFlexibleScaler


@parametrize_with_checks(
Expand Down
5 changes: 2 additions & 3 deletions tests/test_sample_simple_cur.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest

import numpy as np
from sklearn import exceptions
from sklearn.datasets import fetch_california_housing as load

from skmatter.sample_selection import CUR, FPS
Expand All @@ -10,12 +9,12 @@
class TestCUR(unittest.TestCase):
def setUp(self):
self.X, _ = load(return_X_y=True)
self.X = FPS(n_to_select=100).fit(self.X).transform(self.X)
self.X = self.X[FPS(n_to_select=100).fit(self.X).selected_idx_]
self.n_select = min(20, min(self.X.shape) // 2)

def test_bad_transform(self):
selector = CUR(n_to_select=2)
with self.assertRaises(exceptions.NotFittedError):
with self.assertRaises(ValueError):
_ = selector.transform(self.X)

def test_restart(self):
Expand Down

0 comments on commit 1ba16f8

Please sign in to comment.