Skip to content

Commit

Permalink
Merge pull request #1 from Pardoxa/main
Browse files Browse the repository at this point in the history
Merge main into newBranch
  • Loading branch information
Pardoxa authored Feb 27, 2024
2 parents a77eb41 + 203f151 commit a75b0e8
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 41 deletions.
9 changes: 3 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ rand = { version = "^0.8.2"}
serde = { version = "1.0", optional = true, features = ["derive"] }
num-traits = "^0.2"
transpose = "^0.2"
average = { version = "^0.13", optional = true }
average = { version = "^0.14", optional = true }
rayon = { version = "^1.5", optional = true }
paste = "1.0"

[dev-dependencies]
serde_json = "1.0"
criterion = { version = "^0.4", features=["html_reports"] }

criterion = "0.5"
statrs = "0.16.0"
rand_pcg = { version = "^0.3.0", features = ["serde1"]}

Expand All @@ -56,8 +57,4 @@ default = ["serde_support", "bootstrap", "replica_exchange"]

[[bench]]
name = "bench"
harness = true

[[bench]]
name = "hists"
harness = false
18 changes: 13 additions & 5 deletions src/examples/coin_flips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ mod tests{
1,
NonZeroUsize::new(1999).unwrap(),
NonZeroUsize::new(2).unwrap(),
0.000003
0.000001
).unwrap();

