Skip to content

Commit

Permalink
[enhancement] remove sklearn_check_version dependence from `onedal/sv…
Browse files Browse the repository at this point in the history
…m` (#1835)

* Update svm.py

* Update _common.py

* Update nusvc.py

* Update nusvr.py

* Update svc.py

* Update svr.py

* Update test_csr_svm.py

* Update test_svc.py

* Update test_svr.py

* Update svm.cpp

* Update svm.py

* Update test_csr_svm.py

* Update test_svc.py

* Update test_svr.py

* Update _common.py

* Update nusvc.py

* Update nusvr.py

* Update svc.py

* Update svr.py

* Update _common.py

* Update sklearnex/svm/_common.py

Co-authored-by: Alexander Andreev <alexander.andreev@intel.com>

* Update _common.py

* Update nusvc.py

* linting

---------

Co-authored-by: Alexander Andreev <alexander.andreev@intel.com>
  • Loading branch information
icfaust and Alexsandruss authored May 22, 2024
1 parent 236d2fe commit 8b5de5a
Show file tree
Hide file tree
Showing 10 changed files with 289 additions and 260 deletions.
11 changes: 11 additions & 0 deletions onedal/svm/svm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,17 @@ void init_train_ops(py::module_& m) {
train_ops ops(policy, input_t{ data, responses, weights }, params2desc{});
return fptype2t{ method2t{ Task{}, kernel2t{ ops } } }(params);
});
m.def("train",
[](const Policy& policy,
const py::dict& params,
const table& data,
const table& responses) {
using namespace dal::svm;
using input_t = train_input<Task>;

train_ops ops(policy, input_t{ data, responses}, params2desc{});
return fptype2t{ method2t{ Task{}, kernel2t{ ops } } }(params);
});
}

template <typename Policy, typename Task>
Expand Down
176 changes: 39 additions & 137 deletions onedal/svm/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@

from abc import ABCMeta, abstractmethod
from enum import Enum
from numbers import Number, Real

import numpy as np
from scipy import sparse as sp
from sklearn.base import BaseEstimator

from daal4py.sklearn._utils import sklearn_check_version
from onedal import _backend

from ..common._estimator_checks import _check_is_fitted
Expand All @@ -45,7 +42,7 @@ class SVMtype(Enum):
nu_svr = 3


