diff --git a/linfa-linear/Cargo.toml b/linfa-linear/Cargo.toml index a6fddd973..a169fb0a2 100644 --- a/linfa-linear/Cargo.toml +++ b/linfa-linear/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "linfa-linear" version = "0.1.0" -authors = ["Paul Körbitz "] +authors = ["Paul Körbitz / Google "] edition = "2018" workspace = ".." @@ -21,4 +21,5 @@ num-traits = {version="0.2"} csv = "1.1" ndarray-csv = "0.4" approx = "0.3.2" -flate2 = "1.0" \ No newline at end of file +flate2 = "1.0" +ndarray-linalg = {version = "0.12", features = ["openblas"]} \ No newline at end of file diff --git a/linfa-linear/examples/diabetes.rs b/linfa-linear/examples/diabetes.rs index 302532c7a..516ec8098 100644 --- a/linfa-linear/examples/diabetes.rs +++ b/linfa-linear/examples/diabetes.rs @@ -18,11 +18,11 @@ fn main() -> Result<(), Box> { let target = read_array("../datasets/diabetes_target.csv.gz")?; let target = target.column(0); - let lin_reg = LinearRegression::new().fit_intercept(); + let lin_reg = LinearRegression::new(); let model = lin_reg.fit(&data, &target)?; - println!("intercept: {}", model.get_intercept()); - println!("parameters: {}", model.get_params()); + println!("intercept: {}", model.intercept()); + println!("parameters: {}", model.params()); Ok(()) } diff --git a/linfa-linear/src/lib.rs b/linfa-linear/src/lib.rs index 1b38b2dcc..ed7870d08 100644 --- a/linfa-linear/src/lib.rs +++ b/linfa-linear/src/lib.rs @@ -50,16 +50,18 @@ pub struct LinearRegression { #[derive(Clone, Copy, PartialEq, Eq)] enum Options { None, - FitIntercept, - FitInterceptAndNormalize, + WithIntercept, + WithInterceptAndNormalize, } -fn fit_intercept(options: Options) -> bool { - options == Options::FitIntercept || options == Options::FitInterceptAndNormalize -} +impl Options { + fn should_use_intercept(&self) -> bool { + *self == Options::WithIntercept || *self == Options::WithInterceptAndNormalize + } -fn normalize(options: Options) -> bool { - options == Options::FitInterceptAndNormalize + fn should_normalize(&self) -> bool { + *self == Options::WithInterceptAndNormalize + } } /// A fitted linear regression model which can be used for making predictions. @@ -70,19 +72,28 @@ pub struct FittedLinearRegression { /// Configure and fit a linear regression model impl LinearRegression { - /// Create a default linear regression model. By default, no intercept - /// will be fitted and the feature matrix will not be normalized. - /// To change this, call `fit_intercept()` or - /// `fit_intercept_and_normalize()` before calling `fit()`. + /// Create a default linear regression model. + /// + /// By default, an intercept will be fitted. To disable fitting an + /// intercept, call `.with_intercept(false)` before calling `.fit()`. + /// + /// To additionally normalize the feature matrix before fitting, call + /// `fit_intercept_and_normalize()` before calling `fit()`. The feature + /// matrix will not be normalized by default. pub fn new() -> LinearRegression { LinearRegression { - options: Options::None, + options: Options::WithIntercept, } } /// Configure the linear regression model to fit an intercept. - pub fn fit_intercept(mut self) -> Self { - self.options = Options::FitIntercept; + /// Defaults to `true` if not set. + pub fn with_intercept(mut self, with_intercept: bool) -> Self { + if with_intercept { + self.options = Options::WithIntercept; + } else { + self.options = Options::None; + } self } @@ -94,8 +105,8 @@ impl LinearRegression { /// are all within in a small range and all features are of similar size. /// /// Normalization implies fitting an intercept. - pub fn fit_intercept_and_normalize(mut self) -> Self { - self.options = Options::FitInterceptAndNormalize; + pub fn with_intercept_and_normalize(mut self) -> Self { + self.options = Options::WithInterceptAndNormalize; self } @@ -124,7 +135,7 @@ impl LinearRegression { // Check that our inputs have compatible shapes assert_eq!(y.dim(), n_samples); - if fit_intercept(self.options) { + if self.options.should_use_intercept() { // If we are fitting the intercept, we first center X and y, // compute the models parameters based on the centered X and y // and the intercept as the residual of fitted parameters applied @@ -136,7 +147,7 @@ impl LinearRegression { let y_offset: A = y.mean().ok_or(String::from("cannot compute mean of y"))?; let y_centered: Array1 = y - y_offset; let params: Array1 = - compute_params(&X_centered, &y_centered, normalize(self.options))?; + compute_params(&X_centered, &y_centered, self.options.should_normalize())?; let intercept: A = y_offset - X_offset.dot(¶ms); return Ok(FittedLinearRegression { intercept: intercept, @@ -204,12 +215,12 @@ impl FittedLinearRegression { } /// Get the fitted parameters - pub fn get_params(&self) -> &Array1 { + pub fn params(&self) -> &Array1 { &self.params } /// Get the fitted intercept, 0. if no intercept was fitted - pub fn get_intercept(&self) -> A { + pub fn intercept(&self) -> A { self.intercept } } @@ -222,7 +233,7 @@ mod tests { #[test] fn fits_a_line_through_two_dots() { - let lin_reg = LinearRegression::new().fit_intercept(); + let lin_reg = LinearRegression::new(); let A: Array2 = array![[0.], [1.]]; let b: Array1 = array![1., 2.]; let model = lin_reg.fit(&A, &b).unwrap(); @@ -231,12 +242,12 @@ mod tests { assert!(result.abs_diff_eq(&array![1., 2.], 1e-12)); } - /// When `fit_intercept` is not set (the default), the + /// When `with_intercept` is set to false, the /// fitted line runs through the origin. For a perfect /// fit we only need to provide one point. #[test] fn without_intercept_fits_line_through_origin() { - let lin_reg = LinearRegression::new(); + let lin_reg = LinearRegression::new().with_intercept(false); let A: Array2 = array![[1.]]; let b: Array1 = array![1.]; let model = lin_reg.fit(&A, &b).unwrap(); @@ -252,7 +263,7 @@ mod tests { /// f(x) = 0 #[test] fn fits_least_squares_line_through_two_dots() { - let lin_reg = LinearRegression::new(); + let lin_reg = LinearRegression::new().with_intercept(false); let A: Array2 = array![[-1.], [1.]]; let b: Array1 = array![1., 1.]; let model = lin_reg.fit(&A, &b).unwrap(); @@ -268,7 +279,7 @@ mod tests { /// f(x) = -1./3. + x #[test] fn fits_least_squares_line_through_three_dots() { - let lin_reg = LinearRegression::new().fit_intercept(); + let lin_reg = LinearRegression::new(); let A: Array2 = array![[0.], [1.], [2.]]; let b: Array1 = array![0., 0., 2.]; let model = lin_reg.fit(&A, &b).unwrap(); @@ -282,13 +293,13 @@ mod tests { /// f(x) = (x + 1)^2 = x^2 + 2x + 1 #[test] fn fits_three_parameters_through_three_dots() { - let lin_reg = LinearRegression::new().fit_intercept(); + let lin_reg = LinearRegression::new(); let A: Array2 = array![[0., 0.], [1., 1.], [2., 4.]]; let b: Array1 = array![1., 4., 9.]; let model = lin_reg.fit(&A, &b).unwrap(); - assert!(model.get_params().abs_diff_eq(&array![2., 1.], 1e-12)); - assert!(model.get_intercept().abs_diff_eq(&1., 1e-12)); + assert!(model.params().abs_diff_eq(&array![2., 1.], 1e-12)); + assert!(model.intercept().abs_diff_eq(&1., 1e-12)); } /// Check that the linear regression prefectly fits four datapoints for @@ -296,13 +307,13 @@ mod tests { /// f(x) = (x + 1)^3 = x^3 + 3x^2 + 3x + 1 #[test] fn fits_four_parameters_through_four_dots() { - let lin_reg = LinearRegression::new().fit_intercept(); + let lin_reg = LinearRegression::new(); let A: Array2 = array![[0., 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]]; let b: Array1 = array![1., 8., 27., 64.]; let model = lin_reg.fit(&A, &b).unwrap(); - assert!(model.get_params().abs_diff_eq(&array![3., 3., 1.], 1e-12)); - assert!(model.get_intercept().abs_diff_eq(&1., 1e-12)); + assert!(model.params().abs_diff_eq(&array![3., 3., 1.], 1e-12)); + assert!(model.intercept().abs_diff_eq(&1., 1e-12)); } /// Check that the linear regression prefectly fits three datapoints for @@ -310,13 +321,13 @@ mod tests { /// f(x) = (x + 1)^2 = x^2 + 2x + 1 #[test] fn fits_three_parameters_through_three_dots_f32() { - let lin_reg = LinearRegression::new().fit_intercept(); + let lin_reg = LinearRegression::new(); let A: Array2 = array![[0., 0.], [1., 1.], [2., 4.]]; let b: Array1 = array![1., 4., 9.]; let model = lin_reg.fit(&A, &b).unwrap(); - assert!(model.get_params().abs_diff_eq(&array![2., 1.], 1e-4)); - assert!(model.get_intercept().abs_diff_eq(&1., 1e-6)); + assert!(model.params().abs_diff_eq(&array![2., 1.], 1e-4)); + assert!(model.intercept().abs_diff_eq(&1., 1e-6)); } /// Check that the linear regression prefectly fits four datapoints for @@ -325,20 +336,20 @@ mod tests { /// when normalization is enabled #[test] fn fits_four_parameters_through_four_dots_with_normalization() { - let lin_reg = LinearRegression::new().fit_intercept_and_normalize(); + let lin_reg = LinearRegression::new().with_intercept_and_normalize(); let A: Array2 = array![[0., 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]]; let b: Array1 = array![1., 8., 27., 64.]; let model = lin_reg.fit(&A, &b).unwrap(); - assert!(model.get_params().abs_diff_eq(&array![3., 3., 1.], 1e-12)); - assert!(model.get_intercept().abs_diff_eq(&1., 1e-12)); + assert!(model.params().abs_diff_eq(&array![3., 3., 1.], 1e-12)); + assert!(model.intercept().abs_diff_eq(&1., 1e-12)); } /// Check that the linear regression model works with both owned and view /// representations of arrays #[test] fn works_with_viewed_and_owned_representations() { - let lin_reg = LinearRegression::new().fit_intercept_and_normalize(); + let lin_reg = LinearRegression::new().with_intercept_and_normalize(); let A: Array2 = array![[0., 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]]; let b: Array1 = array![1., 8., 27., 64.]; let A_view = A.slice(s![.., ..]); @@ -355,12 +366,12 @@ mod tests { .fit(&A_view, &b_view) .expect("can't fit viewed arrays"); - assert_eq!(model1.get_params(), model2.get_params()); - assert_eq!(model2.get_params(), model3.get_params()); - assert_eq!(model3.get_params(), model4.get_params()); + assert_eq!(model1.params(), model2.params()); + assert_eq!(model2.params(), model3.params()); + assert_eq!(model3.params(), model4.params()); - assert_eq!(model1.get_intercept(), model2.get_intercept()); - assert_eq!(model2.get_intercept(), model3.get_intercept()); - assert_eq!(model3.get_intercept(), model4.get_intercept()); + assert_eq!(model1.intercept(), model2.intercept()); + assert_eq!(model2.intercept(), model3.intercept()); + assert_eq!(model3.intercept(), model4.intercept()); } }