-
I was trying to implement a bisection algorithm and wanted to create a from jaxopt import Bisection
@partial(jax.custom_jvp, nondiff_argnums=(0,2,3))
def find_root(f, des, a, b): # a,b are the lower and upper bounds; des is an argument for f
root = Bisection(f, a, b, check_bracket=False,
jit=True).run(design=des).params
return root
@find_root.defjvp
def find_root_jvp(f, a, b, primals, tangents):
des, = primals
des_dot, = tangents
root = find_root(f, des, a, b)
# find gradient
_, grad = jax.value_and_grad(f, argnums=(0,1))(root, des)
rhs = -1* grad[0] / grad[1]
root_dot = np.dot(rhs, des_dot)
return root, root_dot I am using the def wrapper(vol_frac):
def opt_func(x , design):
return jax.nn.sigmoid(design + x).mean() - vol_frac
return opt_func # f => optimality function i.e. f(x, design) = 0 at x = root
def dummy_objective(design, reqd_vol_frac):
lb, ub = _construct_bounds_bisection(reqd_vol_frac, design)
f = wrapper(reqd_vol_frac)
root = find_root(f, design, lb, ub)
return root
print(dummy_objective(np.array([0.5, 0.1, 0.2]), 0.5)) # Gives -0.27
print(jax.grad(dummy_objective)(np.array([0.5, 0.1, 0.2]), 0.5)) # [-3.01, -2.99, -2.98]
check_grads(lambda design: dummy_objective(design, 0.5),
(np.array([0.5, 0.1, 0.2]), ), order=2, eps=1e-4) # Assertion Error!! I tried enabling float64 but the error still persists. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
I'm having trouble running your code. What is |
Beta Was this translation helpful? Give feedback.
-
This code can be run easily and is documented. Hopefully this helps!
|
Beta Was this translation helpful? Give feedback.
-
If this is still outstanding, you might find It also includes a custom JVP rule, for differentiating via the implicit function theorem. |
Beta Was this translation helpful? Give feedback.
There could be two points causing the wrong result:
into
Since$\partial$ $\partial$ $\partial$ $\partial$
grad[0]
isf
/root
andgrad[1]
isf
/des
, it should be other way around.See more here in
jax.lax.custom_root
https://github.com/google/jax/blob/f498442daa927aeb5bdbf840bc28cc527f47d7f1/jax/_src/lax/control_flow/solves.py#L136-L143
jaxopt.Bisection
may need to use moremax_iter
, and smallertol
to have more accurate root. I find this works