Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added JumpStepWrapper #484

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

andyElking
Copy link
Contributor

Hi Patrick,

I factored the jump_ts and step_ts out of the PIDController into JumpStepWrapper (I'm not very set on this name, lmk if you have ideas). I also made it behave as we discussed in #483. In particular, the following three rules are maintained:

  1. We always have t1-t0 <= prev_dt (this is checked via eqx.error_if), with inequality only if the step was clipped or if we hit the end of the integration interval (we do not explicitly check for that).
  2. If the step was accepted, then next_dt must be >=prev_dt.
  3. If the step was rejected, then next_dt must be < t1-t0.

We achieve this in a very simple way here:

dt_proposal = next_t1 - next_t0
dt_proposal = jnp.where(
keep_step, jnp.maximum(dt_proposal, prev_dt), dt_proposal
)
new_prev_dt = dt_proposal

The next step is to add a parameter JumpStepWrapper.revisit_rejected_steps which does what you expect. That will appear in a future commit in this same PR.

@andyElking
Copy link
Contributor Author

I now also added the functionality to revisit rejected steps. In addition, I also imporved the runtime of step_ts and jump_ts, because the controller no longer searches the whole array each time, but keeps an index of where in the array it was previously.

Also I think there was a bug in the PID controller, where it would sometimes reject a step, but have factor>1. To remedy this I modified the following:

factormax = jnp.where(keep_step, self.factormax, self.safety)
factor = jnp.clip(
self.safety * factor1 * factor2 * factor3,
min=factormin,
max=factormax,
)

I think possibly something smaller than just self.safety would make even more sense, I feel like if a step is rejected the next step should be at least 0.5x smaller. But I'm not an expert.

I added a test for revisiting steps and it all seems to work. I also sprinkled in a bunch of eqx.error_if statements to make sure the necessary invariants are always maintained. But this is a bit experimental, so maybe there are some bugs I didn't test for.

I think I commented the code quite well, so hopefully you can easily notice if I made a mistake somewhere.

P.S.: Sorry for bombarding you with PRs. As far as I'm concerned this one is very low priority, I can use the code even if it isn't merged into diffrax proper.

@andyElking andyElking force-pushed the jump_step_pr branch 2 times, most recently from d022ac1 to 4702380 Compare August 14, 2024 14:53
@andyElking
Copy link
Contributor Author

Hi @patrick-kidger,
I got rid of some eqx.error_ifs that I added to my JumpStepWrapper and redid the timing benchmarks. My new implementation was already faster than the old PIDController before, but now this is way more significant, especially when step_ts is long (think >100). Surprisingly, it is faster even when it has to revisit rejected steps. See

# ======= RESULTS =======
# New controller: 0.22829 s, Old controller: 0.31039 s
# Revisiting controller: 0.23212 s

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, quick first pass at a review!

diffrax/_step_size_controller/adaptive.py Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
@andyElking
Copy link
Contributor Author

Thanks for the review! I made all the edits I could and I left some comments where I need guidance (no hurry though, this is not high priority for me). Also, should I get rid of prev_dt entirely, as you suggested in #483?

@patrick-kidger
Copy link
Owner

Also, should I get rid of prev_dt entirely, as you suggested in #483?

If it's easy to do that in a separate commit afterwards then I would say yes. A separate commit just so it's easy to revert if it turns out we were wrong about something here :D

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'm really sorry for taking so long to get around to this one! Some other work projects got in the way for a bit. (But on the plus side I have a few more open source projects in the pipe, keep an eye out for those ;) ) This is a really useful PR that I very much want to see in.

I've just done another revivew, LMK what you think!

diffrax/_step_size_controller/adaptive_base.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
test/test_progress_meter.py Show resolved Hide resolved
diffrax/_step_size_controller/pid.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
@andyElking
Copy link
Contributor Author

Thanks for the review, Patrick! I'll probably make the fixes sometime in the coming week. I am also making progress on the ML examples for the Single-seed paper, but it is slower now, due to my internship.

@andyElking
Copy link
Contributor Author

I am very confused about what the correct value of made_jump should be when the step was rejected. By my understanding the original PID controller also got this wrong. However, I'm not sure whether getting this wrong is actually a significant issue, so I want your thoughts on it.

Suppose there is a jump at t=2. I will present 2 possible scenarios, in both of which I think something goes wrong (although maybe diffeqsolve might correct for the issue in scenario B). I wrote them as if JSW and the controller are separate, but the same holds for just the old PID controller.

====== scenario A =======

  1. We start with a step from t0 = 0 to t1 = 1, the controller decides the step will be kept and makes a proposal next_t0 = 1, next_t1=2.5 which gets clipped to next_t1 = 2 and so the JSW writes jump_next_step=True in its state.
  2. We take the next step from t0 = 1 to t1 = 2. Suppose the step is rejected, so next_t0 = 1. But the way the code works now, it will still set made_jump = previous_state.jump_next_step, which has been set to True. But that is wrong because made_jump should reflect whether there is a jump at next_t0 so that the solver knows whether to reevaluate the VF at the start of the next step. Still this is not horrible, since it just causes one extra evaluation of the VF, but the full solution should still be correct. Right?

====== scenario B =======

  1. We start with a step from t0 = 0 to t1 = 1, the controller decides the step will be kept and makes a proposal next_t0 = 1, next_t1=2.5 which gets clipped to next_t1 = 2 and so the JSW writes jump_next_step=True in its state.
  2. We take the next step from t0 = 1 to t1 = 2. Suppose the step is accepted and the controller proposes next_t0 = 2 and next_t1 = 4 the controller returns (correctly) made_jump = previous_state.jump_next_step (=True) and ControllerState(..., jump_next_step=False).
  3. We take the next step from t0 = 2 to t1 = 4. Note that because made_jump was set to True last iteration, the FSAL VF was reevaluated at t = 2. Suppose the step is rejected and the new proposal is next_t0 = 2 and next_t1 = 3. The controller returns made_jump = previous_state.jump_next_step (=False). This is a problem because there is a jump at next_t0=2, so made_jump should have been True. However that might be fine if the value of the VF which was recomputed at the start of this step is kept for the next step (but I remember thinking that solver state is discarded when the step is rejected, so the recomputed VF also gets discarded).

Another way of seeing this all is through this:

  • jump_next_step = (there is a jump at next_t1)
  • previous_state.jump_next_step = (there is a jump at t1)
  • made_jump should be True iff there is a jump at next_t0

Hence setting made_jump = previous_state.jump_next_step only works when the current step is being accepted, so next_t0 = t1. When rejecting the step, the new made_jump should be equal to the previous step's made_jump which we no longer have access to. The solution would be to add yet another thing to the state, but first I wanted to confirm with you to see if I'm getting this right. What are your thoughts?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Nov 29, 2024

So I think the made_jump-on-rejected-step is handled through this line, outside the stepsize controller:

made_jump = static_select(keep_step, made_jump, state.made_jump)

I made the decision to handle some of the step-rejection logic in the main integrate.py loop, on the basis that for those pieces it should be the same for all stepsize controllers.

So I think this fine? Do double-check my logic though! :p

Other than that, one thing I am noticing is that this next_made_jump business is pretty annoying to deal with in the stepsize controller, and I think could probably also be factored out to happen in the main integrate.py loop. (Not sure if that's doable in a backward-compatible way though.)

@andyElking
Copy link
Contributor Author

andyElking commented Dec 3, 2024

Great, that's exactly the line I was looking for (I must admit I looked in _integrate.py only briefly). But yes this is precisely the piece of the puzzle that my logic was missing, so I think we both independently arrived at the same conclusion that this is probably correct.

Thinking about it now, the next_made_jump situation might be way easier than I thought it was. I'll try to write the code and then I'll write a proof that it does the right thing for any number of stacked JSWs. Hopefully I can be done with that tomorrow.

Edit: I already implemented what I mentioned above and wrote the proof in a comment. If you're curious and have extra time (yes I know that's a very far tail event :)) you can find it on my pr_correction branch here. I haven't fixed all the other things yet, so I'll push everything together once it's all done.

@andyElking andyElking force-pushed the jump_step_pr branch 2 times, most recently from e203b53 to 7325e74 Compare December 5, 2024 21:32
@andyElking
Copy link
Contributor Author

andyElking commented Dec 5, 2024

Hi Patrick!

I just pushed a new version of this PR, rebased on top of the most current main. I think I addressed everything you asked me to fix.

As it stands this contains 3 commits, contatining:

  1. Everything except 2. and 3.
  2. The changes to at_dtmin and factormax in pid.py
  3. Removing prev_dt from JSW.

I left some conversations unresloved. I did try to fix the things mentioned in those, but I am not sure whether what I did was the best way to tackle that so I wanted to hear your opinion.

Also the test are failing because pyright doesn't know how to import typeguard, which has nothing to do with my changes.

PS: The linear search I added slows it down compared to the way I wrote it before, but it is still faster than the old implementation with binary search. In particular the times (as obtained by benchmarks/jump_step_timing.py) are as follows:

  • my previous implementation: 0.263 (without revisit rejected), 0.285 (with revisit rejected)
  • new implementation (linear search): 0.295 (without revisit rejected), 0.315 (with revisit rejected)
  • old pid with binary search: 0.332

Additionally, changing the length of rejected_buffer from 10 to 4096 amounts to a neglibigle slowdown. However, this might change depending on the problem setup.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! I think I really like this.

