Skip to content

Commit

Permalink
Migrate Channel methods to new definition
Browse files Browse the repository at this point in the history
  • Loading branch information
t7phy committed Jul 18, 2024
1 parent f32902a commit 2417c92
Showing 1 changed file with 63 additions and 66 deletions.
129 changes: 63 additions & 66 deletions pineappl/src/boc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,22 +311,22 @@ impl Channel {

// sort `entry` because the ordering doesn't matter and because it makes it easier to
// compare `Channel` objects with each other
entry.sort_by(|x, y| (x.0, x.1).cmp(&(y.0, y.1)));
entry.sort_by(|x, y| x.0.cmp(&y.0));

Self {
entry: entry
.into_iter()
.coalesce(|lhs, rhs| {
// sum the factors of repeated elements
if (lhs.0, lhs.1) == (rhs.0, rhs.1) {
Ok((lhs.0, lhs.1, lhs.2 + rhs.2))
if lhs.0 == rhs.0 {
Ok((lhs.0, lhs.1 + rhs.1))
} else {
Err((lhs, rhs))
}
})
// filter zeros
// TODO: find a better than to hardcode the epsilon limit
.filter(|&(_, _, f)| !approx_eq!(f64, f.abs(), 0.0, epsilon = 1e-14))
.filter(|&(_, f)| !approx_eq!(f64, f.abs(), 0.0, epsilon = 1e-14))
.collect(),
}
}
Expand All @@ -347,17 +347,22 @@ impl Channel {
/// assert_eq!(entry, channel![2, 11, 1.0; -2, 11, -1.0; 1, 11, -1.0; -1, 11, 1.0]);
/// ```
pub fn translate(entry: &Self, translator: &dyn Fn(i32) -> Vec<(i32, f64)>) -> Self {
let mut tuples = Vec::new();
let mut result = Vec::new();

for &(a, b, factor) in &entry.entry {
for (aid, af) in translator(a) {
for (bid, bf) in translator(b) {
tuples.push((aid, bid, factor * af * bf));
}
for &(pids, factor) in &entry.entry {
for tuples in pids
.iter()
.map(|&pid| translator(pid))
.multi_cartesian_product()
{
result.push((
tuples.iter().map(|&(pid, _)| pid).collect(),
tuples.iter().map(|(_, f)| f).product::<f64>(),
));
}
}

Self::new(tuples)
Self::new(result)
}

/// Returns a tuple representation of this entry.
Expand All @@ -377,11 +382,11 @@ impl Channel {
&self.entry
}

/// Creates a new object with the initial states transposed.
#[must_use]
pub fn transpose(&self) -> Self {
Self::new(self.entry.iter().map(|(a, b, c)| (*b, *a, *c)).collect())
}
// /// Creates a new object with the initial states transposed.
// #[must_use]
// pub fn transpose(&self) -> Self {
// Self::new(self.entry.iter().map(|(a, b, c)| (*b, *a, *c)).collect())
// }

/// If `other` is the same channel when only comparing PIDs and neglecting the factors, return
/// the number `f1 / f2`, where `f1` is the factor from `self` and `f2` is the factor from
Expand All @@ -392,10 +397,10 @@ impl Channel {
/// ```rust
/// use pineappl::boc::Channel;
///
/// let entry1 = Channel::new(vec![(2, 2, 2.0), (4, 4, 2.0)]);
/// let entry2 = Channel::new(vec![(4, 4, 1.0), (2, 2, 1.0)]);
/// let entry3 = Channel::new(vec![(3, 4, 1.0), (2, 2, 1.0)]);
/// let entry4 = Channel::new(vec![(4, 3, 1.0), (2, 3, 2.0)]);
/// let entry1 = Channel::new(vec![(vec![2, 2], 2.0), (vec![4, 4], 2.0)]);
/// let entry2 = Channel::new(vec![(vec![4, 4], 1.0), (vec![2, 2], 1.0)]);
/// let entry3 = Channel::new(vec![(vec![3, 4], 1.0), (vec![2, 2], 1.0)]);
/// let entry4 = Channel::new(vec![(vec![4, 3], 1.0), (vec![2, 3], 2.0)]);
///
/// assert_eq!(entry1.common_factor(&entry2), Some(2.0));
/// assert_eq!(entry1.common_factor(&entry3), None);
Expand All @@ -411,7 +416,7 @@ impl Channel {
.entry
.iter()
.zip(&other.entry)
.map(|(a, b)| ((a.0 == b.0) && (a.1 == b.1)).then_some(a.2 / b.2))
.map(|(a, b)| (a == b).then_some(a.1 / b.1))
.collect();

result.and_then(|factors| {
Expand All @@ -436,51 +441,43 @@ impl FromStr for Channel {
type Err = ParseChannelError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self::new(
s.split('+')
.map(|sub| {
sub.split_once('*').map_or_else(
|| Err(ParseChannelError(format!("missing '*' in '{sub}'"))),
|(factor, pids)| {
let tuple = pids.split_once(',').map_or_else(
|| Err(ParseChannelError(format!("missing ',' in '{pids}'"))),
|(a, b)| {
Ok((
a.trim()
.strip_prefix('(')
.ok_or_else(|| {
ParseChannelError(format!(
"missing '(' in '{pids}'"
))
})?
.trim()
.parse::<i32>()
.map_err(|err| ParseChannelError(err.to_string()))?,
b.trim()
.strip_suffix(')')
.ok_or_else(|| {
ParseChannelError(format!(
"missing ')' in '{pids}'"
))
})?
.trim()
.parse::<i32>()
.map_err(|err| ParseChannelError(err.to_string()))?,
))
},
)?;

Ok((
tuple.0,
tuple.1,
str::parse::<f64>(factor.trim())
.map_err(|err| ParseChannelError(err.to_string()))?,
))
},
)
})
.collect::<Result<_, _>>()?,
))
let result: Vec<_> = s
.split('+')
.map(|sub| {
sub.split_once('*').map_or_else(
|| Err(ParseChannelError(format!("missing '*' in '{sub}'"))),
|(factor, pids)| {
let vector: Vec<_> = pids
.strip_prefix('(')
.ok_or_else(|| ParseChannelError(format!("missing '(' in '{pids}'")))?
.strip_suffix(')')
.ok_or_else(|| ParseChannelError(format!("missing ')' in '{pids}'")))?
.split(',')
.map(|pid| {
Ok(pid
.trim()
.parse::<i32>()
.map_err(|err| ParseChannelError(err.to_string()))?)
})
.collect::<Result<_, _>>()?;

Ok((
vector,
str::parse::<f64>(factor.trim())
.map_err(|err| ParseChannelError(err.to_string()))?,
))
},
)
})
.collect::<Result<_, _>>()?;

if !result.iter().map(|(pids, _)| pids.len()).all_equal() {
return Err(ParseChannelError(format!(
"PID tuples have different lengths"
)));
}

Ok(Self::new(result))
}
}

Expand Down

0 comments on commit 2417c92

Please sign in to comment.