Skip to content

Commit

Permalink
GH-673: add logistic struct
Browse files Browse the repository at this point in the history
  • Loading branch information
rain1024 committed Jul 14, 2023
1 parent ae5505d commit c19e269
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 1 deletion.
1 change: 1 addition & 0 deletions extensions/underthesea_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ serde = { version = "1.0", features = [ "derive" ] }
regex = "1"
rayon = "1.5"
crfs = "0.1"
ndarray = { version = "0.15", features = ["approx"] }

[dependencies.pyo3]
version = "0.15.0"
Expand Down
2 changes: 2 additions & 0 deletions extensions/underthesea_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::collections::HashSet;

pub mod featurizers;

pub mod logistic;

#[pyclass]
pub struct CRFFeaturizer {
pub object: featurizers::CRFFeaturizer
Expand Down
90 changes: 90 additions & 0 deletions extensions/underthesea_core/src/logistic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
extern crate ndarray;

use ndarray::{Array1, Array2};
use ndarray::prelude::*;

pub struct LogisticRegression {
weights: Array1<f64>,
learning_rate: f64,
iterations: usize,
}

impl LogisticRegression {
// Constructor
pub fn new() -> Self {
LogisticRegression {
weights: Array1::zeros(1),
learning_rate: 0.01,
iterations: 1000,
}
}

// Method to set learning_rate and iterations
pub fn with_hyperparams(mut self, learning_rate: f64, iterations: usize) -> Self {
self.learning_rate = learning_rate;
self.iterations = iterations;
self
}

// Sigmoid function
fn sigmoid(z: f64) -> f64 {
1.0 / (1.0 + (-z).exp())
}

pub fn fit(&mut self, x_train: &Array2<f64>, y_train: &Array1<f64>) {
let m = x_train.nrows(); // number of samples
let n = x_train.ncols(); // number of features
self.weights = Array1::zeros(n); // initializing weights

// Gradient Descent
for _ in 0..self.iterations {
let mut gradient = Array1::zeros(n); // initialize gradient

// calculate gradient for each sample
for j in 0..m {
let x = x_train.row(j).to_owned();
let h = Self::sigmoid(x.dot(&self.weights));
let error = h - y_train[j];
gradient = gradient + error * &x;
}

// update weights
self.weights = &self.weights - self.learning_rate * gradient / m as f64;
}
}

// Predict function for predicting an output with the learned weights
pub fn predict(&self, x: &Array1<f64>) -> f64 {
let z = x.dot(&self.weights);
Self::sigmoid(z)
}
}

#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;

#[test]
fn logistic_regression_test() {
// Initialize logistic regression model
let mut model = LogisticRegression::new().with_hyperparams(0.01, 20000);

// Training data
let x_train = array![[0., 1., 2.], [1., 2., 3.], [2., 3., 4.], [3., 4., 5.]];
let y_train = array![0., 0., 1., 1.];

// Fit model
model.fit(&x_train, &y_train);

// Test data
let x_test = array![2., 3., 4.];

// Predict
let prediction = model.predict(&x_test);

// Test that the model's prediction is close to the expected value
print!("Error {}", (prediction - 1.).abs());
assert!((prediction - 0.1).abs() < 1.0);
}
}
2 changes: 1 addition & 1 deletion extensions/underthesea_core/tests/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ mod tests {
#[test]
fn test_crfs(){
let buf = fs::read("tests/wt_crf_2018_09_13.bin").unwrap();
let model = crfs::Model::new(&buf).unwrap();
let _model = crfs::Model::new(&buf).unwrap();
}
}

0 comments on commit c19e269

Please sign in to comment.