Skip to content

Latest commit

 

History

History
540 lines (391 loc) · 34 KB

colab.md

File metadata and controls

540 lines (391 loc) · 34 KB

Example results on Google Colab

Results are as reported by this notebook. To re-run these experiments, just head over to Google Colab, upload the notebook, and run the cells one by one.

Hardware at the time of writing (Oct 2021):

  • Intel(R) Xeon(R) CPU @ 2.30GHz (1 core, 2 threads)
  • 12.6GB of RAM
  • NVidia Tesla K80 GPU with 12GB memory

Caveat: Jax does not support 64bit floating point precision on TPU architectures (yet). Therefore, the Jax + TPU results are not bit-identical to all other backends and devices, so it's not really an apples-to-apples comparison.

Contents

Equation of state

An equation consisting of >100 terms with no data dependencies and only elementary math. This benchmark should represent a best-case scenario for vector instructions and GPU performance.

CPU

$ taskset -c 0 python run.py benchmarks/equation_of_state/

benchmarks.equation_of_state
============================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  pytorch       10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.015     5.605
       4,096  jax           10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.014     5.167
       4,096  numba         10,000     0.001     0.000     0.000     0.000     0.001     0.001     0.013     3.178
       4,096  aesara        10,000     0.001     0.000     0.000     0.001     0.001     0.001     0.015     2.637
       4,096  tensorflow    10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.009     2.143
       4,096  numpy         10,000     0.002     0.000     0.001     0.002     0.002     0.002     0.010     1.000

      16,384  pytorch       10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.017     6.284
      16,384  jax           10,000     0.002     0.000     0.001     0.001     0.002     0.002     0.019     5.396
      16,384  tensorflow     1,000     0.002     0.000     0.002     0.002     0.002     0.002     0.005     4.161
      16,384  numba         10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.010     3.816
      16,384  aesara        10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.017     3.520
      16,384  numpy          1,000     0.009     0.001     0.007     0.008     0.009     0.009     0.012     1.000

      65,536  pytorch        1,000     0.005     0.001     0.005     0.005     0.005     0.006     0.015    16.182
      65,536  jax            1,000     0.006     0.001     0.005     0.005     0.006     0.006     0.009    15.457
      65,536  tensorflow     1,000     0.006     0.001     0.005     0.006     0.006     0.006     0.021    14.052
      65,536  numba          1,000     0.009     0.001     0.008     0.008     0.009     0.009     0.017    10.105
      65,536  aesara         1,000     0.009     0.001     0.008     0.009     0.009     0.009     0.015     9.394
      65,536  numpy            100     0.088     0.003     0.079     0.086     0.088     0.090     0.097     1.000

     262,144  pytorch        1,000     0.018     0.001     0.015     0.017     0.017     0.018     0.028    10.783
     262,144  jax            1,000     0.020     0.002     0.017     0.019     0.019     0.020     0.035     9.667
     262,144  tensorflow     1,000     0.021     0.001     0.018     0.020     0.021     0.022     0.031     8.949
     262,144  numba            100     0.032     0.002     0.029     0.031     0.031     0.033     0.044     5.930
     262,144  aesara           100     0.034     0.002     0.032     0.033     0.033     0.034     0.042     5.666
     262,144  numpy            100     0.190     0.003     0.177     0.188     0.190     0.192     0.200     1.000

   1,048,576  pytorch          100     0.075     0.003     0.068     0.073     0.074     0.077     0.083    21.187
   1,048,576  jax              100     0.086     0.004     0.079     0.083     0.085     0.088     0.098    18.447
   1,048,576  tensorflow       100     0.087     0.004     0.080     0.085     0.087     0.089     0.099    18.140
   1,048,576  numba            100     0.132     0.004     0.125     0.129     0.132     0.134     0.145    11.976
   1,048,576  aesara           100     0.140     0.004     0.131     0.137     0.140     0.142     0.157    11.301
   1,048,576  numpy             10     1.585     0.015     1.568     1.573     1.579     1.595     1.612     1.000

   4,194,304  pytorch           10     0.297     0.006     0.285     0.294     0.297     0.302     0.307    12.408
   4,194,304  tensorflow        10     0.342     0.005     0.331     0.339     0.343     0.345     0.349    10.793
   4,194,304  jax               10     0.360     0.008     0.348     0.354     0.357     0.367     0.373    10.253
   4,194,304  numba             10     0.515     0.007     0.504     0.510     0.516     0.522     0.526     7.155
   4,194,304  aesara            10     0.556     0.009     0.543     0.547     0.558     0.563     0.569     6.634
   4,194,304  numpy             10     3.688     0.014     3.668     3.678     3.688     3.693     3.723     1.000

