-
-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inconvenience when using Executor inside a function with generics #377
Comments
Fruit from a Discord discussion with @stefan-k: One way to avoid having to specify many fn generic<T, F>(cost: T, init_param: na::DVector<F>) -> na::DVector<F>
where
F: argmin::core::ArgminFloat + na::RealField + argmin_math::ArgminZero + std::iter::Sum + argmin_math::ArgminMul<na::DVector<F>, na::DVector<F>>,
T: argmin::core::CostFunction<Output = F, Param = na::DVector<F>> + argmin::core::Gradient<Gradient = na::DVector<F>, Param = na::DVector<F>>,
{
let linesearch = MoreThuenteLineSearch::new()
.with_c(F::from(1e-4).unwrap(), F::from(0.9).unwrap())
.unwrap();
let solver = LBFGS::new(linesearch, 7);
let res = Executor::new(cost, solver)
.configure(|state| state.param(init_param).max_iters(100))
.run()
.unwrap();
res.state().get_prev_best_param().unwrap().clone()
} with `Cargo.toml` entriesargmin-math = { version = "0.3", features = ["nalgebra_latest-serde"] }
argmin = "0.8"
nalgebra = { version = "0.32", features = ["rand", "serde-serialize", "rayon"] }
nalgebra-lapack = "0.24" Omitting any of the associated type constraints |
This only works for the nalgebra backend, and not for ndarray: fn optimize_generic5<T, F>(cost: T, init_param: Array1<F>) -> Array1<F>
where
F: ArgminFloat + ArgminZero + std::iter::Sum + ArgminMul<Array1<F>, Array1<F>>,
T: argmin::core::CostFunction<Output = F, Param = Array1<F>>
+ argmin::core::Gradient<Gradient = Array1<F>, Param = Array1<F>>,
{
let linesearch = MoreThuenteLineSearch::new()
.with_c(F::from(1e-4).unwrap(), F::from(0.9).unwrap())
.unwrap();
let solver = LBFGS::new(linesearch, 7);
let res = Executor::new(cost, solver)
.configure(|state| state.param(init_param).max_iters(100))
.run()
.unwrap();
res.state().get_prev_best_param().unwrap().clone()
} vec: fn optimize_generic7<T, F>(cost: T, init_param: Vec<F>) -> Vec<F>
where
F: ArgminFloat + ArgminZero + std::iter::Sum,
T: argmin::core::CostFunction<Output = F, Param = Vec<F>>
+ argmin::core::Gradient<Gradient = Vec<F>, Param = Vec<F>>,
{
let solver = ParticleSwarm::new((init_param, init_param), 40);
let res = Executor::new(cost, solver)
.configure(|state| state.max_iters(100))
.run()?;
res.state().get_prev_best_param().unwrap().clone()
} |
A potential reason for the difference between the backends could be the way the math traits are implemented on the data types. For instance, comparing the impl<N, R, C> ArgminSignum for OMatrix<N, R, C>
where
N: SimdComplexField,
R: Dim,
C: Dim,
DefaultAllocator: Allocator<N, R, C>,
{
#[inline]
fn signum(self) -> OMatrix<N, R, C> {
self.map(|v| v.simd_signum())
}
} with macro_rules! make_signum {
($t:ty) => {
impl ArgminSignum for Array1<$t> {
#[inline]
fn signum(mut self) -> Array1<$t> {
for a in &mut self {
*a = a.signum();
}
self
}
}
impl ArgminSignum for Array2<$t> {
#[inline]
fn signum(mut self) -> Array2<$t> {
let m = self.shape()[0];
let n = self.shape()[1];
for i in 0..m {
for j in 0..n {
self[(i, j)] = self[(i, j)].signum();
}
}
self
}
}
};
}
// [...]
make_signum!(f32);
make_signum!(f64); For I remember trying to implement the math traits for |
Wrapping calls to the
Executor
inside a function is typically not an issue when the types are known:However, if this function is to be generic over the float type, it becomes quite inconvenient:
I can only imagine that things get worse when the function is to be generic over the parameter vector itself.
I'm not sure at this point how to improve this situation. A few ideas I have:
Array1<F>: ArgminMathNdarray
or something like that would suffice. However, due toF
being generic, I'm not sure if this will workI'm open to further ideas on that topic :)
The text was updated successfully, but these errors were encountered: