Skip to content

Custom JVP rule for Bisection algorithm #16924

Closed Answered by anh-tong
SNMS95 asked this question in General
Discussion options

You must be logged in to vote

There could be two points causing the wrong result:

  1. This line should change from
rhs = -1* grad[0] / grad[1]

into

rhs = -1* grad[1] / grad[0]

Since grad[0] is $\partial$f/ $\partial$ root and grad[1] is $\partial$f/ $\partial$ des, it should be other way around.

See more here in jax.lax.custom_root

https://github.com/google/jax/blob/f498442daa927aeb5bdbf840bc28cc527f47d7f1/jax/_src/lax/control_flow/solves.py#L136-L143

  1. jaxopt.Bisection may need to use more max_iter, and smaller tol to have more accurate root. I find this works
root = Bisection(f, a, b, check_bracket=False, maxiter=1000, tol=1e-8,
                  jit=True).run(design=des).params

Replies: 3 comments 3 replies

Comment options

You must be logged in to vote
1 reply
@SNMS95
Comment options

Comment options

You must be logged in to vote
2 replies
@anh-tong
Comment options

Answer selected by SNMS95
@SNMS95
Comment options

Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
4 participants