(time in wall seconds, less is better)

GPU

$ for backend in jax tensorflow pytorch cupy; do python run.py benchmarks/equation_of_state/ --device gpu -b $backend -b numpy; done

benchmarks.equation_of_state
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax           10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.004    12.584
       4,096  numpy         10,000     0.002     0.000     0.001     0.002     0.002     0.002     0.011     1.000

      16,384  jax           10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.003    61.389
      16,384  numpy          1,000     0.008     0.001     0.007     0.008     0.008     0.009     0.017     1.000

      65,536  jax            1,000     0.000     0.000     0.000     0.000     0.000     0.000     0.000   250.282
      65,536  numpy            100     0.047     0.002     0.044     0.046     0.046     0.048     0.053     1.000

     262,144  jax            1,000     0.000     0.000     0.000     0.000     0.000     0.000     0.004   699.509
     262,144  numpy            100     0.309     0.011     0.256     0.304     0.310     0.314     0.336     1.000

   1,048,576  jax              100     0.002     0.000     0.001     0.001     0.001     0.002     0.004   542.418
   1,048,576  numpy             10     0.818     0.009     0.805     0.808     0.819     0.824     0.831     1.000

   4,194,304  jax              100     0.006     0.001     0.005     0.005     0.005     0.005     0.012   544.624
   4,194,304  numpy             10     3.153     0.014     3.123     3.152     3.156     3.163     3.173     1.000

(time in wall seconds, less is better)

benchmarks.equation_of_state
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  tensorflow    10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.006     3.709
       4,096  numpy         10,000     0.002     0.000     0.001     0.002     0.002     0.002     0.012     1.000

      16,384  tensorflow    10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.006    18.838
      16,384  numpy          1,000     0.008     0.001     0.007     0.008     0.008     0.009     0.014     1.000

      65,536  tensorflow    10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.006   513.398
      65,536  numpy            100     0.228     0.008     0.203     0.224     0.227     0.233     0.256     1.000

     262,144  tensorflow     1,000     0.000     0.000     0.000     0.000     0.000     0.000     0.004   747.237
     262,144  numpy            100     0.343     0.012     0.274     0.338     0.343     0.350     0.372     1.000

   1,048,576  tensorflow     1,000     0.001     0.000     0.000     0.000     0.000     0.001     0.006  1657.587
   1,048,576  numpy             10     0.873     0.012     0.851     0.866     0.875     0.881     0.890     1.000

   4,194,304  tensorflow       100     0.001     0.000     0.001     0.001     0.001     0.001     0.001  4226.591
   4,194,304  numpy             10     3.175     0.014     3.153     3.164     3.175     3.183     3.197     1.000

(time in wall seconds, less is better)

benchmarks.equation_of_state
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  pytorch       10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.008    15.199
       4,096  numpy         10,000     0.002     0.000     0.001     0.002     0.002     0.002     0.010     1.000

      16,384  pytorch       10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.008    69.659
      16,384  numpy          1,000     0.009     0.001     0.007     0.008     0.009     0.009     0.016     1.000

      65,536  pytorch       10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.009  1393.452
      65,536  numpy            100     0.286     0.088     0.126     0.151     0.331     0.338     0.397     1.000

     262,144  pytorch        1,000     0.000     0.000     0.000     0.000     0.000     0.000     0.007   989.724
     262,144  numpy            100     0.418     0.106     0.220     0.251     0.474     0.482     0.521     1.000

   1,048,576  pytorch        1,000     0.001     0.000     0.001     0.001     0.001     0.001     0.010   716.353
   1,048,576  numpy             10     0.970     0.201     0.721     0.728     1.101     1.144     1.160     1.000

   4,194,304  pytorch          100     0.005     0.000     0.005     0.005     0.005     0.005     0.005   708.456
   4,194,304  numpy             10     3.402     0.017     3.371     3.389     3.400     3.417     3.428     1.000