First of all, I think I'm basically happy with pretty much everything outside of jump_step_wrapper.py. The changes here are pleasingly simple ^^

For jump_step_wrapper.py, I think my main question is around whether the rejected-step-buffer should actually be part of this wrapper at all -- since that handles SDEs with any kind of step rejection, which I think is completely orthogonal to clipping steps? (Not sure how I didn't notice this before!) I've also commented on a few other more minor points.

By the way, what did you think of the idea of moving next_made_jump into _integrate.py? It doesn't have to be now -- happy for that to be a separate PR -- just checking your thoughts on whether it is a generalisable thing.

Finally: merry Christmas, and a happy new year! :D

docs/api/stepsize_controller.md Outdated Show resolved Hide resolved
diffrax/_step_size_controller/pid.py Outdated Show resolved Hide resolved
Comment on lines 515 to 513
at_dtmin = at_dtmin | (prev_dt <= self.dtmin)
keep_step = keep_step | at_dtmin
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, does at_dtmin need to be state? (I'm not sure it ever did.) I think we might just be able to have keep_step = keep_step | (prev_dt <= self.dtmin)?

Copy link
Contributor Author

@andyElking andyElking Jan 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point. Looking at the code, I don't really see a reason to keep it, but I'll get rid of it in a separate commit so we can roll it back easily.

Comment on lines +133 to +140
The `step_ts` and `jump_ts` are used to force the solver to step to certain times.
They mostly act in the same way, except that when we hit an element of `jump_ts`,
the controller must return `made_jump = True`, so that the diffeqsolve function
knows that the vector field has a discontinuity at that point, in which case it
re-evaluates it right after the jump point. In addition, the
exact time of the jump will be skipped using eqxi.prevbefore and eqxi.nextafter.
So now to the explanation of the two (we will use `step_ts` as an example, but the
same applies to `jump_ts`):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I often rewrite parts of docs after merging anyway, so feel free to ignore this for now -- but just a heads-up that this part is discussing a lot of implementation details: made_jump = True and eqxi.{prevbefore,nextafter} are not details familiar to most users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I intended to keep that as a comment, not part of the docs, but you suggested I put it in the docstring and I wasn't sure what exactly you wanted. I don't have strong opinions here, so feel free to rewrite it however you wish.

Comment on lines +93 to +94
i = jax.lax.while_loop(cond_up, lambda _i: _i + 1, i)
i = jax.lax.while_loop(cond_down, lambda _i: _i - 1, i)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have both of these loops? I think we only need a linear search in one direction: to find the next element of ts to clip to?

(And if we do need a bidirectional search, then given a hint n it's probably more efficient to search e.g. n / n+1 / n-1 / n+2 / n-2 / ... etc back and forth?)

Copy link
Contributor Author

@andyElking andyElking Jan 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At most one of the two loops will trigger, so I am pretty sure doing it this way is faster and cleaner than your second suggestion (unless I'm getting it completely wrong??). And this is probably the safest option, but yeah I think we can easily just have the upwards loop, everything should still work if my logic is correct. Well in fact everything worked perfectly without any loops at all, the reason we added this is to be extra sure there aren't any edge cases. So I'd say if we want to be safe, let's be completely safe and have both loops. But up to you.


# This is just a logging utility for testing purposes
if self.callback_on_reject is not None:
jax.debug.callback(self.callback_on_reject, keep_step, t1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might suggest making this a pure_callback or io_callback, so that it will definitely be called in the right order across steps. JAX doesn't actually offer guarantees about the order in which multiple debug callbacks are called.

See for example how eqx.error_if works, which does the same thing by requiring a token.

(There is actually jax.debug.callback(..., ordered=True), but this works by having JAX sneakily rewriting the jaxpr to thread a dummy argument through as a token so as to order things... and I think that edge cases, so I try to avoid it.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I don't think the order matters in the test I use this for, but I suppose might as well do it properly.

diffrax/_step_size_controller/jump_step_wrapper.py Outdated Show resolved Hide resolved
Comment on lines 430 to 439
# Let's prove that the line below is correct. Say the inner controller is
# itself a JumpStepWrapper (JSW) with some inner_jump_ts. Then, given that
# it propsed (next_t0, original_next_t1), there cannot be any jumps in
# inner_jump_ts between next_t0 and original_next_t1. So if the next_t1
# proposed by the outer JSW is different from the original_next_t1 then
# next_t1 \in (next_t0, original_next_t1) and hence there cannot be a jump
# in inner_jump_ts at next_t1. So the jump_at_next_t1 only depends on
# jump_at_next_t1.
# On the other hand if original_next_t1 == next_t1, then we just take an
# OR of the two.
jump_at_next_t1 = jnp.where(
next_t1 == original_next_t1,
jump_at_original_next_t1,
jump_at_next_t1 | jump_at_original_next_t1,
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't think I completely believe this. Can we have the following:

  • the PID controller proposes t1.
  • the inner JSW wants to clip to a jump b < t1.
  • the outer JSW wants to clip to a step (not a jump!) a < b
    ?

In this case then we will have next_t1 != original_next_t1, an jump_at_original_next_t1 == True... but overall we want made_jump == False?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+can we have a test for two tested JSW, including the above scenario? It doesn't need to be a full diffeqsolve, just directly calling adapt_step_size and checking that we get the right output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I bungled the code completely, because it doesn't do what's written in the comment at all. But I believe the comment (if implemented correctly) is correct.

I think the code should be:

jump_at_next_t1 = jnp.where(
    next_t1 == original_next_t1,
    jump_at_next_t1 | jump_at_original_next_t1,
    jump_at_next_t1,
)

And this indeed works in the case you brought up as well.

Given this, I don't think moving next_made_jump into integrate is strictly necessary at this point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yes I'll add a test.

This was referenced Jan 1, 2025
@andyElking
Copy link
Contributor Author

Hi Patrick!

Sorry for the long silence, the last few weeks have been very busy.

I made the corrections you suggested.

Regarding splitting off revisiting steps from JumpStepWrapper, I agree that this feature is perpendicular to jump_ts and step_ts, but I don't see the harm in implementing several different features with one class... Certainly we want step-revisiting to be part of some sort of wrapper which can wrap either PID or potentially a custom step controller (I had such a situation in the paper with James I think), so it shouldn't be part of PID either.
So then the alternative is to have two wrappers -- JumpStepWrapper and RevisitingWrapper -- which might work, but might also introduce more confusion and more potential edge cases.

So unless you feel very strongly about separating these two features, I would prefer not to add extra hurdles to this PR and conclude it in the near future. I think Owen has been itching to get this done as fast as possible as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants