Skip to content

Commit

Permalink
Merge pull request #13 from KGrewal1/v0.5
Browse files Browse the repository at this point in the history
V0.5
  • Loading branch information
KGrewal1 authored May 4, 2024
2 parents 1fbac76 + 772f45d commit 11655e2
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 108 deletions.
20 changes: 10 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "candle-optimisers"
version = "0.4.0"
version = "0.5.0"
edition = "2021"
readme = "README.md"
license = "MIT"
Expand All @@ -17,15 +17,15 @@ exclude = [

[dependencies]

candle-core = "0.4.0"
candle-nn = "0.4.0"
candle-core = "0.5.0"
candle-nn = "0.5.0"
log = "0.4.20"


[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
assert_approx_eq = "1.1.0"
candle-datasets = "0.4.0"
candle-datasets = "0.5.0"
clap = {version = "4.4.6", features = ["derive"] }
criterion = { version = "0.5.1", features = ["html_reports"] }

Expand All @@ -41,12 +41,12 @@ cuda = ["candle-core/cuda", "candle-nn/cuda"]
lto = true # maximal LTO optimisaiton

[lints.clippy]
pedantic = "warn"
suspicious = "warn"
perf = "warn"
complexity = "warn"
style = "warn"
cargo = "warn"
pedantic = {level = "warn", priority = -1}
suspicious = {level = "warn", priority = -1}
perf = {level = "warn", priority = -1}
complexity = {level = "warn", priority = -1}
style = {level = "warn", priority = -1}
cargo = {level = "warn", priority = -1}
imprecise_flops = "warn"
missing_errors_doc = {level = "allow", priority = 1}
uninlined_format_args = {level = "allow", priority = 1}
Expand Down
5 changes: 5 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## v0.5.0 (2024-02-28)

* Bump candle requirtement to 0.5.0: this is considered a breaking change due to the reliance of this library on candle-core and candle-nn
* Internal changes for LBFGS line search

## v0.4.0 (2024-02-28)

* Bump candle requirtement to 0.4.0: this is considered a breaking change due to the reliance of this library on candle-core and candle-nn
Expand Down
8 changes: 3 additions & 5 deletions src/lbfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl<M: Model> LossOptimizer<M> for Lbfgs<M> {
let mut evals = 1;

let grad = if let Some(this_grad) = &self.next_grad {
this_grad.as_tensor().clone()
this_grad.as_tensor().copy()?
} else {
flat_grads(&self.vars, loss, self.params.weight_decay)?
};
Expand Down Expand Up @@ -302,10 +302,8 @@ impl<M: Model> LossOptimizer<M> for Lbfgs<M> {
if let Some(ls) = &self.params.line_search {
match ls {
LineSearch::StrongWolfe(c1, c2, tol) => {
let (loss, grad, t, steps) = self.strong_wolfe(
lr, &q, loss, //.to_dtype(candle_core::DType::F64)?.to_scalar()?
&grad, dd, *c1, *c2, *tol, 25,
)?;
let (loss, grad, t, steps) =
self.strong_wolfe(lr, &q, loss, &grad, dd, *c1, *c2, *tol, 25)?;
if let Some(next_grad) = &self.next_grad {
next_grad.set(&grad)?;
} else {
Expand Down
155 changes: 62 additions & 93 deletions src/lbfgs/strong_wolfe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ impl<M: Model> Lbfgs<M> {
) -> CResult<(Tensor, Tensor, f64, usize)> {
// ported from https://github.com/torch/optim/blob/master/lswolfe.lua

let dtype = loss.dtype();
let shape = loss.shape();
let dev = loss.device();

let d_norm = &direction
.abs()?
.max(0)?
Expand All @@ -92,7 +96,9 @@ impl<M: Model> Lbfgs<M> {
// evaluate objective and gradient using initial step
let (f_new, g_new, mut l2_new) = self.directional_evaluate(step_size, direction)?;
let g_new = Var::from_tensor(&g_new)?;
let f_new = Var::from_tensor(&f_new)?;
let mut f_new = f_new
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?;
let mut ls_func_evals = 1;
let mut gtd_new = g_new
.unsqueeze(0)?
Expand All @@ -103,8 +109,10 @@ impl<M: Model> Lbfgs<M> {
.to_scalar::<f64>()?;

// bracket an interval containing a point satisfying the Wolfe criteria
let g_prev = Var::from_tensor(grad)?;
let f_prev = Var::from_tensor(loss)?;
let grad_det = grad.copy()?;
let g_prev = Var::from_tensor(&grad_det)?;
let scalar_loss = loss.to_dtype(candle_core::DType::F64)?.to_scalar::<f64>()?;
let mut f_prev = scalar_loss;
let l2_init = self.l2_reg()?;
let mut l2_prev = l2_init;
let (mut t_prev, mut gtd_prev) = (0., directional_grad);
Expand All @@ -113,21 +121,13 @@ impl<M: Model> Lbfgs<M> {

let mut bracket_gtd;
let mut bracket_l2;
let bracket_f;
let mut bracket_f;
let (mut bracket, bracket_g) = loop {
// check conditions
if f_new
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ l2_new
>= f_prev
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ l2_prev
{
if f_new + l2_new >= f_prev + l2_prev {
bracket_gtd = [gtd_prev, gtd_new];
bracket_l2 = [l2_prev, l2_new];
bracket_f = [f_prev, Var::from_tensor(f_new.as_tensor())?];
bracket_f = [f_prev, f_new];
break (
[t_prev, step_size],
[g_prev, Var::from_tensor(g_new.as_tensor())?],
Expand All @@ -138,14 +138,11 @@ impl<M: Model> Lbfgs<M> {
done = true;
bracket_gtd = [gtd_prev, gtd_new];
bracket_l2 = [l2_prev, l2_new];
bracket_f = [
Var::from_tensor(f_new.as_tensor())?,
Var::from_tensor(f_new.as_tensor())?,
];
bracket_f = [f_new, f_new];
break (
[step_size, step_size],
[
Var::from_tensor(g_new.as_tensor())?,
Var::from_tensor(&g_new.as_tensor().copy()?)?,
Var::from_tensor(g_new.as_tensor())?,
],
);
Expand All @@ -154,7 +151,7 @@ impl<M: Model> Lbfgs<M> {
if gtd_new >= 0. {
bracket_gtd = [gtd_prev, gtd_new];
bracket_l2 = [l2_prev, l2_new];
bracket_f = [f_prev, Var::from_tensor(f_new.as_tensor())?];
bracket_f = [f_prev, f_new];
break (
[t_prev, step_size],
[g_prev, Var::from_tensor(g_new.as_tensor())?],
Expand All @@ -167,31 +164,27 @@ impl<M: Model> Lbfgs<M> {
let tmp = step_size;
step_size = cubic_interpolate(
t_prev,
f_prev
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ l2_prev,
f_prev + l2_prev,
gtd_prev,
step_size,
f_new
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ l2_new,
f_new + l2_new,
gtd_new,
Some((min_step, max_step)),
);

// next step
t_prev = tmp;
f_prev.set(f_new.as_tensor())?;
f_prev = f_new;
g_prev.set(g_new.as_tensor())?;
l2_prev = l2_new;
gtd_prev = gtd_new;
// assign to temp vars:
let (next_f, next_g, next_l2) = self.directional_evaluate(step_size, direction)?;

// overwrite
f_new.set(&next_f)?;
f_new = next_f
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?;
g_new.set(&next_g)?;
l2_new = next_l2;

Expand All @@ -210,10 +203,7 @@ impl<M: Model> Lbfgs<M> {
if ls_iter == max_ls {
bracket_gtd = [gtd_prev, gtd_new];
bracket_l2 = [l2_prev, l2_new];
bracket_f = [
Var::from_tensor(loss)?,
Var::from_tensor(f_new.as_tensor())?,
];
bracket_f = [scalar_loss, f_new];
break (
[0., step_size],
[
Expand All @@ -229,19 +219,12 @@ impl<M: Model> Lbfgs<M> {
// exact point satisfying the criteria
let mut insuf_progress = false;
// find high and low points in bracket
let (mut low_pos, mut high_pos) = if bracket_f[0]
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ bracket_l2[0]
<= bracket_f[1]
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ bracket_l2[1]
{
(0, 1)
} else {
(1, 0)
};
let (mut low_pos, mut high_pos) =
if bracket_f[0] + bracket_l2[0] <= bracket_f[1] + bracket_l2[1] {
(0, 1)
} else {
(1, 0)
};
while !done && ls_iter < max_ls {
// line-search bracket is so small
if (bracket[1] - bracket[0]).abs() * d_norm < tolerance_change {
Expand All @@ -251,16 +234,10 @@ impl<M: Model> Lbfgs<M> {
// compute new trial value
step_size = cubic_interpolate(
bracket[0],
bracket_f[0]
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ bracket_l2[0],
bracket_f[0] + bracket_l2[0],
bracket_gtd[0],
bracket[1],
bracket_f[1]
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ bracket_l2[1],
bracket_f[1] + bracket_l2[1],
bracket_gtd[1],
None,
);
Expand Down Expand Up @@ -296,12 +273,14 @@ impl<M: Model> Lbfgs<M> {
// assign to temp vars:
let (next_f, next_g, next_l2) = self.directional_evaluate(step_size, direction)?;
// overwrite
f_new.set(&next_f)?;
g_new.set(&next_g)?;
f_new = next_f
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?;

l2_new = next_l2;
ls_func_evals += 1;

gtd_new = g_new
gtd_new = next_g
.unsqueeze(0)?
.matmul(&(direction.unsqueeze(1)?))?
.to_dtype(candle_core::DType::F64)?
Expand All @@ -310,59 +289,39 @@ impl<M: Model> Lbfgs<M> {
.to_scalar::<f64>()?;
ls_iter += 1;

if f_new
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ l2_new
> (loss.to_dtype(candle_core::DType::F64)?.to_scalar::<f64>()?
+ l2_init
+ c1 * step_size * directional_grad)
|| f_new
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ l2_new
>= bracket_f[low_pos]
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ bracket_l2[low_pos]
if f_new + l2_new > (scalar_loss + l2_init + c1 * step_size * directional_grad)
|| f_new + l2_new >= bracket_f[low_pos] + bracket_l2[low_pos]
{
// Armijo condition not satisfied or not lower than lowest point
bracket[high_pos] = step_size;
bracket_f[high_pos].set(&f_new)?;
bracket_g[high_pos].set(g_new.as_tensor())?;
bracket_f[high_pos] = f_new;
bracket_g[high_pos].set(&next_g)?;
bracket_l2[high_pos] = l2_new;
bracket_gtd[high_pos] = gtd_new;

(low_pos, high_pos) = if bracket_f[0]
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ bracket_l2[0]
<= bracket_f[1]
.to_dtype(candle_core::DType::F64)?
.to_scalar::<f64>()?
+ bracket_l2[1]
{
(0, 1)
} else {
(1, 0)
};
(low_pos, high_pos) =
if bracket_f[0] + bracket_l2[0] <= bracket_f[1] + bracket_l2[1] {
(0, 1)
} else {
(1, 0)
};
} else {
if gtd_new.abs() <= -c2 * directional_grad {
// Wolfe conditions satisfied
done = true;
} else if gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0. {
// old low becomes new high
bracket[high_pos] = bracket[low_pos];
bracket_f[high_pos].set(bracket_f[low_pos].as_tensor())?;
bracket_f[high_pos] = bracket_f[low_pos];
bracket_g[high_pos].set(bracket_g[low_pos].as_tensor())?;
bracket_gtd[high_pos] = bracket_gtd[low_pos];
bracket_l2[high_pos] = bracket_l2[low_pos];
}

// new point becomes new low
bracket[low_pos] = step_size;
bracket_f[low_pos].set(f_new.as_tensor())?;
bracket_g[low_pos].set(g_new.as_tensor())?;
bracket_f[low_pos] = f_new;
bracket_g[low_pos].set(&next_g)?;
bracket_gtd[low_pos] = gtd_new;
bracket_l2[low_pos] = l2_new;
}
Expand All @@ -374,9 +333,19 @@ impl<M: Model> Lbfgs<M> {
let [f0, f1] = bracket_f;
if low_pos == 1 {
// if b is the lower value set a to b, else a should be returned
Ok((f1.into_inner(), g1.into_inner(), step_size, ls_func_evals))
Ok((
Tensor::from_slice(&[f1], shape, dev)?.to_dtype(dtype)?,
g1.into_inner(),
step_size,
ls_func_evals,
))
} else {
Ok((f0.into_inner(), g0.into_inner(), step_size, ls_func_evals))
Ok((
Tensor::from_slice(&[f0], shape, dev)?.to_dtype(dtype)?,
g0.into_inner(),
step_size,
ls_func_evals,
))
}
}

Expand Down

0 comments on commit 11655e2

Please sign in to comment.