Skip to content

Commit

Permalink
refactor: remove use of Rc in OdeEquations trait (#104)
Browse files Browse the repository at this point in the history
refactor: sparsity returned as owned
refactor: sparsity functions moved from Op to relevent traits
refactor: diffsl struct owns context, now has 'static lifetime
  • Loading branch information
martinjrobins authored Nov 1, 2024
1 parent 033e8be commit 29e2a4c
Show file tree
Hide file tree
Showing 44 changed files with 1,009 additions and 983 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ nalgebra = []
sundials = ["suitesparse_sys", "bindgen", "cc"]
suitesparse = ["suitesparse_sys"]
diffsl-cranelift = ["diffsl-no-llvm", "diffsl"]
diffsl = []
diffsl = [ ]
diffsl-llvm = []
diffsl-llvm13 = ["diffsl13-0", "diffsl-llvm", "diffsl"]
diffsl-llvm14 = ["diffsl14-0", "diffsl-llvm", "diffsl"]
Expand All @@ -29,7 +29,6 @@ diffsl-llvm17 = ["diffsl17-0", "diffsl-llvm", "diffsl"]
nalgebra = "0.33"
nalgebra-sparse = { version = "0.10", features = ["io"] }
num-traits = "0.2.17"
ouroboros = "0.18.2"
serde = { version = "1.0.196", features = ["derive"] }
diffsl-no-llvm = { package = "diffsl", version = "=0.2.0", optional = true }
diffsl13-0 = { package = "diffsl", version = "=0.2.0", features = ["llvm13-0"], optional = true }
Expand Down
18 changes: 5 additions & 13 deletions benches/ode_solvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use criterion::{criterion_group, criterion_main, Criterion};
use diffsol::{
ode_solver::test_models::{
exponential_decay::exponential_decay_problem, foodweb::foodweb_problem,
foodweb::FoodWebContext, heat2d::head2d_problem, robertson::robertson,
robertson_ode::robertson_ode,
heat2d::head2d_problem, robertson::robertson, robertson_ode::robertson_ode,
},
FaerLU, FaerSparseLU, NalgebraLU, SparseColMat,
};
Expand Down Expand Up @@ -222,10 +221,8 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function(stringify!($name), |b| {
use diffsol::diffsl::LlvmModule;
use diffsol::ode_solver::test_models::robertson::*;
let mut context = diffsol::DiffSlContext::default();
robertson_diffsl_compile::<$matrix, LlvmModule>(&mut context);
b.iter(|| {
let (problem, soln) = robertson_diffsl_problem(&mut context, false);
let (problem, soln) = robertson_diffsl_problem::<$matrix, LlvmModule>();
let ls = $linear_solver::default();
benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls)
})
Expand Down Expand Up @@ -336,8 +333,7 @@ fn criterion_benchmark(c: &mut Criterion) {
($name:ident, $solver:ident, $linear_solver:ident, $model:ident, $model_problem:ident, $matrix:ty, $($N:expr),+) => {
$(c.bench_function(concat!(stringify!($name), "_", $N), |b| {
b.iter(|| {
let context = FoodWebContext::default();
let (problem, soln) = $model_problem::<$matrix, $N>(&context);
let (problem, soln) = $model_problem::<$matrix, $N>();
let ls = $linear_solver::default();
benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls)
})
Expand Down Expand Up @@ -429,10 +425,8 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function(concat!(stringify!($name), "_", $N), |b| {
use diffsol::ode_solver::test_models::heat2d::*;
use diffsol::diffsl::LlvmModule;
let mut context = diffsol::DiffSlContext::default();
heat2d_diffsl_compile::<$matrix, LlvmModule, $N>(&mut context);
b.iter(|| {
let (problem, soln) = heat2d_diffsl_problem(&mut context);
let (problem, soln) = heat2d_diffsl_problem::<$matrix, LlvmModule, $N>();
let ls = $linear_solver::default();
benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls)
})
Expand Down Expand Up @@ -506,10 +500,8 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function(concat!(stringify!($name), "_", $N), |b| {
use diffsol::ode_solver::test_models::foodweb::*;
use diffsol::diffsl::LlvmModule;
let mut context = diffsol::DiffSlContext::default();
foodweb_diffsl_compile::<$matrix, LlvmModule, $N>(&mut context);
b.iter(|| {
let (problem, soln) = foodweb_diffsl_problem(&mut context);
let (problem, soln) = foodweb_diffsl_problem::<$matrix, LlvmModule, $N>();
let ls = $linear_solver::default();
benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls)
})
Expand Down
34 changes: 17 additions & 17 deletions src/jacobian/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::collections::HashSet;

use crate::{
LinearOp, LinearOpTranspose, Matrix, MatrixSparsityRef, NonLinearOp, NonLinearOpAdjoint,
NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Op, Scalar, Vector, VectorIndex,
LinearOp, LinearOpTranspose, Matrix, MatrixSparsity, NonLinearOp, NonLinearOpAdjoint,
NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Scalar, Vector, VectorIndex,
};
use num_traits::{One, Zero};

Expand Down Expand Up @@ -93,12 +93,9 @@ pub struct JacobianColoring<M: Matrix> {
}

impl<M: Matrix> JacobianColoring<M> {
pub fn new_from_non_zeros<F: Op<M = M>>(op: &F, non_zeros: Vec<(usize, usize)>) -> Self {
let sparsity = op
.sparsity()
.expect("Jacobian sparsity not defined, cannot use coloring");
let ncols = op.nstates();
let graph = nonzeros2graph(non_zeros.as_slice(), ncols);
pub fn new(sparsity: &impl MatrixSparsity<M>, non_zeros: &[(usize, usize)]) -> Self {
let ncols = sparsity.ncols();
let graph = nonzeros2graph(non_zeros, ncols);
let coloring = color_graph_greedy(&graph);
let max_color = coloring.iter().max().copied().unwrap_or(0);
let mut dst_indices_per_color = Vec::new();
Expand Down Expand Up @@ -224,7 +221,6 @@ mod tests {
use std::rc::Rc;

use crate::jacobian::{find_jacobian_non_zeros, JacobianColoring};
use crate::matrix::sparsity::MatrixSparsityRef;
use crate::matrix::Matrix;
use crate::op::linear_closure::LinearClosure;
use crate::vector::Vector;
Expand All @@ -238,8 +234,6 @@ mod tests {
use num_traits::{One, Zero};
use std::ops::MulAssign;

use super::find_matrix_non_zeros;

fn helper_triplets2op_nonlinear<'a, M: Matrix + 'a>(
triplets: &'a [(usize, usize, M::T)],
nrows: usize,
Expand Down Expand Up @@ -394,9 +388,12 @@ mod tests {
let op = helper_triplets2op_nonlinear::<M>(triplets.as_slice(), n, n);
let y0 = M::V::zeros(n);
let t0 = M::T::zero();
let non_zeros = find_jacobian_non_zeros(&op, &y0, t0);
let coloring = JacobianColoring::new_from_non_zeros(&op, non_zeros);
let mut jac = M::new_from_sparsity(3, 3, op.sparsity().map(|s| s.to_owned()));
let nonzeros = triplets
.iter()
.map(|(i, j, _v)| (*i, *j))
.collect::<Vec<_>>();
let coloring = JacobianColoring::new(&op.jacobian_sparsity().unwrap(), &nonzeros);
let mut jac = M::new_from_sparsity(3, 3, op.jacobian_sparsity());
coloring.jacobian_inplace(&op, &y0, t0, &mut jac);
let mut gemv1 = M::V::zeros(n);
let v = M::V::from_element(3, M::T::one());
Expand All @@ -410,9 +407,12 @@ mod tests {
for triplets in test_triplets {
let op = helper_triplets2op_linear::<M>(triplets.as_slice(), n, n);
let t0 = M::T::zero();
let non_zeros = find_matrix_non_zeros(&op, t0);
let coloring = JacobianColoring::new_from_non_zeros(&op, non_zeros);
let mut jac = M::new_from_sparsity(3, 3, op.sparsity().map(|s| s.to_owned()));
let nonzeros = triplets
.iter()
.map(|(i, j, _v)| (*i, *j))
.collect::<Vec<_>>();
let coloring = JacobianColoring::new(&op.sparsity().unwrap(), &nonzeros);
let mut jac = M::new_from_sparsity(3, 3, op.sparsity());
coloring.matrix_inplace(&op, t0, &mut jac);
let mut gemv1 = M::V::zeros(n);
let v = M::V::from_element(3, M::T::one());
Expand Down
19 changes: 10 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! ## Solving ODEs
//!
//! The simplest way to create a new problem is to use the [OdeBuilder] struct. You can set the initial time, initial step size, relative tolerance, absolute tolerance, and parameters,
//! or leave them at their default values. Then, call one of the `build_*` functions (e.g. [OdeBuilder::build_ode], [OdeBuilder::build_ode_with_mass], [OdeBuilder::build_diffsl]) to create a [OdeSolverProblem].
//! or leave them at their default values. Then, call one of the `build_*` functions (e.g. [OdeBuilder::build_ode], [OdeBuilder::build_ode_with_mass], [OdeBuilder::build_from_eqn]) to create a [OdeSolverProblem].
//!
//! You will also need to choose a matrix type to use. DiffSol can use the [nalgebra](https://nalgebra.org) `DMatrix` type, the [faer](https://github.com/sarah-ek/faer-rs) `Mat` type, or any other type that implements the
//! [Matrix] trait.
Expand Down Expand Up @@ -35,7 +35,7 @@
//! DiffSL is a domain-specific language for specifying differential equations <https://github.com/martinjrobins/diffsl>. It uses the LLVM compiler framwork
//! to compile the equations to efficient machine code and uses the EnzymeAD library to compute the jacobian.
//!
//! You can use DiffSL with DiffSol using the [DiffSlContext] struct and [OdeBuilder::build_diffsl] method. You need to enable one of the `diffsl-llvm*` features
//! You can use DiffSL with DiffSol using the [DiffSlContext] and [DiffSl] structs and [OdeBuilder::build_from_eqn] method. You need to enable one of the `diffsl-llvm*` features
//! corresponding to the version of LLVM you have installed. E.g. to use your LLVM 10 installation, enable the `diffsl-llvm10` feature.
//!
//! For more information on the DiffSL language, see the [DiffSL documentation](https://martinjrobins.github.io/diffsl/)
Expand All @@ -54,7 +54,7 @@
//! of the output vector `J(x) v` are also `NaN`, using the fact that `NaN`s propagate through most operations. However, this method is not foolproof and will fail if,
//! for example, your jacobian function uses any control flow that depends on the input vector. If this is the case, you can provide the jacobian matrix directly by
//! implementing the optional [NonLinearOpJacobian::jacobian_inplace] and the [LinearOp::matrix_inplace] (if applicable) functions,
//! or by providing a sparsity pattern using the [Op::sparsity] function.
//! or by providing a sparsity pattern using the [NonLinearOpJacobian::jacobian_sparsity] and [LinearOp::sparsity] functions.
//!
//! ## Events / Root finding
//!
Expand Down Expand Up @@ -173,7 +173,7 @@ pub use ode_solver::sundials::SundialsIda;
pub use linear_solver::suitesparse::klu::KLU;

#[cfg(feature = "diffsl")]
pub use ode_solver::diffsl::DiffSlContext;
pub use ode_solver::diffsl::{DiffSl, DiffSlContext};

pub use jacobian::{
find_adjoint_non_zeros, find_jacobian_non_zeros, find_matrix_non_zeros,
Expand All @@ -196,11 +196,12 @@ pub use ode_solver::{
bdf_state::BdfState, builder::OdeBuilder, checkpointing::Checkpointing,
checkpointing::HermiteInterpolator, equations::AugmentedOdeEquations,
equations::AugmentedOdeEquationsImplicit, equations::NoAug, equations::OdeEquations,
equations::OdeEquationsAdjoint, equations::OdeEquationsImplicit, equations::OdeEquationsSens,
equations::OdeSolverEquations, method::AdjointOdeSolverMethod, method::OdeSolverMethod,
method::OdeSolverStopReason, method::SensitivitiesOdeSolverMethod, problem::OdeSolverProblem,
sdirk::Sdirk, sdirk::SdirkAdj, sdirk_state::SdirkState, sens_equations::SensEquations,
sens_equations::SensInit, sens_equations::SensRhs, state::OdeSolverState, tableau::Tableau,
equations::OdeEquationsAdjoint, equations::OdeEquationsImplicit, equations::OdeEquationsRef,
equations::OdeEquationsSens, equations::OdeSolverEquations, method::AdjointOdeSolverMethod,
method::OdeSolverMethod, method::OdeSolverStopReason, method::SensitivitiesOdeSolverMethod,
problem::OdeSolverProblem, sdirk::Sdirk, sdirk::SdirkAdj, sdirk_state::SdirkState,
sens_equations::SensEquations, sens_equations::SensInit, sens_equations::SensRhs,
state::OdeSolverState, tableau::Tableau,
};
pub use op::constant_op::{ConstantOp, ConstantOpSens, ConstantOpSensAdjoint};
pub use op::linear_op::{LinearOp, LinearOpSens, LinearOpTranspose};
Expand Down
5 changes: 2 additions & 3 deletions src/linear_solver/faer/lu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use std::rc::Rc;
use crate::{error::LinearSolverError, linear_solver_error};

use crate::{
error::DiffsolError, linear_solver::LinearSolver, Matrix, MatrixSparsityRef,
NonLinearOpJacobian, Scalar,
error::DiffsolError, linear_solver::LinearSolver, Matrix, NonLinearOpJacobian, Scalar,
};

use faer::{linalg::solvers::FullPivLu, solvers::SpSolver, Col, Mat};
Expand Down Expand Up @@ -58,7 +57,7 @@ impl<T: Scalar> LinearSolver<Mat<T>> for LU<T> {
) {
let ncols = op.nstates();
let nrows = op.nout();
let matrix = C::M::new_from_sparsity(nrows, ncols, op.sparsity().map(|s| s.to_owned()));
let matrix = C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity());
self.matrix = Some(matrix);
}
}
8 changes: 1 addition & 7 deletions src/linear_solver/faer/sparse_lu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::{
error::{DiffsolError, LinearSolverError},
linear_solver::LinearSolver,
linear_solver_error,
matrix::sparsity::MatrixSparsityRef,
scalar::IndexType,
Matrix, NonLinearOpJacobian, Scalar, SparseColMat,
};
Expand Down Expand Up @@ -73,12 +72,7 @@ impl<T: Scalar> LinearSolver<SparseColMat<T>> for FaerSparseLU<T> {
) {
let ncols = op.nstates();
let nrows = op.nout();
let matrix = C::M::new_from_sparsity(
nrows,
ncols,
op.sparsity()
.map(|s| MatrixSparsityRef::<SparseColMat<T>>::to_owned(&s)),
);
let matrix = C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity());
self.matrix = Some(matrix);
self.lu_symbolic = Some(
SymbolicLu::try_new(self.matrix.as_ref().unwrap().faer().symbolic())
Expand Down
6 changes: 2 additions & 4 deletions src/linear_solver/nalgebra/lu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ use nalgebra::{DMatrix, DVector, Dyn};

use crate::{
error::{DiffsolError, LinearSolverError},
linear_solver_error,
matrix::sparsity::MatrixSparsityRef,
LinearSolver, Matrix, NonLinearOpJacobian, Scalar,
linear_solver_error, LinearSolver, Matrix, NonLinearOpJacobian, Scalar,
};

/// A [LinearSolver] that uses the LU decomposition in the [`nalgebra` library](https://nalgebra.org/) to solve the linear system.
Expand Down Expand Up @@ -62,7 +60,7 @@ impl<T: Scalar> LinearSolver<DMatrix<T>> for LU<T> {
) {
let ncols = op.nstates();
let nrows = op.nout();
let matrix = C::M::new_from_sparsity(nrows, ncols, op.sparsity().map(|s| s.to_owned()));
let matrix = C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity());
self.matrix = Some(matrix);
}
}
4 changes: 2 additions & 2 deletions src/linear_solver/suitesparse/klu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::{
linear_solver_error,
matrix::MatrixCommon,
vector::Vector,
Matrix, MatrixSparsityRef, NonLinearOpJacobian, SparseColMat,
Matrix, NonLinearOpJacobian, SparseColMat,
};

trait MatrixKLU: Matrix<T = f64> {
Expand Down Expand Up @@ -231,7 +231,7 @@ where
) {
let ncols = op.nstates();
let nrows = op.nout();
let mut matrix = C::M::new_from_sparsity(nrows, ncols, op.sparsity().map(|s| s.to_owned()));
let mut matrix = C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity());
let mut klu_common = self.klu_common.borrow_mut();
self.klu_symbolic = KluSymbolic::try_from_matrix(&mut matrix, klu_common.as_mut()).ok();
self.matrix = Some(matrix);
Expand Down
2 changes: 1 addition & 1 deletion src/matrix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ pub trait MatrixView<'a>:
}

/// A base matrix trait (including sparse and dense matrices)
pub trait Matrix: MatrixCommon + Mul<Scale<Self::T>, Output = Self> + Clone {
pub trait Matrix: MatrixCommon + Mul<Scale<Self::T>, Output = Self> + Clone + 'static {
type Sparsity: MatrixSparsity<Self>;
type SparsityRef<'a>: MatrixSparsityRef<'a, Self>
where
Expand Down
52 changes: 26 additions & 26 deletions src/matrix/sparse_faer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,32 @@ impl<T: Scalar> MatrixSparsity<SparseColMat<T>> for SymbolicSparseColMat<IndexTy
Err(e) => Err(DiffsolError::Other(e.to_string())),
}
}

fn get_index(
&self,
rows: &[IndexType],
cols: &[IndexType],
) -> <<SparseColMat<T> as MatrixCommon>::V as Vector>::Index {
let col_ptrs = self.col_ptrs();
let row_indices = self.row_indices();
let mut indices = Vec::with_capacity(rows.len());
for (&i, &j) in rows.iter().zip(cols.iter()) {
let col_ptr = col_ptrs[j];
let next_col_ptr = col_ptrs[j + 1];
for (ii, &ri) in row_indices
.iter()
.enumerate()
.take(next_col_ptr)
.skip(col_ptr)
{
if ri == i {
indices.push(ii);
break;
}
}
}
indices
}
}

impl<'a, T: Scalar> MatrixSparsityRef<'a, SparseColMat<T>>
Expand Down Expand Up @@ -132,32 +158,6 @@ impl<'a, T: Scalar> MatrixSparsityRef<'a, SparseColMat<T>>
}
indices
}

fn get_index(
&self,
rows: &[IndexType],
cols: &[IndexType],
) -> <<SparseColMat<T> as MatrixCommon>::V as Vector>::Index {
let col_ptrs = self.col_ptrs();
let row_indices = self.row_indices();
let mut indices = Vec::with_capacity(rows.len());
for (&i, &j) in rows.iter().zip(cols.iter()) {
let col_ptr = col_ptrs[j];
let next_col_ptr = col_ptrs[j + 1];
for (ii, &ri) in row_indices
.iter()
.enumerate()
.take(next_col_ptr)
.skip(col_ptr)
{
if ri == i {
indices.push(ii);
break;
}
}
}
indices
}
}

impl<T: Scalar> Mul<Scale<T>> for SparseColMat<T> {
Expand Down
13 changes: 6 additions & 7 deletions src/matrix/sparse_serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,6 @@ impl<T: Scalar> MatrixSparsity<CscMatrix<T>> for SparsityPattern {
major_offsets.push(n);
SparsityPattern::try_from_offsets_and_indices(n, n, major_offsets, minor_indices).unwrap()
}
}

impl<'a, T: Scalar> MatrixSparsityRef<'a, CscMatrix<T>> for &'a SparsityPattern {
fn to_owned(&self) -> SparsityPattern {
SparsityPattern::clone(self)
}

fn get_index(&self, rows: &[IndexType], cols: &[IndexType]) -> DVector<IndexType> {
let mut index = DVector::<IndexType>::zeros(rows.len());
#[allow(unused_mut)]
Expand All @@ -156,6 +149,12 @@ impl<'a, T: Scalar> MatrixSparsityRef<'a, CscMatrix<T>> for &'a SparsityPattern
}
index
}
}

impl<'a, T: Scalar> MatrixSparsityRef<'a, CscMatrix<T>> for &'a SparsityPattern {
fn to_owned(&self) -> SparsityPattern {
SparsityPattern::clone(self)
}

fn nrows(&self) -> IndexType {
self.minor_dim()
Expand Down
Loading

0 comments on commit 29e2a4c

Please sign in to comment.