Skip to content

Commit

Permalink
Add SVC::decision_function (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrouille committed Jun 21, 2022
1 parent ff456df commit b4a807e
Showing 1 changed file with 57 additions and 10 deletions.
67 changes: 57 additions & 10 deletions src/svm/svc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,33 +263,41 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
/// Predicts estimated class labels from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();

let mut y_hat = M::RowVector::zeros(n);
let mut y_hat = self.decision_function(x)?;

for i in 0..n {
let cls_idx = match self.predict_for_row(x.get_row(i)) == T::one() {
for i in 0..y_hat.len() {
let cls_idx = match y_hat.get(i) > T::zero() {
false => self.classes[0],
true => self.classes[1],
};

y_hat.set(i, cls_idx);
}

Ok(y_hat)
}

/// Evaluates the decision function for the rows in `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn decision_function(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();
let mut y_hat = M::RowVector::zeros(n);

for i in 0..n {
y_hat.set(i, self.predict_for_row(x.get_row(i)));
}

Ok(y_hat)
}

fn predict_for_row(&self, x: M::RowVector) -> T {
let mut f = self.b;

for i in 0..self.instances.len() {
f += self.w[i] * self.kernel.apply(&x, &self.instances[i]);
}

if f > T::zero() {
T::one()
} else {
-T::one()
}
f
}
}

Expand Down Expand Up @@ -772,6 +780,45 @@ mod tests {
assert!(accuracy(&y_hat, &y) >= 0.9);
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn svc_fit_decision_function() {
let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]);

let x2 = DenseMatrix::from_2d_array(&[
&[3.0, 3.0],
&[4.0, 4.0],
&[6.0, 6.0],
&[10.0, 10.0],
&[1.0, 1.0],
&[0.0, 0.0],
]);

let y: Vec<f64> = vec![0., 0., 1., 1.];

let y_hat = SVC::fit(
&x,
&y,
SVCParameters::default()
.with_c(200.0)
.with_kernel(Kernels::linear()),
)
.and_then(|lr| lr.decision_function(&x2))
.unwrap();

// x can be classified by a straight line through [6.0, 0.0] and [0.0, 6.0],
// so the score should increase as points get further away from that line
println!("{:?}", y_hat);
assert!(y_hat[1] < y_hat[2]);
assert!(y_hat[2] < y_hat[3]);

// for negative scores the score should decrease
assert!(y_hat[4] > y_hat[5]);

// y_hat[0] is on the line, so its score should be close to 0
assert!(y_hat[0].abs() <= 0.1);
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn svc_fit_predict_rbf() {
Expand Down

0 comments on commit b4a807e

Please sign in to comment.