You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm building a differentiable iterative solver, in my case specifically for bayesian regression. I started off with the code from the implicit layers website (http://implicit-layers-tutorial.org/implicit_functions/), but have added 1) a tolerance keyword and 2) a maximum number of iterations (see below).
This code works (well), but I like to make a minor change that I can't seem to pull off: in bayesian regression (and by extension, sparse bayesian learning), we're learning two parameters: the width of a prior and the noise precision. Right now I just stack these two into a device array of shape (2, ), but I'd like to pass around a tuple consisting of two scalars, or, with SBL, a tuple of an array and a scalar. However, I seem to be running into issues with the backwards pass, specifically with setting z_init. Using jnp.zeros_like(z_star) works for arrays, but not for tuples. If I change it to z_init = tuple(jnp.zeros_like(elem) for elem in z_star), I get issues with the the tangents not having the correct shape. I've also tried z_init = jax.lax.stop_gradient(z_star) or even manually setting the shape with z_init = (jnp.zeros(()), jnp.zeros(())), but none of it works.
Again, my code works so it's not really an issue, but I wonder what I'm doing wrong and how I could fix it :-). Thanks!
def fwd_solver(f, z_init, tol=1e-4, max_iter=300):
def cond_fun(carry):
iteration, z_prev, z = carry
# we check the change in alpha (element 0 in z tuple)
# and the maximum number of iterations
cond_norm = jnp.linalg.norm(z_prev[:-1] - z[:-1]) < tol
cond_iter = iteration >= max_iter
return ~jnp.logical_or(cond_norm, cond_iter)
def body_fun(carry):
iteration, _, z = carry
return iteration + 1, z, f(z)
init_carry = (0, z_init, f(z_init)) # first arg is iteration count
_, _, z_star = lax.while_loop(cond_fun, body_fun, init_carry)
return z_star
@partial(jax.custom_vjp, nondiff_argnums=(0, ))
@partial(jit, static_argnums=(0,))
def fixed_point_solver(f, args, z_init, tol=1e-5, max_iter=300):
z_star = fwd_solver(
lambda z: f(z, *args), z_init=z_init, tol=tol, max_iter=max_iter
)
return z_star
@partial(jit, static_argnums=(0,))
def fixed_point_solver_fwd(f, args, z_init, tol, max_iter):
z_star = fixed_point_solver(f, args, z_init, tol, max_iter)
return z_star, (z_star, tol, max_iter, args)
@partial(jit, static_argnums=(0,))
def fixed_point_solver_bwd(f, res, z_star_bar):
z_star, tol, max_iter, args = res
_, vjp_a = jax.vjp(lambda args: f(z_star, *args), args)
_, vjp_z = jax.vjp(lambda z: f(z, *args), z_star)
res = vjp_a(
fwd_solver(
lambda u: vjp_z(u)[0] + z_star_bar,
z_init=jnp.zeros_like(z_star),
tol=tol,
max_iter=max_iter,
)
)
return (*res, None, None, None) # None for init, tol and max_iter
fixed_point_solver.defvjp(fixed_point_solver_fwd, fixed_point_solver_bwd)```
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I'm building a differentiable iterative solver, in my case specifically for bayesian regression. I started off with the code from the implicit layers website (http://implicit-layers-tutorial.org/implicit_functions/), but have added 1) a tolerance keyword and 2) a maximum number of iterations (see below).
This code works (well), but I like to make a minor change that I can't seem to pull off: in bayesian regression (and by extension, sparse bayesian learning), we're learning two parameters: the width of a prior and the noise precision. Right now I just stack these two into a device array of shape (2, ), but I'd like to pass around a tuple consisting of two scalars, or, with SBL, a tuple of an array and a scalar. However, I seem to be running into issues with the backwards pass, specifically with setting z_init. Using
jnp.zeros_like(z_star)
works for arrays, but not for tuples. If I change it toz_init = tuple(jnp.zeros_like(elem) for elem in z_star)
, I get issues with the the tangents not having the correct shape. I've also triedz_init = jax.lax.stop_gradient(z_star)
or even manually setting the shape withz_init = (jnp.zeros(()), jnp.zeros(()))
, but none of it works.Again, my code works so it's not really an issue, but I wonder what I'm doing wrong and how I could fix it :-). Thanks!
Beta Was this translation helpful? Give feedback.
All reactions