(time in wall seconds, less is better)

benchmarks.equation_of_state
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy         10,000     0.002     0.000     0.001     0.002     0.002     0.002     0.005     1.000
       4,096  cupy           1,000     0.007     0.002     0.005     0.006     0.006     0.009     0.018     0.223

      16,384  cupy           1,000     0.008     0.002     0.006     0.006     0.007     0.009     0.020     1.085
      16,384  numpy          1,000     0.008     0.001     0.007     0.008     0.008     0.009     0.011     1.000

      65,536  cupy           1,000     0.008     0.002     0.006     0.006     0.007     0.009     0.017     5.290
      65,536  numpy            100     0.040     0.002     0.038     0.039     0.040     0.041     0.046     1.000

     262,144  cupy           1,000     0.016     0.001     0.015     0.015     0.015     0.017     0.019     9.686
     262,144  numpy            100     0.154     0.003     0.148     0.152     0.154     0.157     0.166     1.000

   1,048,576  cupy             100     0.058     0.004     0.053     0.054     0.054     0.060     0.065    12.664
   1,048,576  numpy             10     0.728     0.012     0.710     0.725     0.726     0.733     0.753     1.000

   4,194,304  cupy              10     0.208     0.009     0.203     0.203     0.204     0.207     0.233    14.708
   4,194,304  numpy             10     3.062     0.014     3.039     3.053     3.066     3.073     3.083     1.000

(time in wall seconds, less is better)

TPU

$ python run.py benchmarks/equation_of_state -b jax -b numpy --device tpu

benchmarks.equation_of_state
============================
Running on TPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax            1,000     0.002     0.001     0.001     0.002     0.002     0.003     0.007     1.044
       4,096  numpy         10,000     0.002     0.001     0.002     0.002     0.002     0.003     0.021     1.000

      16,384  jax            1,000     0.002     0.001     0.001     0.002     0.002     0.003     0.007     4.138
      16,384  numpy          1,000     0.010     0.002     0.008     0.009     0.009     0.010     0.052     1.000

      65,536  jax            1,000     0.002     0.001     0.002     0.002     0.002     0.003     0.007    56.663
      65,536  numpy            100     0.139     0.009     0.101     0.137     0.140     0.144     0.158     1.000

     262,144  jax              100     0.002     0.000     0.002     0.002     0.002     0.003     0.004   105.074
     262,144  numpy            100     0.255     0.013     0.227     0.250     0.253     0.261     0.319     1.000

   1,048,576  jax              100     0.003     0.001     0.002     0.003     0.003     0.003     0.008   359.453
   1,048,576  numpy             10     1.075     0.025     1.041     1.057     1.069     1.085     1.125     1.000

   4,194,304  jax               10     0.004     0.000     0.004     0.004     0.004     0.004     0.005   737.921
   4,194,304  numpy             10     3.200     0.033     3.142     3.182     3.199     3.210     3.266     1.000

(time in wall seconds, less is better)

Isoneutral mixing

A more balanced routine with many data dependencies (stencil operations), and tensor shapes of up to 5 dimensions. This is the most expensive part of Veros, so in a way this is the benchmark that interests me the most.

CPU

$ taskset -c 0 python run.py benchmarks/isoneutral_mixing/