let rewl1 = rewl_builder1.greedy_build(|e| Some(e.head_count()));
Expand Down Expand Up @@ -307,6 +307,7 @@ mod tests{
rewl_slice.par_iter_mut()
.for_each(|rewl| rewl.simulate_until_convergence(|e| Some(e.head_count())));


rewl_slice.iter()
.for_each(
|r|
Expand All @@ -315,7 +316,7 @@ mod tests{
.for_each(|w| println!("rewl replica_frac {}", w.replica_exchange_frac()));
}
);

let steps: u64 = rewl_slice
.iter()
.flat_map(|r|
Expand All @@ -336,7 +337,7 @@ mod tests{
let mut rees_slice: Vec<_> = rewl_slice.into_iter()
.map(|r| r.into_rees())
.collect();

rees_slice
.par_iter_mut()
.for_each(
Expand All @@ -356,7 +357,7 @@ mod tests{
).sum();
println!("Ges steps rees {}", steps);

let prob_rees = rees::merged_log_prob(&rees_slice).unwrap();
let prob_rees = rees::merged_log_prob_rees(&rees_slice).unwrap();

let mut max_ln_difference_rewl = f64::NEG_INFINITY;
let mut max_difference_rewl = f64::NEG_INFINITY;
Expand All @@ -368,7 +369,14 @@ mod tests{
let mut max_difference_rees = f64::NEG_INFINITY;
let mut frac_difference_max_rees = f64::NEG_INFINITY;
let mut frac_difference_min_rees = f64::INFINITY;
for (index, ((val_sim1, val_sim2), val_true)) in prob.0.into_iter().zip(prob_rees.0).zip(ln_prob_true).enumerate()

let iter = prob.0.
into_iter()
.zip(prob_rees.0)
.zip(ln_prob_true)
.enumerate();

for (index, ((val_sim1, val_sim2), val_true)) in iter
{
println!("{} {} {} {}", index, val_sim1, val_sim2, val_true);

Expand Down
12 changes: 7 additions & 5 deletions src/glue/derivative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ pub fn five_point_derivitive(data: &[f64]) -> Vec<f64>
d
}

fn derivative(data: &[f64]) -> Vec<f64>
/// # Calculates the derivative of a Vector
/// * will return a Vector of NaN if `data.len() < 2`
pub fn derivative(data: &[f64]) -> Vec<f64>
{
let mut d = vec![f64::NAN; data.len()];
if data.len() >= 3 {
Expand All @@ -21,9 +23,9 @@ fn derivative(data: &[f64]) -> Vec<f64>
}
}
if data.len() >= 2 {
d[0] = (data[1] - data[0]) / 2.0;
d[0] = data[1] - data[0];

d[data.len() - 1] = (data[data.len() - 1] - data[data.len() - 2]) / 2.0;
d[data.len() - 1] = data[data.len() - 1] - data[data.len() - 2];
}
d
}
Expand All @@ -41,9 +43,9 @@ pub fn derivative_merged(data: &[f64]) -> Vec<f64>
d[1] = (data[2] - data[0]) / 2.0;
d[data.len() - 2] = (data[data.len() - 1] - data[data.len() - 3]) / 2.0;

d[0] = (data[1] - data[0]) / 2.0;
d[0] = data[1] - data[0];

d[data.len() - 1] = (data[data.len() - 1] - data[data.len() - 2]) / 2.0;
d[data.len() - 1] = data[data.len() - 1] - data[data.len() - 2];

d
}
113 changes: 95 additions & 18 deletions src/heatmap/gnuplot.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
use{
std::{
use std::{
fmt,
io::Write,
convert::From,
borrow::*
}
};
borrow::*,
path::Path
};

#[cfg(feature = "serde_support")]
use serde::{Serialize, Deserialize};

#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
/// For labeling the gnuplot plots axis
pub enum GnuplotAxis{
pub enum Labels{
/// construct the labels
FromValues{
/// minimum value for axis labels
Expand All @@ -24,17 +23,31 @@ pub enum GnuplotAxis{
tics: usize,
},
/// use labels
Labels{
FromStrings{
/// this are the labels
labels: Vec<String>
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
/// For labeling the gnuplot plots axis
pub struct GnuplotAxis{
labels: Labels,
rotation: f32
}

impl GnuplotAxis{
/// Set the rotation value.
/// Tics will be displayed rotaded to the right by the requested amount
pub fn set_rotation(&mut self, rotation_degrees: f32)
{
self.rotation = rotation_degrees;
}

pub(crate) fn write_tics<W: Write>(&self, mut w: W, num_bins: usize, axis: &str) -> std::io::Result<()>
{
match self {
Self::FromValues{min, max, tics} => {
match &self.labels {
Labels::FromValues{min, max, tics} => {
if min.is_nan() || max.is_nan() || *tics < 2 || num_bins < 2 {
Ok(())
} else {
Expand All @@ -49,10 +62,10 @@ impl GnuplotAxis{
let pos = i as f64 * bin_dif;
write!(w, "\"{:#}\" {:e}, ", val, pos)?;
}
writeln!(w, "\"{:#}\" {:e} )", max, num_bins - 1)
writeln!(w, "\"{:#}\" {:e} ) rotate by {} right", max, num_bins - 1, self.rotation)
}
},
Self::Labels{labels} => {
Labels::FromStrings{labels} => {
let tics = labels.len();
match tics {
0 => Ok(()),
Expand All @@ -67,7 +80,7 @@ impl GnuplotAxis{
let pos = i as f64 * bin_dif;
write!(w, "\"{}\" {:e}, ", lab, pos)?;
}
writeln!(w, " )")
writeln!(w, " ) rotate by {} right", self.rotation)
}
}
}
Expand All @@ -77,20 +90,20 @@ impl GnuplotAxis{

/// Create new GnuplotAxis::FromValues
pub fn new(min: f64, max: f64, tics: usize) -> Self {
Self::FromValues{
let labels = Labels::FromValues{
min,
max,
tics
}
};
Self { labels, rotation: 0.0 }
}

/// Create new GnuplotAxis::Labels
/// - Vector contains labels used for axis
pub fn from_labels(labels: Vec<String>) -> Self
{
Self::Labels{
labels
}
let labels = Labels::FromStrings { labels };
Self{labels, rotation: 0.0}
}

/// Similar to `from_labels`
Expand Down Expand Up @@ -125,6 +138,10 @@ pub struct GnuplotSettings{

/// Color palette for heatmap
pub palette: GnuplotPalette,

/// Define the cb range if this option is set
pub cb_range: Option<(f64, f64)>,

/// # Size of the terminal
/// * Anything gnuplot accepts (e.g. "2cm, 2.9cm") is acceptable
/// # Note
Expand All @@ -143,6 +160,20 @@ impl GnuplotSettings {
self
}

/// # Builder pattern - set cb_range
pub fn cb_range(&'_ mut self, range_start: f64, range_end: f64) -> &'_ mut Self
{
self.cb_range = Some((range_start, range_end));
self
}

/// # Builder pattern - remove cb_range
pub fn remove_cb_range(&'_ mut self) -> &'_ mut Self
{
self.cb_range = None;
self
}

/// # Builder pattern - set x_label
pub fn x_label<S: Into<String>>(&'_ mut self, x_label: S) -> &'_ mut Self
{
Expand Down Expand Up @@ -216,13 +247,27 @@ impl GnuplotSettings {
self
}

/// Remove x_axis
pub fn remove_x_axis(&'_ mut self) -> &'_ mut Self
{
self.x_axis = None;
self
}

/// Set y_axis - See GnuplotAxis or try it out
pub fn y_axis(&'_ mut self, axis: GnuplotAxis) -> &'_ mut Self
{
self.y_axis = Some(axis);
self
}

/// Remove y_axis
pub fn remove_y_axis(&'_ mut self) -> &'_ mut Self
{
self.y_axis = None;
self
}

pub(crate) fn write_axis<W: Write>(&self, mut w: W, num_bins_x: usize, num_bins_y: usize) -> std::io::Result<()>
{
if let Some(ax) = self.x_axis.as_ref() {
Expand All @@ -248,6 +293,9 @@ impl GnuplotSettings {

writeln!(writer, "set xrange[-0.5:{}]", x_len as f64 - 0.5)?;
writeln!(writer, "set yrange[-0.5:{}]", y_len as f64 - 0.5)?;
if let Some((range_start, range_end)) = self.cb_range{
writeln!(writer, "set cbrange [{range_start:e}:{range_end:e}]")?;
}
if !self.title.is_empty(){
writeln!(writer, "set title '{}'", self.title)?;
}
Expand Down Expand Up @@ -293,6 +341,34 @@ impl GnuplotSettings {

self.terminal.finish(&mut writer)
}

/// Same as write_heatmap but it assumes that the heatmap
/// matrix is available in the file "heatmap"
pub fn write_heatmap_external_matrix<W, P>(
&self,
mut writer: W,
matrix_width: usize,
matrix_height: usize,
matrix_path: P
) -> std::io::Result<()>
where W: Write,
P: AsRef<Path>
{
self.write_heatmap_helper1(
&mut writer,
matrix_width,
matrix_height
)?;

writeln!(
writer,
"splot \"{}\" matrix with image t \"{}\" ",
matrix_path.as_ref().to_string_lossy(),
&self.title
)?;

self.terminal.finish(&mut writer)
}
}

impl Default for GnuplotSettings{
Expand All @@ -305,7 +381,8 @@ impl Default for GnuplotSettings{
palette: GnuplotPalette::PresetHSV,
x_axis: None,
y_axis: None,
size: "7.4cm, 5cm".into()
size: "7.4cm, 5cm".into(),
cb_range: None
}
}
}
Expand Down
Loading

0 comments on commit a75b0e8

Please sign in to comment.