class BaseSVM(BaseEstimator, metaclass=ABCMeta):
class BaseSVM(metaclass=ABCMeta):
@abstractmethod
def __init__(
self,
Expand Down Expand Up @@ -87,133 +84,11 @@ def __init__(
self.algorithm = algorithm
self.svm_type = svm_type

def _compute_gamma_sigma(self, gamma, X):
if isinstance(gamma, str):
if gamma == "scale":
if sp.issparse(X):
# var = E[X^2] - E[X]^2
X_sc = (X.multiply(X)).mean() - (X.mean()) ** 2
else:
X_sc = X.var()
_gamma = 1.0 / (X.shape[1] * X_sc) if X_sc != 0 else 1.0
elif gamma == "auto":
_gamma = 1.0 / X.shape[1]
else:
raise ValueError(
"When 'gamma' is a string, it should be either 'scale' or "
"'auto'. Got '{}' instead.".format(gamma)
)
else:
if sklearn_check_version("1.1") and not sklearn_check_version("1.2"):
if isinstance(gamma, Real):
if gamma <= 0:
msg = (
f"gamma value must be > 0; {gamma!r} is invalid. Use"
" a positive number or use 'auto' to set gamma to a"
" value of 1 / n_features."
)
raise ValueError(msg)
_gamma = gamma
else:
msg = (
"The gamma value should be set to 'scale', 'auto' or a"
f" positive float value. {gamma!r} is not a valid option"
)
raise ValueError(msg)
else:
_gamma = gamma
return _gamma, np.sqrt(0.5 / _gamma)

def _validate_targets(self, y, dtype):
self.class_weight_ = None
self.classes_ = None
return _column_or_1d(y, warn=True).astype(dtype, copy=False)

def _get_sample_weight(self, X, y, sample_weight):
n_samples = X.shape[0]
dtype = X.dtype
if n_samples == 1:
raise ValueError("n_samples=1")

sample_weight = np.asarray(
[] if sample_weight is None else sample_weight, dtype=np.float64
)

sample_weight_count = sample_weight.shape[0]
if sample_weight_count != 0 and sample_weight_count != n_samples:
raise ValueError(
"sample_weight and X have incompatible shapes: "
"%r vs %r\n"
"Note: Sparse matrices cannot be indexed w/"
"boolean masks (use `indices=True` in CV)."
% (len(sample_weight), X.shape)
)

ww = None
if sample_weight_count == 0 and self.class_weight_ is None:
return ww

if sample_weight_count == 0:
sample_weight = np.ones(n_samples, dtype=dtype)
elif isinstance(sample_weight, Number):
sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
else:
sample_weight = _check_array(
sample_weight,
accept_sparse=False,
ensure_2d=False,
dtype=dtype,
order="C",
)
if sample_weight.ndim != 1:
raise ValueError("Sample weights must be 1D array or scalar")

if sample_weight.shape != (n_samples,):
raise ValueError(
"sample_weight.shape == {}, expected {}!".format(
sample_weight.shape, (n_samples,)
)
)

if self.svm_type == SVMtype.nu_svc:
weight_per_class = [
np.sum(sample_weight[y == class_label]) for class_label in np.unique(y)
]

for i in range(len(weight_per_class)):
for j in range(i + 1, len(weight_per_class)):
if self.nu * (weight_per_class[i] + weight_per_class[j]) / 2 > min(
weight_per_class[i], weight_per_class[j]
):
raise ValueError("specified nu is infeasible")

if np.all(sample_weight <= 0):
if self.svm_type == SVMtype.nu_svc:
err_msg = "negative dimensions are not allowed"
else:
err_msg = "Invalid input - all samples have zero or negative weights."
raise ValueError(err_msg)
if np.any(sample_weight <= 0):
if self.svm_type == SVMtype.c_svc and len(
np.unique(y[sample_weight > 0])
) != len(self.classes_):
raise ValueError(
"Invalid input - all samples with positive weights "
"belong to the same class"
if sklearn_check_version("1.2")
else "Invalid input - all samples with positive weights "
"have the same label."
)
ww = sample_weight
if self.class_weight_ is not None:
for i, v in enumerate(self.class_weight_):
ww[y == i] *= v

if not ww.flags.c_contiguous and not ww.flags.f_contiguous:
ww = np.ascontiguousarray(ww, dtype)

return ww

def _get_onedal_params(self, data):
max_iter = 10000 if self.max_iter == -1 else self.max_iter
# TODO: remove this workaround
Expand Down Expand Up @@ -247,12 +122,6 @@ def _fit(self, X, y, sample_weight, module, queue):
f"got {self.decision_function_shape}."
)

if y is None:
if self._get_tags()["requires_y"]:
raise ValueError(
f"This {self.__class__.__name__} estimator "
f"requires y to be passed, but the target y is None."
)
X, y = _check_X_y(
X,
y,
Expand All @@ -261,19 +130,52 @@ def _fit(self, X, y, sample_weight, module, queue):
accept_sparse="csr",
)
y = self._validate_targets(y, X.dtype)
sample_weight = self._get_sample_weight(X, y, sample_weight)

if sample_weight is not None and len(sample_weight) > 0:
sample_weight = _check_array(
sample_weight,
accept_sparse=False,
ensure_2d=False,
dtype=X.dtype,
order="C",
)
elif self.class_weight is not None:
sample_weight = np.ones(X.shape[0], dtype=X.dtype)

if sample_weight is not None:
if self.class_weight_ is not None:
for i, v in enumerate(self.class_weight_):
sample_weight[y == i] *= v
data = (X, y, sample_weight)
else:
data = (X, y)
self._sparse = sp.issparse(X)

if self.kernel == "linear":
self._scale_, self._sigma_ = 1.0, 1.0
self.coef0 = 0.0
else:
self._scale_, self._sigma_ = self._compute_gamma_sigma(self.gamma, X)
if isinstance(self.gamma, str):
if self.gamma == "scale":
if sp.issparse(X):
# var = E[X^2] - E[X]^2
X_sc = (X.multiply(X)).mean() - (X.mean()) ** 2
else:
X_sc = X.var()
_gamma = 1.0 / (X.shape[1] * X_sc) if X_sc != 0 else 1.0
elif self.gamma == "auto":
_gamma = 1.0 / X.shape[1]
else:
raise ValueError(
"When 'gamma' is a string, it should be either 'scale' or "
"'auto'. Got '{}' instead.".format(self.gamma)
)
else:
_gamma = self.gamma
self._scale_, self._sigma_ = _gamma, np.sqrt(0.5 / _gamma)

policy = _get_policy(queue, X, y, sample_weight)
policy = _get_policy(queue, *data)
params = self._get_onedal_params(X)
result = module.train(policy, params, *to_table(X, y, sample_weight))
result = module.train(policy, params, *to_table(*data))

if self._sparse:
self.dual_coef_ = sp.csr_matrix(from_table(result.coeffs).T)
Expand Down
34 changes: 17 additions & 17 deletions onedal/svm/tests/test_csr_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,19 @@
from numpy.testing import assert_array_almost_equal, assert_array_equal
from scipy import sparse as sp
from sklearn import datasets
from sklearn.base import clone as clone_estimator
from sklearn.datasets import make_classification

from onedal.common._mixin import ClassifierMixin
from onedal.svm import SVC, SVR
from onedal.tests.utils._device_selection import (
get_queues,
pass_if_not_implemented_for_gpu,
)


def is_classifier(estimator):
return getattr(estimator, "_estimator_type", None) == "classifier"


def check_svm_model_equal(queue, svm, X_train, y_train, X_test, decimal=6):
sparse_svm = clone_estimator(svm)
dense_svm = clone_estimator(svm)
def check_svm_model_equal(
queue, dense_svm, sparse_svm, X_train, y_train, X_test, decimal=6
):
dense_svm.fit(X_train.toarray(), y_train, queue=queue)
if sp.issparse(X_test):
X_test_dense = X_test.toarray()
Expand All @@ -56,7 +52,7 @@ def check_svm_model_equal(queue, svm, X_train, y_train, X_test, decimal=6):
sparse_svm.predict(X_test, queue=queue),
)