benchmarks.isoneutral_mixing
============================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax            1,000     0.001     0.001     0.001     0.001     0.001     0.001     0.016     3.293
       4,096  numba          1,000     0.002     0.002     0.001     0.001     0.001     0.001     0.050     2.904
       4,096  aesara         1,000     0.003     0.003     0.002     0.003     0.003     0.003     0.059     1.334
       4,096  numpy          1,000     0.004     0.002     0.004     0.004     0.004     0.004     0.063     1.000
       4,096  pytorch        1,000     0.004     0.002     0.003     0.004     0.004     0.005     0.052     0.981

      16,384  jax            1,000     0.006     0.001     0.005     0.006     0.006     0.006     0.021     2.664
      16,384  numba          1,000     0.007     0.002     0.006     0.006     0.006     0.007     0.054     2.461
      16,384  aesara         1,000     0.012     0.001     0.010     0.011     0.011     0.012     0.026     1.433
      16,384  pytorch        1,000     0.012     0.003     0.010     0.011     0.011     0.012     0.061     1.424
      16,384  numpy          1,000     0.017     0.002     0.015     0.016     0.016     0.017     0.043     1.000

      65,536  jax              100     0.029     0.001     0.026     0.028     0.028     0.029     0.034     2.597
      65,536  numba            100     0.030     0.003     0.026     0.028     0.029     0.030     0.050     2.494
      65,536  pytorch          100     0.050     0.002     0.046     0.048     0.049     0.051     0.059     1.502
      65,536  aesara           100     0.050     0.002     0.047     0.049     0.050     0.051     0.057     1.483
      65,536  numpy            100     0.075     0.002     0.070     0.073     0.075     0.077     0.080     1.000

     262,144  jax               10     0.111     0.004     0.105     0.108     0.111     0.114     0.118     2.408
     262,144  numba            100     0.116     0.004     0.108     0.113     0.115     0.118     0.130     2.314
     262,144  pytorch           10     0.178     0.004     0.173     0.176     0.178     0.179     0.184     1.503
     262,144  aesara            10     0.190     0.004     0.183     0.187     0.190     0.194     0.197     1.408
     262,144  numpy             10     0.268     0.009     0.254     0.262     0.267     0.274     0.285     1.000

   1,048,576  numba             10     0.480     0.004     0.473     0.476     0.479     0.483     0.488     2.524
   1,048,576  jax               10     0.599     0.007     0.592     0.593     0.597     0.604     0.615     2.020
   1,048,576  aesara            10     0.834     0.011     0.816     0.828     0.833     0.835     0.862     1.451
   1,048,576  pytorch           10     0.863     0.080     0.786     0.799     0.806     0.944     0.983     1.403
   1,048,576  numpy             10     1.210     0.169     1.134     1.147     1.160     1.165     1.718     1.000

   4,194,304  numba             10     1.947     0.011     1.926     1.939     1.953     1.956     1.958     2.739
   4,194,304  jax               10     2.477     0.096     2.422     2.441     2.445     2.461     2.761     2.154
   4,194,304  aesara            10     3.620     0.017     3.592     3.610     3.620     3.630     3.647     1.473
   4,194,304  pytorch           10     3.668     0.026     3.631     3.658     3.663     3.675     3.730     1.454
   4,194,304  numpy             10     5.334     0.042     5.271     5.297     5.333     5.374     5.388     1.000

(time in wall seconds, less is better)

GPU

$ for backend in jax pytorch cupy; do python run.py benchmarks/isoneutral_mixing/ --device gpu -b $backend -b numpy; done

benchmarks.isoneutral_mixing
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax            1,000     0.001     0.000     0.001     0.001     0.001     0.001     0.009     4.187
       4,096  numpy          1,000     0.004     0.001     0.004     0.004     0.004     0.004     0.013     1.000

      16,384  jax            1,000     0.001     0.001     0.001     0.001     0.001     0.001     0.008    13.768
      16,384  numpy          1,000     0.017     0.001     0.015     0.016     0.016     0.017     0.024     1.000

      65,536  jax              100     0.003     0.000     0.003     0.003     0.004     0.004     0.004    21.820
      65,536  numpy            100     0.075     0.004     0.070     0.073     0.074     0.076     0.094     1.000

     262,144  jax              100     0.014     0.001     0.012     0.012     0.015     0.015     0.020    19.799
     262,144  numpy             10     0.274     0.009     0.260     0.272     0.274     0.274     0.293     1.000

   1,048,576  jax               10     0.057     0.005     0.052     0.052     0.054     0.062     0.063    21.834
   1,048,576  numpy             10     1.239     0.009     1.226     1.231     1.237     1.246     1.254     1.000

   4,194,304  jax               10     0.200     0.011     0.192     0.192     0.195     0.207     0.223    25.440
   4,194,304  numpy             10     5.097     0.033     5.054     5.071     5.088     5.124     5.153     1.000

(time in wall seconds, less is better)

benchmarks.isoneutral_mixing
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy          1,000     0.004     0.001     0.004     0.004     0.004     0.004     0.013     1.000
       4,096  pytorch        1,000     0.006     0.001     0.005     0.005     0.005     0.007     0.014     0.746

      16,384  pytorch        1,000     0.006     0.001     0.005     0.005     0.005     0.007     0.017     2.667
      16,384  numpy          1,000     0.016     0.001     0.014     0.016     0.016     0.017     0.027     1.000

      65,536  pytorch          100     0.007     0.001     0.006     0.007     0.007     0.008     0.015    12.932
      65,536  numpy            100     0.097     0.007     0.080     0.094     0.097     0.100     0.125     1.000

     262,144  pytorch          100     0.016     0.002     0.014     0.015     0.015     0.016     0.021    17.586
     262,144  numpy             10     0.274     0.005     0.267     0.270     0.273     0.277     0.281     1.000

   1,048,576  pytorch           10     0.051     0.003     0.048     0.050     0.050     0.050     0.060    25.531
   1,048,576  numpy             10     1.292     0.011     1.276     1.284     1.292     1.296     1.316     1.000

   4,194,304  pytorch           10     0.192     0.011     0.182     0.182     0.184     0.202     0.211    25.674
   4,194,304  numpy             10     4.923     0.013     4.901     4.917     4.920     4.929     4.954     1.000

(time in wall seconds, less is better)

benchmarks.isoneutral_mixing
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy          1,000     0.004     0.001     0.003     0.004     0.004     0.004     0.013     1.000
       4,096  cupy           1,000     0.013     0.002     0.010     0.011     0.012     0.015     0.026     0.343

      16,384  cupy           1,000     0.013     0.002     0.010     0.011     0.012     0.015     0.024     1.273
      16,384  numpy          1,000     0.017     0.001     0.015     0.016     0.016     0.017     0.027     1.000

      65,536  cupy             100     0.013     0.002     0.011     0.012     0.012     0.015     0.025     5.723
      65,536  numpy            100     0.075     0.005     0.068     0.072     0.074     0.077     0.086     1.000

     262,144  cupy             100     0.021     0.002     0.018     0.019     0.023     0.023     0.027    13.102
     262,144  numpy             10     0.279     0.007     0.272     0.274     0.276     0.286     0.292     1.000

   1,048,576  cupy              10     0.071     0.006     0.067     0.068     0.069     0.069     0.083    17.415
   1,048,576  numpy             10     1.240     0.020     1.191     1.232     1.250     1.252     1.263     1.000

   4,194,304  cupy              10     0.270     0.012     0.259     0.260     0.264     0.280     0.291    18.798
   4,194,304  numpy             10     5.071     0.045     4.962     5.048     5.089     5.096     5.124     1.000

(time in wall seconds, less is better)

TPU

$ python run.py benchmarks/isoneutral_mixing -b jax -b numpy --device tpu

benchmarks.isoneutral_mixing
============================
Running on TPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax              100     0.004     0.003     0.003     0.003     0.004     0.004     0.033     1.603
       4,096  numpy          1,000     0.007     0.003     0.005     0.006     0.006     0.006     0.037     1.000

      16,384  jax              100     0.005     0.005     0.003     0.004     0.004     0.005     0.041     4.733
      16,384  numpy            100     0.024     0.006     0.020     0.021     0.022     0.023     0.065     1.000

      65,536  jax              100     0.005     0.002     0.004     0.004     0.005     0.006     0.018    20.059
      65,536  numpy             10     0.106     0.009     0.096     0.101     0.103     0.113     0.126     1.000

     262,144  jax               10     0.007     0.001     0.006     0.006     0.006     0.007     0.009    68.206
     262,144  numpy             10     0.458     0.034     0.364     0.460     0.470     0.473     0.490     1.000

   1,048,576  jax               10     0.016     0.002     0.015     0.015     0.015     0.015     0.022    97.621
   1,048,576  numpy             10     1.522     0.035     1.471     1.500     1.520     1.540     1.601     1.000

   4,194,304  jax               10     0.056     0.009     0.050     0.050     0.051     0.060     0.073   109.384
   4,194,304  numpy             10     6.156     0.077     6.071     6.089     6.138     6.195     6.306     1.000

