-
-
Notifications
You must be signed in to change notification settings - Fork 6
/
regression.rs
59 lines (53 loc) · 2.07 KB
/
regression.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
//! MSE regression model training and evaluation example
use lightgbm3::{Booster, Dataset};
use serde_json::json;
use std::iter::zip;
/// Loads a .tsv file and returns a flattened vector of xs, a vector of ys
/// and a number of features
fn load_file(file_path: &str) -> (Vec<f64>, Vec<f32>, i32) {
let rdr = csv::ReaderBuilder::new()
.has_headers(false)
.delimiter(b'\t')
.from_path(file_path);
let mut ys: Vec<f32> = Vec::new();
let mut xs: Vec<f64> = Vec::new();
for result in rdr.unwrap().records() {
let record = result.unwrap();
let mut record = record.into_iter();
let y = record.next().unwrap().parse::<f32>().unwrap();
ys.push(y);
xs.extend(record.map(|x| x.parse::<f64>().unwrap()));
}
let n_features = xs.len() / ys.len();
(xs, ys, n_features as i32)
}
fn main() -> std::io::Result<()> {
let (train_xs, train_ys, n_features) =
load_file("lightgbm3-sys/lightgbm/examples/regression/regression.train");
let (test_xs, test_ys, n_features_test) =
load_file("lightgbm3-sys/lightgbm/examples/regression/regression.test");
assert_eq!(n_features, n_features_test);
let train_dataset = Dataset::from_slice(&train_xs, &train_ys, n_features, true).unwrap();
let params = json! {
{
"num_iterations": 100,
"objective": "regression",
"metric": "l2"
}
};
// Train a model
let booster = Booster::train(train_dataset, ¶ms).unwrap();
// Predicts floating point
let y_pred = booster.predict(&test_xs, n_features, true).unwrap();
// Calculate regression metrics
let mean = test_ys.iter().sum::<f32>() / test_ys.len() as f32;
let var = test_ys.iter().map(|&y| (y - mean).powi(2)).sum::<f32>() / test_ys.len() as f32;
let var_model = zip(&test_ys, &y_pred)
.map(|(&y, &y_pred)| (y - y_pred as f32).powi(2))
.sum::<f32>()
/ test_ys.len() as f32;
let r2 = 1.0f32 - var_model / var;
println!("test mse = {var_model:.3}");
println!("test r^2 = {r2:.3}");
Ok(())
}