if is_classifier(svm):
if isinstance(dense_svm, ClassifierMixin) and isinstance(sparse_svm, ClassifierMixin):
assert_array_almost_equal(
dense_svm.decision_function(X_test_dense, queue=queue),
sparse_svm.decision_function(X_test, queue=queue),
Expand All @@ -73,8 +69,9 @@ def _test_simple_dataset(queue, kernel):
sparse_X2 = sp.dok_matrix(X2)

dataset = sparse_X, Y, sparse_X2
clf = SVC(kernel=kernel, gamma=1)
check_svm_model_equal(queue, clf, *dataset)
clf0 = SVC(kernel=kernel, gamma=1)
clf1 = SVC(kernel=kernel, gamma=1)
check_svm_model_equal(queue, clf0, clf1, *dataset)


@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented")
Expand All @@ -101,8 +98,9 @@ def _test_binary_dataset(queue, kernel):
sparse_X = sp.csr_matrix(X)

dataset = sparse_X, y, sparse_X
clf = SVC(kernel=kernel)
check_svm_model_equal(queue, clf, *dataset)
clf0 = SVC(kernel=kernel)
clf1 = SVC(kernel=kernel)
check_svm_model_equal(queue, clf0, clf1, *dataset)


@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented")
Expand Down Expand Up @@ -135,8 +133,9 @@ def _test_iris(queue, kernel):

dataset = sparse_iris_data, iris.target, sparse_iris_data

clf = SVC(kernel=kernel)
check_svm_model_equal(queue, clf, *dataset, decimal=2)
clf0 = SVC(kernel=kernel)
clf1 = SVC(kernel=kernel)
check_svm_model_equal(queue, clf0, clf1, *dataset, decimal=2)


@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented")
Expand All @@ -152,8 +151,9 @@ def _test_diabetes(queue, kernel):
sparse_diabetes_data = sp.csr_matrix(diabetes.data)
dataset = sparse_diabetes_data, diabetes.target, sparse_diabetes_data

clf = SVR(kernel=kernel, C=0.1)
check_svm_model_equal(queue, clf, *dataset)
clf0 = SVR(kernel=kernel, C=0.1)
clf1 = SVR(kernel=kernel, C=0.1)
check_svm_model_equal(queue, clf0, clf1, *dataset)


@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented")
Expand Down
37 changes: 0 additions & 37 deletions onedal/svm/tests/test_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from sklearn.datasets import make_blobs
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.model_selection import train_test_split
from sklearn.utils.estimator_checks import check_estimator

from onedal.svm import SVC
from onedal.tests.utils._device_selection import (
Expand All @@ -31,42 +30,6 @@
)


def _replace_and_save(md, fns, replacing_fn):
saved = dict()
for check_f in fns:
try:
fn = getattr(md, check_f)
setattr(md, check_f, replacing_fn)
saved[check_f] = fn
except RuntimeError:
pass
return saved


def _restore_from_saved(md, saved_dict):
for check_f in saved_dict:
setattr(md, check_f, saved_dict[check_f])


def test_estimator():
def dummy(*args, **kwargs):
pass

md = sklearn.utils.estimator_checks
saved = _replace_and_save(
md,
[
"check_sample_weights_invariance", # Max absolute difference: 0.0008
"check_estimators_fit_returns_self", # ValueError: empty metadata
"check_classifiers_train", # assert y_pred.shape == (n_samples,)
"check_estimators_unfitted", # Call 'fit' with appropriate arguments
],
dummy,
)
check_estimator(SVC())
_restore_from_saved(md, saved)


def _test_libsvm_parameters(queue, array_constr, dtype):
X = array_constr([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=dtype)
y = array_constr([1, 1, 1, 2, 2, 2], dtype=dtype)
Expand Down
Loading

0 comments on commit 8b5de5a

Please sign in to comment.