Skip to content

Commit

Permalink
Change tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jinlow committed Jan 18, 2024
1 parent c8b7f5c commit 2821fd4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 31 deletions.
52 changes: 34 additions & 18 deletions scripts/make_resources.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import argparse

import pandas as pd
import seaborn as sns

if __name__ == "__main__":

class Inputs(argparse.Namespace):
make_benchmark_files: bool


def main(args: Inputs):
df = sns.load_dataset("titanic")
X = df.select_dtypes("number").drop(columns=["survived"]).astype(float)
y = df["survived"].astype(float)
Expand Down Expand Up @@ -29,24 +36,33 @@
index=False,
header=False,
)
if args.make_benchmark_files:
dfb = df.sample(
100_000,
random_state=0,
replace=True,
).reset_index(drop=True)

dfb = df.sample(
100_000,
random_state=0,
replace=True,
).reset_index(drop=True)
Xb = dfb.select_dtypes("number").drop(columns=["survived"]).astype(float)
Xb = pd.concat([Xb] * 10, axis=0)
yb = dfb["survived"].astype(float)
print("benmark files sizes: {Xb.shape}, {y.shape}")

Xb = dfb.select_dtypes("number").drop(columns=["survived"]).astype(float)
yb = dfb["survived"].astype(float)
pd.Series(Xb.fillna(0).to_numpy().ravel(order="F")).to_csv(
"resources/contiguous_no_missing_100k_samp_seed0.csv",
index=False,
header=False,
)

pd.Series(Xb.fillna(0).to_numpy().ravel(order="F")).to_csv(
"resources/contiguous_no_missing_100k_samp_seed0.csv",
index=False,
header=False,
)
yb.to_csv(
"resources/performance_100k_samp_seed0.csv",
index=False,
header=False,
)

yb.to_csv(
"resources/performance_100k_samp_seed0.csv",
index=False,
header=False,
)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--make-benchmark-files", "-mbf", action="store_true")
args = parser.parse_args(namespace=Inputs())
main(args)
25 changes: 12 additions & 13 deletions src/histogram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,8 @@ pub fn create_feature_histogram(
sorted_hess: &[f32],
index: &[usize],
) -> Vec<Bin<f32>> {
let mut histogram: Vec<Bin<f32>> = Vec::with_capacity(cuts.len());

let mut gradient_sums: Vec<f64> = Vec::with_capacity(cuts.len());
let mut hessian_sums: Vec<f64> = Vec::with_capacity(cuts.len());

histogram.push(Bin::new_f32(f64::NAN));
// The last cut value is simply the maximum possible value, so we don't need it.
// This value is needed initially for binning, but we don't need to count it as
// a histogram bin.
histogram.extend(cuts[..(cuts.len() - 1)].iter().map(|c| Bin::new_f32(*c)));
let mut gradient_sums: Vec<f64> = vec![0.; cuts.len()];
let mut hessian_sums: Vec<f64> = vec![0.; cuts.len()];

index
.iter()
Expand All @@ -112,14 +104,21 @@ pub fn create_feature_histogram(
*v += f64::from(*h);
}
});

// The first value si reserved for missing.
// The last cut value is simply the maximum possible value, so we don't need it.
// This value is needed initially for binning, but we don't need to count it as
// a histogram bin.
histogram
.iter_mut()
Some(&f64::NAN)
.into_iter()
.chain(cuts[..(cuts.len() - 1)].iter())
.zip(gradient_sums)
.zip(hessian_sums)
.map(|((hist, g), h)| hist.update(g, h))
.map(|((c, g), h)| Bin {
gradient_sum: g as f32,
hessian_sum: h as f32,
cut_value: *c,
})
.collect()
}

Expand Down

0 comments on commit 2821fd4

Please sign in to comment.