(time in wall seconds, less is better)

Turbulent kinetic energy

This routine consists of some stencil operations and some linear algebra (a tridiagonal matrix solver), which cannot be vectorized.

CPU

$ taskset -c 0 python run.py benchmarks/turbulent_kinetic_energy/

benchmarks.turbulent_kinetic_energy
===================================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax            1,000     0.001     0.000     0.000     0.000     0.000     0.001     0.004     4.918
       4,096  numba          1,000     0.001     0.000     0.001     0.001     0.001     0.001     0.005     2.312
       4,096  pytorch        1,000     0.002     0.001     0.001     0.002     0.002     0.002     0.008     1.227
       4,096  numpy          1,000     0.003     0.001     0.002     0.002     0.002     0.003     0.009     1.000

      16,384  jax            1,000     0.002     0.000     0.002     0.002     0.002     0.002     0.008     3.708
      16,384  numba          1,000     0.004     0.001     0.003     0.003     0.004     0.004     0.009     2.265
      16,384  pytorch        1,000     0.005     0.001     0.004     0.004     0.005     0.005     0.009     1.803
      16,384  numpy          1,000     0.008     0.001     0.007     0.008     0.008     0.009     0.020     1.000

      65,536  jax              100     0.009     0.000     0.008     0.009     0.009     0.009     0.012     4.015
      65,536  numba            100     0.013     0.001     0.012     0.013     0.013     0.013     0.019     2.801
      65,536  pytorch          100     0.018     0.001     0.016     0.017     0.018     0.019     0.024     2.047
      65,536  numpy            100     0.038     0.001     0.035     0.036     0.037     0.038     0.044     1.000

     262,144  jax              100     0.040     0.002     0.037     0.039     0.039     0.041     0.047     3.173
     262,144  numba            100     0.046     0.003     0.042     0.044     0.045     0.047     0.057     2.745
     262,144  pytorch           10     0.064     0.002     0.061     0.062     0.063     0.064     0.068     1.992
     262,144  numpy             10     0.127     0.002     0.123     0.125     0.127     0.129     0.130     1.000

   1,048,576  numba             10     0.187     0.003     0.183     0.185     0.187     0.189     0.191     3.046
   1,048,576  jax               10     0.237     0.003     0.232     0.235     0.236     0.238     0.241     2.408
   1,048,576  pytorch           10     0.297     0.005     0.289     0.294     0.296     0.302     0.304     1.918
   1,048,576  numpy             10     0.570     0.007     0.559     0.564     0.569     0.577     0.579     1.000

   4,194,304  numba             10     0.737     0.010     0.721     0.730     0.739     0.743     0.751     3.447
   4,194,304  jax               10     1.212     0.012     1.193     1.204     1.210     1.220     1.232     2.097
   4,194,304  pytorch           10     1.404     0.006     1.395     1.400     1.403     1.410     1.415     1.809
   4,194,304  numpy             10     2.540     0.014     2.519     2.529     2.545     2.549     2.557     1.000

(time in wall seconds, less is better)

GPU

$ for backend in jax pytorch; do python run.py benchmarks/turbulent_kinetic_energy/ --device gpu -b $backend -b numpy; done

