Skip to content

Commit

Permalink
feat(linear): add multiplication support
Browse files Browse the repository at this point in the history
  • Loading branch information
Shi-Raida committed Nov 12, 2024
1 parent 6ce5e19 commit 56c659a
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions solver/src/model/lang/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@ impl<T: Into<LinearSum>> std::ops::Add<T> for LinearSum {
type Output = LinearSum;

fn add(self, rhs: T) -> Self::Output {
let mut new = self;
new += rhs.into();
let mut new = self.clone();
new += rhs;
new
}
}
Expand All @@ -352,6 +352,16 @@ impl<T: Into<LinearSum>> std::ops::Sub<T> for LinearSum {
}
}

impl<T: Into<i32>> std::ops::Mul<T> for LinearSum {
type Output = LinearSum;

fn mul(self, rhs: T) -> Self::Output {
let mut new = self.clone();
new *= rhs;
new
}
}

impl<T: Into<LinearSum>> std::ops::AddAssign<T> for LinearSum {
fn add_assign(&mut self, rhs: T) {
let rhs: LinearSum = rhs.into();
Expand All @@ -369,6 +379,16 @@ impl<T: Into<LinearSum>> std::ops::SubAssign<T> for LinearSum {
}
}

impl<T: Into<i32>> std::ops::MulAssign<T> for LinearSum {
fn mul_assign(&mut self, rhs: T) {
let rhs = rhs.into();
self.constant *= rhs;
for term in self.terms.iter_mut() {
term.factor *= rhs;
}
}
}

impl std::ops::Neg for LinearSum {
type Output = LinearSum;

Expand Down Expand Up @@ -1041,6 +1061,16 @@ mod tests {
assert_eq!(result.terms, vec![]);
}

#[test]
fn test_sum_mul() {
let v = IVar::new(VarRef::from_u32(5));
let s = LinearSum::of(vec![FAtom::new(IAtom::new(v, 5), 28)]);
let result = (s * 3).simplify();
assert_eq!(result.constant, 15);
assert_eq!(result.denom, 28);
assert_eq!(result.terms, vec![LinearTerm::new(3, v, 28)]);
}

#[test]
fn test_sum_sub() {
let s1 = LinearSum::of(vec![FAtom::new(5.into(), 28)]);
Expand Down Expand Up @@ -1075,6 +1105,18 @@ mod tests {
assert_eq!(result.terms, vec![]);
}

#[test]
fn test_sum_mul_assign() {
let v = IVar::new(VarRef::from_u32(5));
let s = LinearSum::of(vec![FAtom::new(IAtom::new(v, 5), 28)]);
let mut result = s.clone();
result *= 3;
let result = result.simplify();
assert_eq!(result.constant, 15);
assert_eq!(result.denom, 28);
assert_eq!(result.terms, vec![LinearTerm::new(3, v, 28)]);
}

#[test]
fn test_sum_neg() {
let s1 = LinearSum::of(vec![FAtom::new(5.into(), 28)]);
Expand Down

0 comments on commit 56c659a

Please sign in to comment.