Skip to content

Commit

Permalink
going through traits and splitting op ones out
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Oct 13, 2024
1 parent 581e7c7 commit 4126bac
Show file tree
Hide file tree
Showing 39 changed files with 740 additions and 523 deletions.
29 changes: 13 additions & 16 deletions src/jacobian/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use std::collections::HashSet;

use crate::op::{LinearOp, Op};
use crate::vector::Vector;
use crate::Scalar;
use crate::{op::NonLinearOp, Matrix, MatrixSparsityRef, VectorIndex};
use crate::{NonLinearOp, Matrix, MatrixSparsityRef, VectorIndex, LinearOp, Op, Vector, Scalar, NonLinearOpAdjoint, NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint};
use num_traits::{One, Zero};

use self::{coloring::nonzeros2graph, greedy_coloring::color_graph_greedy};
Expand All @@ -13,9 +10,9 @@ pub mod graph;
pub mod greedy_coloring;

macro_rules! gen_find_non_zeros_nonlinear {
($name:ident, $op_fn:ident) => {
($name:ident, $op_fn:ident, $op_trait:ident) => {
/// Find the non-zero entries of the $name matrix of a non-linear operator.
pub fn $name<F: NonLinearOp + ?Sized>(op: &F, x: &F::V, t: F::T) -> Vec<(usize, usize)> {
pub fn $name<F: NonLinearOp + $op_trait + ?Sized>(op: &F, x: &F::V, t: F::T) -> Vec<(usize, usize)> {
let mut v = F::V::zeros(op.nstates());
let mut col = F::V::zeros(op.nout());
let mut triplets = Vec::with_capacity(op.nstates());
Expand All @@ -35,10 +32,10 @@ macro_rules! gen_find_non_zeros_nonlinear {
};
}

gen_find_non_zeros_nonlinear!(find_jacobian_non_zeros, jac_mul_inplace);
gen_find_non_zeros_nonlinear!(find_adjoint_non_zeros, jac_transpose_mul_inplace);
gen_find_non_zeros_nonlinear!(find_sens_non_zeros, sens_mul_inplace);
gen_find_non_zeros_nonlinear!(find_sens_adjoint_non_zeros, sens_transpose_mul_inplace);
gen_find_non_zeros_nonlinear!(find_jacobian_non_zeros, jac_mul_inplace, NonLinearOpJacobian);
gen_find_non_zeros_nonlinear!(find_adjoint_non_zeros, jac_transpose_mul_inplace, NonLinearOpAdjoint);
gen_find_non_zeros_nonlinear!(find_sens_non_zeros, sens_mul_inplace, NonLinearOpSens);
gen_find_non_zeros_nonlinear!(find_sens_adjoint_non_zeros, sens_transpose_mul_inplace, NonLinearOpSensAdjoint);

macro_rules! gen_find_non_zeros_linear {
($name:ident, $op_fn:ident) => {
Expand Down Expand Up @@ -119,7 +116,7 @@ impl<M: Matrix> JacobianColoring<M> {
// Self::new_from_non_zeros(op, non_zeros)
//}

pub fn jacobian_inplace<F: NonLinearOp<M = M, V = M::V, T = M::T>>(
pub fn jacobian_inplace<F: NonLinearOpJacobian<M = M, V = M::V, T = M::T>>(
&self,
op: &F,
x: &F::V,
Expand All @@ -139,7 +136,7 @@ impl<M: Matrix> JacobianColoring<M> {
}
}

pub fn adjoint_inplace<F: NonLinearOp<M = M, V = M::V, T = M::T>>(
pub fn adjoint_inplace<F: NonLinearOpAdjoint<M = M, V = M::V, T = M::T>>(
&self,
op: &F,
x: &F::V,
Expand All @@ -159,7 +156,7 @@ impl<M: Matrix> JacobianColoring<M> {
}
}

pub fn sens_adjoint_inplace<F: NonLinearOp<M = M, V = M::V, T = M::T>>(
pub fn sens_adjoint_inplace<F: NonLinearOpSensAdjoint<M = M, V = M::V, T = M::T>>(
&self,
op: &F,
x: &F::V,
Expand Down Expand Up @@ -207,13 +204,13 @@ mod tests {
use crate::matrix::sparsity::MatrixSparsityRef;
use crate::matrix::Matrix;
use crate::op::linear_closure::LinearClosure;
use crate::op::{LinearOp, Op};
use crate::vector::Vector;
use crate::{
jacobian::{coloring::nonzeros2graph, greedy_coloring::color_graph_greedy},
op::closure::Closure,
LinearOp, Op,
};
use crate::{scale, NonLinearOp, SparseColMat};
use crate::{scale, NonLinearOpJacobian, SparseColMat};
use nalgebra::DMatrix;
use num_traits::{One, Zero};
use std::ops::MulAssign;
Expand All @@ -224,7 +221,7 @@ mod tests {
triplets: &'a [(usize, usize, M::T)],
nrows: usize,
ncols: usize,
) -> impl NonLinearOp<M = M, V = M::V, T = M::T> + 'a {
) -> impl NonLinearOpJacobian<M = M, V = M::V, T = M::T> + 'a {
let nstates = ncols;
let nout = nrows;
let f = move |x: &M::V, y: &mut M::V| {
Expand Down
7 changes: 5 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,19 @@ use ode_solver::jacobian_update::JacobianUpdate;
pub use ode_solver::{
adjoint_equations::AdjointEquations, adjoint_equations::AdjointInit, adjoint_equations::AdjointContext,
adjoint_equations::AdjointRhs, bdf::Bdf, bdf::BdfAdj, bdf::BdfAug, bdf_state::BdfState,
builder::OdeBuilder, checkpointing::Checkpointing, equations::AugmentedOdeEquations,
builder::OdeBuilder, checkpointing::Checkpointing, equations::AugmentedOdeEquations, equations::OdeEquationsAdjoint, equations::OdeEquationsImplicit,
equations::NoAug, equations::OdeEquations, equations::OdeSolverEquations,
method::OdeSolverMethod, method::OdeSolverStopReason, problem::OdeSolverProblem, sdirk::Sdirk,
sdirk::SdirkAdj, sdirk::SdirkAug, sdirk_state::SdirkState, sens_equations::SensEquations,
sens_equations::SensInit, sens_equations::SensRhs, state::OdeSolverState, tableau::Tableau,
};
pub use ode_solver::state::{StateRef, StateRefMut};
use op::nonlinear_op::{NonLinearOp, NonLinearOpSens, NonLinearOpSensAdjoint, NonLinearOpAdjoint, NonLinearOpJacobian};
use op::linear_op::{LinearOp, LinearOpMatrix, LinearOpSens, LinearOpTranspose};
use op::constant_op::{ConstantOp, ConstantOpSens, ConstantOpSensAdjoint};
pub use op::{
closure::Closure, constant_closure::ConstantClosure, constant_closure_with_adjoint::ConstantClosureWithAdjoint, linear_closure::LinearClosure,
unit::UnitCallable, ConstantOp, LinearOp, NonLinearOp, Op,
unit::UnitCallable, Op,
};
use op::{
closure_no_jac::ClosureNoJac, closure_with_sens::ClosureWithSens,
Expand Down
10 changes: 5 additions & 5 deletions src/linear_solver/faer/lu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ use std::rc::Rc;

use crate::{
error::DiffsolError, linear_solver::LinearSolver, op::linearise::LinearisedOp,
solver::SolverProblem, LinearOp, Matrix, MatrixSparsityRef, NonLinearOp, Op, Scalar,
solver::SolverProblem, Matrix, LinearOpMatrix, MatrixSparsityRef, Op, Scalar, NonLinearOpJacobian
};

use faer::{linalg::solvers::FullPivLu, solvers::SpSolver, Col, Mat};
/// A [LinearSolver] that uses the LU decomposition in the [`faer`](https://github.com/sarah-ek/faer-rs) library to solve the linear system.
pub struct LU<T, C>
where
T: Scalar,
C: NonLinearOp<M = Mat<T>, V = Col<T>, T = T>,
C: NonLinearOpJacobian<M = Mat<T>, V = Col<T>, T = T>,
{
lu: Option<FullPivLu<T>>,
problem: Option<SolverProblem<LinearisedOp<C>>>,
Expand All @@ -21,7 +21,7 @@ where
impl<T, C> Default for LU<T, C>
where
T: Scalar,
C: NonLinearOp<M = Mat<T>, V = Col<T>, T = T>,
C: NonLinearOpJacobian<M = Mat<T>, V = Col<T>, T = T>,
{
fn default() -> Self {
Self {
Expand All @@ -32,8 +32,8 @@ where
}
}

impl<T: Scalar, C: NonLinearOp<M = Mat<T>, V = Col<T>, T = T>> LinearSolver<C> for LU<T, C> {
type SelfNewOp<C2: NonLinearOp<T = C::T, V = C::V, M = C::M>> = LU<T, C2>;
impl<T: Scalar, C: NonLinearOpJacobian<M = Mat<T>, V = Col<T>, T = T>> LinearSolver<C> for LU<T, C> {
type SelfNewOp<C2: NonLinearOpJacobian<T = C::T, V = C::V, M = C::M>> = LU<T, C2>;

fn set_linearisation(&mut self, x: &C::V, t: C::T) {
Rc::<LinearisedOp<C>>::get_mut(&mut self.problem.as_mut().expect("Problem not set").f)
Expand Down
11 changes: 6 additions & 5 deletions src/linear_solver/faer/sparse_lu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use crate::{
op::linearise::LinearisedOp,
scalar::IndexType,
solver::SolverProblem,
LinearOp, Matrix, NonLinearOp, Op, Scalar, SparseColMat,
LinearOpMatrix, Matrix, NonLinearOpJacobian, Op, Scalar, SparseColMat,

};

use faer::{
Expand All @@ -21,7 +22,7 @@ use faer::{
pub struct FaerSparseLU<T, C>
where
T: Scalar,
C: NonLinearOp<M = SparseColMat<T>, V = Col<T>, T = T>,
C: NonLinearOpJacobian<M = SparseColMat<T>, V = Col<T>, T = T>,
{
lu: Option<Lu<IndexType, T>>,
lu_symbolic: Option<SymbolicLu<IndexType>>,
Expand All @@ -32,7 +33,7 @@ where
impl<T, C> Default for FaerSparseLU<T, C>
where
T: Scalar,
C: NonLinearOp<M = SparseColMat<T>, V = Col<T>, T = T>,
C: NonLinearOpJacobian<M = SparseColMat<T>, V = Col<T>, T = T>,
{
fn default() -> Self {
Self {
Expand All @@ -44,10 +45,10 @@ where
}
}

impl<T: Scalar, C: NonLinearOp<M = SparseColMat<T>, V = Col<T>, T = T>> LinearSolver<C>
impl<T: Scalar, C: NonLinearOpJacobian<M = SparseColMat<T>, V = Col<T>, T = T>> LinearSolver<C>
for FaerSparseLU<T, C>
{
type SelfNewOp<C2: NonLinearOp<T = C::T, V = C::V, M = C::M>> = FaerSparseLU<T, C2>;
type SelfNewOp<C2: NonLinearOpJacobian<T = C::T, V = C::V, M = C::M>> = FaerSparseLU<T, C2>;

fn set_linearisation(&mut self, x: &C::V, t: C::T) {
Rc::<LinearisedOp<C>>::get_mut(&mut self.problem.as_mut().expect("Problem not set").f)
Expand Down
11 changes: 6 additions & 5 deletions src/linear_solver/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{error::DiffsolError, op::Op, solver::SolverProblem, NonLinearOp};
use crate::{error::DiffsolError, op::Op, solver::SolverProblem, NonLinearOpJacobian};

#[cfg(feature = "nalgebra")]
pub mod nalgebra;
Expand All @@ -17,7 +17,7 @@ pub use nalgebra::lu::LU as NalgebraLU;

/// A solver for the linear problem `Ax = b`, where `A` is a linear operator that is obtained by taking the linearisation of a nonlinear operator `C`
pub trait LinearSolver<C: Op>: Default {
type SelfNewOp<C2: NonLinearOp<T = C::T, V = C::V, M = C::M>>: LinearSolver<C2>;
type SelfNewOp<C2: NonLinearOpJacobian<T = C::T, V = C::V, M = C::M>>: LinearSolver<C2>;

/// Set the problem to be solved, any previous problem is discarded.
/// Any internal state of the solver is reset.
Expand Down Expand Up @@ -57,7 +57,8 @@ pub mod tests {

use crate::{
linear_solver::{FaerLU, NalgebraLU},
op::{closure::Closure, NonLinearOp},
op::closure::Closure,
NonLinearOpJacobian,
scalar::scale,
vector::VectorRef,
LinearSolver, Matrix, SolverProblem, Vector,
Expand All @@ -67,7 +68,7 @@ pub mod tests {
use super::LinearSolveSolution;

pub fn linear_problem<M: Matrix + 'static>() -> (
SolverProblem<impl NonLinearOp<M = M, V = M::V, T = M::T>>,
SolverProblem<impl NonLinearOpJacobian<M = M, V = M::V, T = M::T>>,
Vec<LinearSolveSolution<M::V>>,
) {
let diagonal = M::V::from_vec(vec![2.0.into(), 2.0.into()]);
Expand Down Expand Up @@ -99,7 +100,7 @@ pub mod tests {
problem: SolverProblem<C>,
solns: Vec<LinearSolveSolution<C::V>>,
) where
C: NonLinearOp,
C: NonLinearOpJacobian,
for<'a> &'a C::V: VectorRef<C::V>,
{
solver.set_problem(&problem);
Expand Down
13 changes: 7 additions & 6 deletions src/linear_solver/nalgebra/lu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ use crate::{
error::{DiffsolError, LinearSolverError},
linear_solver_error,
matrix::sparsity::MatrixSparsityRef,
op::{linearise::LinearisedOp, NonLinearOp},
LinearOp, LinearSolver, Matrix, Op, Scalar, SolverProblem,
op::linearise::LinearisedOp,
NonLinearOpJacobian,
LinearOpMatrix, LinearSolver, Matrix, Op, Scalar, SolverProblem,
};

/// A [LinearSolver] that uses the LU decomposition in the [`nalgebra` library](https://nalgebra.org/) to solve the linear system.
#[derive(Clone)]
pub struct LU<T, C>
where
T: Scalar,
C: NonLinearOp<M = DMatrix<T>, V = DVector<T>, T = T>,
C: NonLinearOpJacobian<M = DMatrix<T>, V = DVector<T>, T = T>,
{
matrix: Option<DMatrix<T>>,
lu: Option<nalgebra::LU<T, Dyn, Dyn>>,
Expand All @@ -24,7 +25,7 @@ where
impl<T, C> Default for LU<T, C>
where
T: Scalar,
C: NonLinearOp<M = DMatrix<T>, V = DVector<T>, T = T>,
C: NonLinearOpJacobian<M = DMatrix<T>, V = DVector<T>, T = T>,
{
fn default() -> Self {
Self {
Expand All @@ -35,10 +36,10 @@ where
}
}

impl<T: Scalar, C: NonLinearOp<M = DMatrix<T>, V = DVector<T>, T = T>> LinearSolver<C>
impl<T: Scalar, C: NonLinearOpJacobian<M = DMatrix<T>, V = DVector<T>, T = T>> LinearSolver<C>
for LU<T, C>
{
type SelfNewOp<C2: NonLinearOp<T = C::T, V = C::V, M = C::M>> = LU<T, C2>;
type SelfNewOp<C2: NonLinearOpJacobian<T = C::T, V = C::V, M = C::M>> = LU<T, C2>;

fn solve_in_place(&self, state: &mut C::V) -> Result<(), DiffsolError> {
if self.lu.is_none() {
Expand Down
10 changes: 5 additions & 5 deletions src/linear_solver/suitesparse/klu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::{
matrix::MatrixCommon,
op::linearise::LinearisedOp,
vector::Vector,
LinearOp, Matrix, MatrixSparsityRef, NonLinearOp, Op, SolverProblem, SparseColMat,
LinearOpMatrix, Matrix, MatrixSparsityRef, NonLinearOpJacobian, Op, SolverProblem, SparseColMat,
};

trait MatrixKLU: Matrix<T = f64> {
Expand Down Expand Up @@ -159,7 +159,7 @@ impl KluCommon {
pub struct KLU<M, C>
where
M: Matrix,
C: NonLinearOp<M = M, V = M::V, T = M::T>,
C: NonLinearOpJacobian<M = M, V = M::V, T = M::T>,
{
klu_common: RefCell<KluCommon>,
klu_symbolic: Option<KluSymbolic>,
Expand All @@ -171,7 +171,7 @@ where
impl<M, C> Default for KLU<M, C>
where
M: Matrix,
C: NonLinearOp<M = M, V = M::V, T = M::T>,
C: NonLinearOpJacobian<M = M, V = M::V, T = M::T>,
{
fn default() -> Self {
let klu_common = KluCommon::default();
Expand All @@ -190,9 +190,9 @@ impl<M, C> LinearSolver<C> for KLU<M, C>
where
M: MatrixKLU,
M::V: VectorKLU,
C: NonLinearOp<M = M, V = M::V, T = M::T>,
C: NonLinearOpJacobian<M = M, V = M::V, T = M::T>,
{
type SelfNewOp<C2: NonLinearOp<T = C::T, V = C::V, M = C::M>> = KLU<M, C2>;
type SelfNewOp<C2: NonLinearOpJacobian<T = C::T, V = C::V, M = C::M>> = KLU<M, C2>;

fn set_linearisation(&mut self, x: &C::V, t: C::T) {
Rc::<LinearisedOp<C>>::get_mut(&mut self.problem.as_mut().expect("Problem not set").f)
Expand Down
14 changes: 7 additions & 7 deletions src/linear_solver/sundials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::sundials_sys::{

use crate::{
error::*, linear_solver_error, ode_solver::sundials::sundials_check,
op::linearise::LinearisedOp, vector::sundials::SundialsVector, LinearOp, Matrix, NonLinearOp,
op::linearise::LinearisedOp, vector::sundials::SundialsVector, LinearOpMatrix, Matrix, NonLinearOpJacobian,
Op, SolverProblem, SundialsMatrix,
};

Expand All @@ -17,7 +17,7 @@ use super::LinearSolver;

pub struct SundialsLinearSolver<Op>
where
Op: NonLinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>,
Op: NonLinearOpJacobian<M = SundialsMatrix, V = SundialsVector, T = realtype>,
{
linear_solver: Option<SUNLinearSolver>,
problem: Option<SolverProblem<LinearisedOp<Op>>>,
Expand All @@ -27,7 +27,7 @@ where

impl<Op> Default for SundialsLinearSolver<Op>
where
Op: NonLinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>,
Op: NonLinearOpJacobian<M = SundialsMatrix, V = SundialsVector, T = realtype>,
{
fn default() -> Self {
Self::new_dense()
Expand All @@ -36,7 +36,7 @@ where

impl<Op> SundialsLinearSolver<Op>
where
Op: NonLinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>,
Op: NonLinearOpJacobian<M = SundialsMatrix, V = SundialsVector, T = realtype>,
{
pub fn new_dense() -> Self {
Self {
Expand All @@ -50,7 +50,7 @@ where

impl<Op> Drop for SundialsLinearSolver<Op>
where
Op: NonLinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>,
Op: NonLinearOpJacobian<M = SundialsMatrix, V = SundialsVector, T = realtype>,
{
fn drop(&mut self) {
if let Some(linear_solver) = self.linear_solver {
Expand All @@ -61,9 +61,9 @@ where

impl<Op> LinearSolver<Op> for SundialsLinearSolver<Op>
where
Op: NonLinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>,
Op: NonLinearOpJacobian<M = SundialsMatrix, V = SundialsVector, T = realtype>,
{
type SelfNewOp<C2: NonLinearOp<T = Op::T, V = Op::V, M = Op::M>> = SundialsLinearSolver<C2>;
type SelfNewOp<C2: NonLinearOpJacobian<T = Op::T, V = Op::V, M = Op::M>> = SundialsLinearSolver<C2>;

fn set_problem(&mut self, problem: &SolverProblem<Op>) {
let linearised_problem = problem.linearise();
Expand Down
6 changes: 3 additions & 3 deletions src/matrix/default_solver.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::{LinearSolver, NonLinearOp};
use crate::{LinearSolver, NonLinearOpJacobian};

use super::Matrix;

pub trait DefaultSolver: Matrix {
type LS<C: NonLinearOp<M = Self, V = Self::V, T = Self::T>>: LinearSolver<C> + Default;
fn default_solver<C: NonLinearOp<M = Self, V = Self::V, T = Self::T>>() -> Self::LS<C> {
type LS<C: NonLinearOpJacobian<M = Self, V = Self::V, T = Self::T>>: LinearSolver<C> + Default;
fn default_solver<C: NonLinearOpJacobian<M = Self, V = Self::V, T = Self::T>>() -> Self::LS<C> {
Self::LS::default()
}
}
Loading

0 comments on commit 4126bac

Please sign in to comment.