-
-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added JumpStepWrapper #484
base: main
Are you sure you want to change the base?
Conversation
78b122a
to
0eac356
Compare
I now also added the functionality to revisit rejected steps. In addition, I also imporved the runtime of Also I think there was a bug in the PID controller, where it would sometimes reject a step, but have diffrax/diffrax/_step_size_controller/adaptive.py Lines 569 to 574 in 501bed5
I think possibly something smaller than just self.safety would make even more sense, I feel like if a step is rejected the next step should be at least 0.5x smaller. But I'm not an expert.
I added a test for revisiting steps and it all seems to work. I also sprinkled in a bunch of I think I commented the code quite well, so hopefully you can easily notice if I made a mistake somewhere. P.S.: Sorry for bombarding you with PRs. As far as I'm concerned this one is very low priority, I can use the code even if it isn't merged into diffrax proper. |
d022ac1
to
4702380
Compare
Hi @patrick-kidger, diffrax/benchmarks/jump_step_timing.py Lines 126 to 128 in 345e23a
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, quick first pass at a review!
Thanks for the review! I made all the edits I could and I left some comments where I need guidance (no hurry though, this is not high priority for me). Also, should I get rid of |
0050fa2
to
c3c4dcf
Compare
If it's easy to do that in a separate commit afterwards then I would say yes. A separate commit just so it's easy to revert if it turns out we were wrong about something here :D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I'm really sorry for taking so long to get around to this one! Some other work projects got in the way for a bit. (But on the plus side I have a few more open source projects in the pipe, keep an eye out for those ;) ) This is a really useful PR that I very much want to see in.
I've just done another revivew, LMK what you think!
Thanks for the review, Patrick! I'll probably make the fixes sometime in the coming week. I am also making progress on the ML examples for the Single-seed paper, but it is slower now, due to my internship. |
I am very confused about what the correct value of Suppose there is a jump at t=2. I will present 2 possible scenarios, in both of which I think something goes wrong (although maybe diffeqsolve might correct for the issue in scenario B). I wrote them as if JSW and the controller are separate, but the same holds for just the old PID controller. ====== scenario A =======
====== scenario B =======
Another way of seeing this all is through this:
Hence setting |
So I think the Line 390 in daec89c
I made the decision to handle some of the step-rejection logic in the main So I think this fine? Do double-check my logic though! :p Other than that, one thing I am noticing is that this |
Great, that's exactly the line I was looking for (I must admit I looked in Thinking about it now, the Edit: I already implemented what I mentioned above and wrote the proof in a comment. If you're curious and have extra time (yes I know that's a very far tail event :)) you can find it on my |
e203b53
to
7325e74
Compare
Hi Patrick! I just pushed a new version of this PR, rebased on top of the most current main. I think I addressed everything you asked me to fix. As it stands this contains 3 commits, contatining:
I left some conversations unresloved. I did try to fix the things mentioned in those, but I am not sure whether what I did was the best way to tackle that so I wanted to hear your opinion. Also the test are failing because pyright doesn't know how to import PS: The linear search I added slows it down compared to the way I wrote it before, but it is still faster than the old implementation with binary search. In particular the times (as obtained by
Additionally, changing the length of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay! I think I really like this.
First of all, I think I'm basically happy with pretty much everything outside of jump_step_wrapper.py
. The changes here are pleasingly simple ^^
For jump_step_wrapper.py
, I think my main question is around whether the rejected-step-buffer should actually be part of this wrapper at all -- since that handles SDEs with any kind of step rejection, which I think is completely orthogonal to clipping steps? (Not sure how I didn't notice this before!) I've also commented on a few other more minor points.
By the way, what did you think of the idea of moving next_made_jump
into _integrate.py
? It doesn't have to be now -- happy for that to be a separate PR -- just checking your thoughts on whether it is a generalisable thing.
Finally: merry Christmas, and a happy new year! :D
diffrax/_step_size_controller/pid.py
Outdated
at_dtmin = at_dtmin | (prev_dt <= self.dtmin) | ||
keep_step = keep_step | at_dtmin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, does at_dtmin
need to be state? (I'm not sure it ever did.) I think we might just be able to have keep_step = keep_step | (prev_dt <= self.dtmin)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point. Looking at the code, I don't really see a reason to keep it, but I'll get rid of it in a separate commit so we can roll it back easily.
The `step_ts` and `jump_ts` are used to force the solver to step to certain times. | ||
They mostly act in the same way, except that when we hit an element of `jump_ts`, | ||
the controller must return `made_jump = True`, so that the diffeqsolve function | ||
knows that the vector field has a discontinuity at that point, in which case it | ||
re-evaluates it right after the jump point. In addition, the | ||
exact time of the jump will be skipped using eqxi.prevbefore and eqxi.nextafter. | ||
So now to the explanation of the two (we will use `step_ts` as an example, but the | ||
same applies to `jump_ts`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I often rewrite parts of docs after merging anyway, so feel free to ignore this for now -- but just a heads-up that this part is discussing a lot of implementation details: made_jump = True
and eqxi.{prevbefore,nextafter}
are not details familiar to most users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I intended to keep that as a comment, not part of the docs, but you suggested I put it in the docstring and I wasn't sure what exactly you wanted. I don't have strong opinions here, so feel free to rewrite it however you wish.
i = jax.lax.while_loop(cond_up, lambda _i: _i + 1, i) | ||
i = jax.lax.while_loop(cond_down, lambda _i: _i - 1, i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have both of these loops? I think we only need a linear search in one direction: to find the next element of ts
to clip to?
(And if we do need a bidirectional search, then given a hint n
it's probably more efficient to search e.g. n / n+1 / n-1 / n+2 / n-2 / ...
etc back and forth?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At most one of the two loops will trigger, so I am pretty sure doing it this way is faster and cleaner than your second suggestion (unless I'm getting it completely wrong??). And this is probably the safest option, but yeah I think we can easily just have the upwards loop, everything should still work if my logic is correct. Well in fact everything worked perfectly without any loops at all, the reason we added this is to be extra sure there aren't any edge cases. So I'd say if we want to be safe, let's be completely safe and have both loops. But up to you.
|
||
# This is just a logging utility for testing purposes | ||
if self.callback_on_reject is not None: | ||
jax.debug.callback(self.callback_on_reject, keep_step, t1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might suggest making this a pure_callback
or io_callback
, so that it will definitely be called in the right order across steps. JAX doesn't actually offer guarantees about the order in which multiple debug callbacks are called.
See for example how eqx.error_if
works, which does the same thing by requiring a token.
(There is actually jax.debug.callback(..., ordered=True)
, but this works by having JAX sneakily rewriting the jaxpr to thread a dummy argument through as a token so as to order things... and I think that edge cases, so I try to avoid it.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! I don't think the order matters in the test I use this for, but I suppose might as well do it properly.
# Let's prove that the line below is correct. Say the inner controller is | ||
# itself a JumpStepWrapper (JSW) with some inner_jump_ts. Then, given that | ||
# it propsed (next_t0, original_next_t1), there cannot be any jumps in | ||
# inner_jump_ts between next_t0 and original_next_t1. So if the next_t1 | ||
# proposed by the outer JSW is different from the original_next_t1 then | ||
# next_t1 \in (next_t0, original_next_t1) and hence there cannot be a jump | ||
# in inner_jump_ts at next_t1. So the jump_at_next_t1 only depends on | ||
# jump_at_next_t1. | ||
# On the other hand if original_next_t1 == next_t1, then we just take an | ||
# OR of the two. | ||
jump_at_next_t1 = jnp.where( | ||
next_t1 == original_next_t1, | ||
jump_at_original_next_t1, | ||
jump_at_next_t1 | jump_at_original_next_t1, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I don't think I completely believe this. Can we have the following:
- the PID controller proposes
t1
. - the inner JSW wants to clip to a jump
b < t1
. - the outer JSW wants to clip to a step (not a jump!)
a < b
?
In this case then we will have next_t1 != original_next_t1
, an jump_at_original_next_t1 == True
... but overall we want made_jump == False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+can we have a test for two tested JSW, including the above scenario? It doesn't need to be a full diffeqsolve, just directly calling adapt_step_size
and checking that we get the right output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I bungled the code completely, because it doesn't do what's written in the comment at all. But I believe the comment (if implemented correctly) is correct.
I think the code should be:
jump_at_next_t1 = jnp.where(
next_t1 == original_next_t1,
jump_at_next_t1 | jump_at_original_next_t1,
jump_at_next_t1,
)
And this indeed works in the case you brought up as well.
Given this, I don't think moving next_made_jump
into integrate is strictly necessary at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, yes I'll add a test.
7325e74
to
0d446ad
Compare
05bbcb2
to
8d4212c
Compare
Hi Patrick! Sorry for the long silence, the last few weeks have been very busy. I made the corrections you suggested. Regarding splitting off revisiting steps from So unless you feel very strongly about separating these two features, I would prefer not to add extra hurdles to this PR and conclude it in the near future. I think Owen has been itching to get this done as fast as possible as well. |
Hi Patrick,
I factored the
jump_ts
andstep_ts
out of thePIDController
intoJumpStepWrapper
(I'm not very set on this name, lmk if you have ideas). I also made it behave as we discussed in #483. In particular, the following three rules are maintained:t1-t0 <= prev_dt
(this is checked viaeqx.error_if
), with inequality only if the step was clipped or if we hit the end of the integration interval (we do not explicitly check for that).next_dt
must be>=prev_dt
.next_dt
must be< t1-t0
.We achieve this in a very simple way here:
diffrax/diffrax/_step_size_controller/jump_step_wrapper.py
Lines 119 to 123 in 78b122a
The next step is to add a parameter
JumpStepWrapper.revisit_rejected_steps
which does what you expect. That will appear in a future commit in this same PR.