benchmarks.turbulent_kinetic_energy
===================================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax            1,000     0.001     0.000     0.001     0.001     0.001     0.001     0.004     2.625
       4,096  numpy          1,000     0.002     0.000     0.002     0.002     0.002     0.003     0.006     1.000

      16,384  jax            1,000     0.001     0.000     0.001     0.001     0.001     0.001     0.003     6.924
      16,384  numpy          1,000     0.008     0.001     0.007     0.008     0.008     0.009     0.013     1.000

      65,536  jax              100     0.002     0.000     0.002     0.002     0.003     0.003     0.003    15.079
      65,536  numpy            100     0.038     0.002     0.035     0.036     0.037     0.038     0.047     1.000

     262,144  jax              100     0.010     0.001     0.009     0.010     0.011     0.011     0.011    12.195
     262,144  numpy             10     0.128     0.003     0.123     0.127     0.128     0.129     0.132     1.000

   1,048,576  jax               10     0.043     0.003     0.040     0.041     0.043     0.046     0.046    12.451
   1,048,576  numpy             10     0.540     0.006     0.525     0.538     0.541     0.544     0.545     1.000

   4,194,304  jax               10     0.111     0.008     0.099     0.105     0.111     0.119     0.120    20.741
   4,194,304  numpy             10     2.309     0.008     2.296     2.303     2.309     2.316     2.320     1.000

(time in wall seconds, less is better)

benchmarks.turbulent_kinetic_energy
===================================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy          1,000     0.003     0.000     0.002     0.002     0.002     0.003     0.006     1.000
       4,096  pytorch        1,000     0.003     0.001     0.003     0.003     0.003     0.004     0.007     0.790

      16,384  pytorch        1,000     0.004     0.001     0.003     0.003     0.003     0.004     0.008     2.273
      16,384  numpy          1,000     0.008     0.001     0.007     0.008     0.008     0.009     0.012     1.000

      65,536  pytorch          100     0.005     0.001     0.004     0.004     0.005     0.005     0.008     8.471
      65,536  numpy            100     0.039     0.002     0.036     0.038     0.039     0.040     0.045     1.000

     262,144  pytorch          100     0.008     0.001     0.007     0.007     0.008     0.008     0.011    16.245
     262,144  numpy             10     0.126     0.002     0.123     0.124     0.126     0.128     0.132     1.000

   1,048,576  pytorch           10     0.027     0.002     0.025     0.025     0.027     0.027     0.031    20.552
   1,048,576  numpy             10     0.549     0.008     0.540     0.545     0.548     0.553     0.567     1.000

   4,194,304  pytorch           10     0.108     0.008     0.096     0.101     0.108     0.114     0.123    21.209
   4,194,304  numpy             10     2.290     0.008     2.277     2.286     2.289     2.295     2.302     1.000

(time in wall seconds, less is better)

TPU

$ python run.py benchmarks/turbulent_kinetic_energy -b jax -b numpy --device tpu

benchmarks.turbulent_kinetic_energy
===================================
Running on TPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax              100     0.003     0.001     0.002     0.003     0.003     0.003     0.015     1.132
       4,096  numpy          1,000     0.004     0.001     0.003     0.003     0.003     0.004     0.035     1.000

      16,384  jax              100     0.003     0.001     0.003     0.003     0.003     0.004     0.007     3.322
      16,384  numpy          1,000     0.011     0.002     0.010     0.010     0.011     0.011     0.041     1.000

      65,536  jax              100     0.004     0.004     0.003     0.003     0.004     0.004     0.031    11.957
      65,536  numpy            100     0.050     0.004     0.045     0.048     0.050     0.051     0.065     1.000

     262,144  jax               10     0.004     0.000     0.004     0.004     0.004     0.004     0.005    40.486
     262,144  numpy             10     0.178     0.005     0.168     0.176     0.177     0.182     0.185     1.000

   1,048,576  jax               10     0.008     0.000     0.008     0.008     0.008     0.008     0.009    95.165
   1,048,576  numpy             10     0.803     0.040     0.750     0.766     0.797     0.834     0.872     1.000

   4,194,304  jax               10     0.022     0.000     0.022     0.022     0.022     0.022     0.023   121.268
   4,194,304  numpy             10     2.679     0.349     2.423     2.482     2.511     2.745     3.577     1.000

(time in wall seconds, less is better)