diff --git a/solver/src/model/lang/linear.rs b/solver/src/model/lang/linear.rs index fb967890..28e2adfd 100644 --- a/solver/src/model/lang/linear.rs +++ b/solver/src/model/lang/linear.rs @@ -338,8 +338,8 @@ impl> std::ops::Add for LinearSum { type Output = LinearSum; fn add(self, rhs: T) -> Self::Output { - let mut new = self; - new += rhs.into(); + let mut new = self.clone(); + new += rhs; new } } @@ -352,6 +352,16 @@ impl> std::ops::Sub for LinearSum { } } +impl> std::ops::Mul for LinearSum { + type Output = LinearSum; + + fn mul(self, rhs: T) -> Self::Output { + let mut new = self.clone(); + new *= rhs; + new + } +} + impl> std::ops::AddAssign for LinearSum { fn add_assign(&mut self, rhs: T) { let rhs: LinearSum = rhs.into(); @@ -369,6 +379,16 @@ impl> std::ops::SubAssign for LinearSum { } } +impl> std::ops::MulAssign for LinearSum { + fn mul_assign(&mut self, rhs: T) { + let rhs = rhs.into(); + self.constant *= rhs; + for term in self.terms.iter_mut() { + term.factor *= rhs; + } + } +} + impl std::ops::Neg for LinearSum { type Output = LinearSum; @@ -1041,6 +1061,16 @@ mod tests { assert_eq!(result.terms, vec![]); } + #[test] + fn test_sum_mul() { + let v = IVar::new(VarRef::from_u32(5)); + let s = LinearSum::of(vec![FAtom::new(IAtom::new(v, 5), 28)]); + let result = (s * 3).simplify(); + assert_eq!(result.constant, 15); + assert_eq!(result.denom, 28); + assert_eq!(result.terms, vec![LinearTerm::new(3, v, 28)]); + } + #[test] fn test_sum_sub() { let s1 = LinearSum::of(vec![FAtom::new(5.into(), 28)]); @@ -1075,6 +1105,18 @@ mod tests { assert_eq!(result.terms, vec![]); } + #[test] + fn test_sum_mul_assign() { + let v = IVar::new(VarRef::from_u32(5)); + let s = LinearSum::of(vec![FAtom::new(IAtom::new(v, 5), 28)]); + let mut result = s.clone(); + result *= 3; + let result = result.simplify(); + assert_eq!(result.constant, 15); + assert_eq!(result.denom, 28); + assert_eq!(result.terms, vec![LinearTerm::new(3, v, 28)]); + } + #[test] fn test_sum_neg() { let s1 = LinearSum::of(vec![FAtom::new(5.into(), 28)]);