Skip to content

Commit

Permalink
Small improvements to linear regression (#28)
Browse files Browse the repository at this point in the history
* Rename intercept setters and use intercept by default

* Enable cargo test without specifying features

To run tests, we need to instantiate ndarray-linalg with a blas
implementation. For dev-dependencies we use ndarray-linalg
with openblas.

* Comply with C-GETTER in linear regression
  • Loading branch information
paulkoerbitz authored Jul 31, 2020
1 parent d7a5828 commit c171cf0
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 49 deletions.
5 changes: 3 additions & 2 deletions linfa-linear/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "linfa-linear"
version = "0.1.0"
authors = ["Paul Körbitz <paul@koerbitz.me>"]
authors = ["Paul Körbitz / Google <koerbitz@google.com>"]
edition = "2018"
workspace = ".."

Expand All @@ -21,4 +21,5 @@ num-traits = {version="0.2"}
csv = "1.1"
ndarray-csv = "0.4"
approx = "0.3.2"
flate2 = "1.0"
flate2 = "1.0"
ndarray-linalg = {version = "0.12", features = ["openblas"]}
6 changes: 3 additions & 3 deletions linfa-linear/examples/diabetes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ fn main() -> Result<(), Box<dyn Error>> {
let target = read_array("../datasets/diabetes_target.csv.gz")?;
let target = target.column(0);

let lin_reg = LinearRegression::new().fit_intercept();
let lin_reg = LinearRegression::new();
let model = lin_reg.fit(&data, &target)?;

println!("intercept: {}", model.get_intercept());
println!("parameters: {}", model.get_params());
println!("intercept: {}", model.intercept());
println!("parameters: {}", model.params());

Ok(())
}
99 changes: 55 additions & 44 deletions linfa-linear/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,18 @@ pub struct LinearRegression {
#[derive(Clone, Copy, PartialEq, Eq)]
enum Options {
None,
FitIntercept,
FitInterceptAndNormalize,
WithIntercept,
WithInterceptAndNormalize,
}

fn fit_intercept(options: Options) -> bool {
options == Options::FitIntercept || options == Options::FitInterceptAndNormalize
}
impl Options {
fn should_use_intercept(&self) -> bool {
*self == Options::WithIntercept || *self == Options::WithInterceptAndNormalize
}

fn normalize(options: Options) -> bool {
options == Options::FitInterceptAndNormalize
fn should_normalize(&self) -> bool {
*self == Options::WithInterceptAndNormalize
}
}

/// A fitted linear regression model which can be used for making predictions.
Expand All @@ -70,19 +72,28 @@ pub struct FittedLinearRegression<A> {

/// Configure and fit a linear regression model
impl LinearRegression {
/// Create a default linear regression model. By default, no intercept
/// will be fitted and the feature matrix will not be normalized.
/// To change this, call `fit_intercept()` or
/// `fit_intercept_and_normalize()` before calling `fit()`.
/// Create a default linear regression model.
///
/// By default, an intercept will be fitted. To disable fitting an
/// intercept, call `.with_intercept(false)` before calling `.fit()`.
///
/// To additionally normalize the feature matrix before fitting, call
/// `fit_intercept_and_normalize()` before calling `fit()`. The feature
/// matrix will not be normalized by default.
pub fn new() -> LinearRegression {
LinearRegression {
options: Options::None,
options: Options::WithIntercept,
}
}

/// Configure the linear regression model to fit an intercept.
pub fn fit_intercept(mut self) -> Self {
self.options = Options::FitIntercept;
/// Defaults to `true` if not set.
pub fn with_intercept(mut self, with_intercept: bool) -> Self {
if with_intercept {
self.options = Options::WithIntercept;
} else {
self.options = Options::None;
}
self
}

Expand All @@ -94,8 +105,8 @@ impl LinearRegression {
/// are all within in a small range and all features are of similar size.
///
/// Normalization implies fitting an intercept.
pub fn fit_intercept_and_normalize(mut self) -> Self {
self.options = Options::FitInterceptAndNormalize;
pub fn with_intercept_and_normalize(mut self) -> Self {
self.options = Options::WithInterceptAndNormalize;
self
}

Expand Down Expand Up @@ -124,7 +135,7 @@ impl LinearRegression {
// Check that our inputs have compatible shapes
assert_eq!(y.dim(), n_samples);

if fit_intercept(self.options) {
if self.options.should_use_intercept() {
// If we are fitting the intercept, we first center X and y,
// compute the models parameters based on the centered X and y
// and the intercept as the residual of fitted parameters applied
Expand All @@ -136,7 +147,7 @@ impl LinearRegression {
let y_offset: A = y.mean().ok_or(String::from("cannot compute mean of y"))?;
let y_centered: Array1<A> = y - y_offset;
let params: Array1<A> =
compute_params(&X_centered, &y_centered, normalize(self.options))?;
compute_params(&X_centered, &y_centered, self.options.should_normalize())?;
let intercept: A = y_offset - X_offset.dot(&params);
return Ok(FittedLinearRegression {
intercept: intercept,
Expand Down Expand Up @@ -204,12 +215,12 @@ impl<A: Scalar + ScalarOperand> FittedLinearRegression<A> {
}

/// Get the fitted parameters
pub fn get_params(&self) -> &Array1<A> {
pub fn params(&self) -> &Array1<A> {
&self.params
}

/// Get the fitted intercept, 0. if no intercept was fitted
pub fn get_intercept(&self) -> A {
pub fn intercept(&self) -> A {
self.intercept
}
}
Expand All @@ -222,7 +233,7 @@ mod tests {

#[test]
fn fits_a_line_through_two_dots() {
let lin_reg = LinearRegression::new().fit_intercept();
let lin_reg = LinearRegression::new();
let A: Array2<f64> = array![[0.], [1.]];
let b: Array1<f64> = array![1., 2.];
let model = lin_reg.fit(&A, &b).unwrap();
Expand All @@ -231,12 +242,12 @@ mod tests {
assert!(result.abs_diff_eq(&array![1., 2.], 1e-12));
}

/// When `fit_intercept` is not set (the default), the
/// When `with_intercept` is set to false, the
/// fitted line runs through the origin. For a perfect
/// fit we only need to provide one point.
#[test]
fn without_intercept_fits_line_through_origin() {
let lin_reg = LinearRegression::new();
let lin_reg = LinearRegression::new().with_intercept(false);
let A: Array2<f64> = array![[1.]];
let b: Array1<f64> = array![1.];
let model = lin_reg.fit(&A, &b).unwrap();
Expand All @@ -252,7 +263,7 @@ mod tests {
/// f(x) = 0
#[test]
fn fits_least_squares_line_through_two_dots() {
let lin_reg = LinearRegression::new();
let lin_reg = LinearRegression::new().with_intercept(false);
let A: Array2<f64> = array![[-1.], [1.]];
let b: Array1<f64> = array![1., 1.];
let model = lin_reg.fit(&A, &b).unwrap();
Expand All @@ -268,7 +279,7 @@ mod tests {
/// f(x) = -1./3. + x
#[test]
fn fits_least_squares_line_through_three_dots() {
let lin_reg = LinearRegression::new().fit_intercept();
let lin_reg = LinearRegression::new();
let A: Array2<f64> = array![[0.], [1.], [2.]];
let b: Array1<f64> = array![0., 0., 2.];
let model = lin_reg.fit(&A, &b).unwrap();
Expand All @@ -282,41 +293,41 @@ mod tests {
/// f(x) = (x + 1)^2 = x^2 + 2x + 1
#[test]
fn fits_three_parameters_through_three_dots() {
let lin_reg = LinearRegression::new().fit_intercept();
let lin_reg = LinearRegression::new();
let A: Array2<f64> = array![[0., 0.], [1., 1.], [2., 4.]];
let b: Array1<f64> = array![1., 4., 9.];
let model = lin_reg.fit(&A, &b).unwrap();

assert!(model.get_params().abs_diff_eq(&array![2., 1.], 1e-12));
assert!(model.get_intercept().abs_diff_eq(&1., 1e-12));
assert!(model.params().abs_diff_eq(&array![2., 1.], 1e-12));
assert!(model.intercept().abs_diff_eq(&1., 1e-12));
}

/// Check that the linear regression prefectly fits four datapoints for
/// the model
/// f(x) = (x + 1)^3 = x^3 + 3x^2 + 3x + 1
#[test]
fn fits_four_parameters_through_four_dots() {
let lin_reg = LinearRegression::new().fit_intercept();
let lin_reg = LinearRegression::new();
let A: Array2<f64> = array![[0., 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]];
let b: Array1<f64> = array![1., 8., 27., 64.];
let model = lin_reg.fit(&A, &b).unwrap();

assert!(model.get_params().abs_diff_eq(&array![3., 3., 1.], 1e-12));
assert!(model.get_intercept().abs_diff_eq(&1., 1e-12));
assert!(model.params().abs_diff_eq(&array![3., 3., 1.], 1e-12));
assert!(model.intercept().abs_diff_eq(&1., 1e-12));
}

/// Check that the linear regression prefectly fits three datapoints for
/// the model
/// f(x) = (x + 1)^2 = x^2 + 2x + 1
#[test]
fn fits_three_parameters_through_three_dots_f32() {
let lin_reg = LinearRegression::new().fit_intercept();
let lin_reg = LinearRegression::new();
let A: Array2<f32> = array![[0., 0.], [1., 1.], [2., 4.]];
let b: Array1<f32> = array![1., 4., 9.];
let model = lin_reg.fit(&A, &b).unwrap();

assert!(model.get_params().abs_diff_eq(&array![2., 1.], 1e-4));
assert!(model.get_intercept().abs_diff_eq(&1., 1e-6));
assert!(model.params().abs_diff_eq(&array![2., 1.], 1e-4));
assert!(model.intercept().abs_diff_eq(&1., 1e-6));
}

/// Check that the linear regression prefectly fits four datapoints for
Expand All @@ -325,20 +336,20 @@ mod tests {
/// when normalization is enabled
#[test]
fn fits_four_parameters_through_four_dots_with_normalization() {
let lin_reg = LinearRegression::new().fit_intercept_and_normalize();
let lin_reg = LinearRegression::new().with_intercept_and_normalize();
let A: Array2<f64> = array![[0., 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]];
let b: Array1<f64> = array![1., 8., 27., 64.];
let model = lin_reg.fit(&A, &b).unwrap();

assert!(model.get_params().abs_diff_eq(&array![3., 3., 1.], 1e-12));
assert!(model.get_intercept().abs_diff_eq(&1., 1e-12));
assert!(model.params().abs_diff_eq(&array![3., 3., 1.], 1e-12));
assert!(model.intercept().abs_diff_eq(&1., 1e-12));
}

/// Check that the linear regression model works with both owned and view
/// representations of arrays
#[test]
fn works_with_viewed_and_owned_representations() {
let lin_reg = LinearRegression::new().fit_intercept_and_normalize();
let lin_reg = LinearRegression::new().with_intercept_and_normalize();
let A: Array2<f64> = array![[0., 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]];
let b: Array1<f64> = array![1., 8., 27., 64.];
let A_view = A.slice(s![.., ..]);
Expand All @@ -355,12 +366,12 @@ mod tests {
.fit(&A_view, &b_view)
.expect("can't fit viewed arrays");

assert_eq!(model1.get_params(), model2.get_params());
assert_eq!(model2.get_params(), model3.get_params());
assert_eq!(model3.get_params(), model4.get_params());
assert_eq!(model1.params(), model2.params());
assert_eq!(model2.params(), model3.params());
assert_eq!(model3.params(), model4.params());

assert_eq!(model1.get_intercept(), model2.get_intercept());
assert_eq!(model2.get_intercept(), model3.get_intercept());
assert_eq!(model3.get_intercept(), model4.get_intercept());
assert_eq!(model1.intercept(), model2.intercept());
assert_eq!(model2.intercept(), model3.intercept());
assert_eq!(model3.intercept(), model4.intercept());
}
}

0 comments on commit c171cf0

Please sign in to comment.