Replies: 1 comment
-
I can't run this, because I don't know what In general, though, JAX has a fixed-cost overhead for every computation that is on the order tens of milliseconds, while numpy does not. I suspect that if you ran your benchmarks on arrays that are large enough so that this overhead is not important, you'd see a truer comparison of JAX vs numpy runtime scaling. Also, keep in mind for this sort of micro-benchmark that JAX executes Asynchronously, and so you should use the |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
At Colab, I am doing the following experiment to see the speed of some basic addition computation, the code is very simple
sum and a conditional sum in JAX
same way, in Numpy
For CPU only
For GPU accelerator
It seems that a plain numpy has orders of magnitude win, why is that, I already used @jit and the code is straightforward to run, do I make anything wrong? Where can we get the gain using JAX
Beta Was this translation helpful? Give feedback.
All reactions