-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🔪 Remaining Sharp Bit TODOs 🔪 #9952
Comments
Just adding context: context manager for precision is defined here and there are some words about it in #6143. |
Others sharp bits:
|
Fantastic! On that first bullet, we could also mention |
Nice, +1 on the |
Also, making sure that the async set of JAX calls used in a training loop don't introduce blocking calls that will kill dispatch pipelining efficiency (e.g. trivial host-side metrics fn or similar) - one of the most common performance mistakes I see (maybe belongs in a separate performance gotchas doc... not sure) |
I like the idea of having a new dedicated doc for performance tips and pitfalls |
Regarding reworking the Sharp Bits doc, I recently added a section on miscellaneous divergences between numpy and JAX. It might be nice to rearrange things so all the differences between numpy and JAX are listed briefly under a single heading, perhaps with links to deeper discussion later in the doc. |
Regarding the "jit caching behavior", is there any chance you could cache the compiled result to the file system so that it can persist across runs? In my development cycle, I typically change some hyperparameters and re-run the experiment. It's a little frustrating that each time I have to wait for the JIT compilation, even if I have compiled the exact same code multiple times. I am under the impression that this won't be too hard to implement, since we already have a hashing/caching mechanism. All it takes is writing the emitted XLA program to the disk. Should I open a new issue for this? |
@nalzok - there is currently an implementation of this, but only for TPU. See https://github.com/google/jax/tree/main/jax/experimental/compilation_cache for details, and #2490 where this kind of request is tracked. |
I have a fairly RNG generation-heavy workload that I am running on Cloud TPU and was googling around to try and understand the |
@JeppeKlitgaard - yeah, it uses an adhoc method of splitting keys that we don't have theoretical justification for (and in fact we don't really have well established statistical tests for split-chain decorrelation when it comes to splittable PRNG systems). That said, it compiles and runs fast, and it's almost certainly good enough for e.g. dropout masks in the context of SGD training of NNs (and we've used it for that with no observed ill effects for some time). I'd be a bit more careful if I were doing classic MCMC or something. |
We could do with sprucing up the Sharp Bits with common problems we've encountered in user code since it was first written.
Top of the list is documenting matmul / conv op precision issues:
We should add some other ideas here.
The text was updated successfully, but these errors were encountered: