From d962947e292571e6ed95746c8fc801a641208712 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Sat, 10 Jun 2023 06:08:04 +0900 Subject: [PATCH] Update README.md --- README.md | 146 ++++++++++++------------------------------------------ 1 file changed, 32 insertions(+), 114 deletions(-) diff --git a/README.md b/README.md index 76a9a29..8c8afc8 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@
-

Differentiable Stencil computations in JAX

[**Installation**](#Installation) @@ -12,8 +11,8 @@ |[**Benchmarking**](#Benchmarking) ![Tests](https://github.com/ASEM000/kernex/actions/workflows/tests.yml/badge.svg) -![pyver](https://img.shields.io/badge/python-3.7%203.8%203.9%203.10-red) -![codestyle](https://img.shields.io/badge/codestyle-black-lightgrey) +![pyver](https://img.shields.io/badge/python-3.8%203.8%203.9%203.11-red) +![codestyle](https://img.shields.io/badge/codestyle-black-black) [![Downloads](https://pepy.tech/badge/kernex)](https://pepy.tech/project/kernex) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14UEqKzIyZsDzQ9IMeanvztXxbbbatTYV?usp=sharing) [![codecov](https://codecov.io/gh/ASEM000/kernex/branch/main/graph/badge.svg?token=3KLL24Z94I)](https://codecov.io/gh/ASEM000/kernex) @@ -29,7 +28,7 @@ pip install kernex ## 📖 Description -Kernex extends `jax.vmap` and `jax.lax.scan` with `kmap` and `kscan` for general stencil computations. +Kernex extends `jax.vmap`/`jax.lax.map`/`jax.pmap` with `kmap` and `jax.lax.scan` with `kscan` for general stencil computations. The prime motivation for this package is to blend the solution process of PDEs into a NN setting. @@ -45,8 +44,8 @@ The prime motivation for this package is to blend the solution process of PDEs i ```python -import kernex as kex -import jax.numpy as jnp +import kernex as kex +import jax.numpy as jnp @kex.kmap(kernel_size=(3,)) def sum_all(x): @@ -56,22 +55,23 @@ def sum_all(x): >>> print(sum_all(x)) [ 6 9 12] ``` + ```python import kernex as kex -import jax.numpy as jnp +import jax.numpy as jnp @kex.kscan(kernel_size=(3,)) def sum_all(x): - return jnp.sum(x) +return jnp.sum(x) ->>> x = jnp.array([1,2,3,4,5]) ->>> print(sum_all(x)) -[ 6 13 22] +> > > x = jnp.array([1,2,3,4,5]) +> > > print(sum_all(x)) +> > > [ 6 13 22] -``` +```` @@ -193,7 +193,6 @@ See Linear convection in **More examples** section ## 🔢 More examples -
1️⃣ Convolution operation @@ -208,9 +207,10 @@ import kernex as kex kernel_size= (3,3,3), padding = ('valid','same','same')) def kernex_conv2d(x,w): - # JAX channel first conv2d with 3x3x3 kernel_size + # JAX channel first conv2d with 3x3x3 kernel_size return jnp.sum(x*w) -```` +``` +
@@ -316,7 +316,6 @@ $\Large u_i^{n} = u_i^{n-1} - c \frac{\Delta t}{\Delta x}(u_i^{n-1}-u_{i-1}^{n-1
- ```python import jax @@ -379,12 +378,11 @@ for line in kx_solution[::20]: -
5️⃣ Gaussian blur ```python - -import jax + +import jax import jax.numpy as jnp import kernex as kex @@ -396,19 +394,18 @@ def gaussian_blur(image, sigma, kernel_size): @kex.kmap(kernel_size=(kernel_size, kernel_size), padding="same") def conv(x): - return jnp.sum(x * w) - + return jnp.sum(x * w) + return conv(image) - - + + ``` - -
+
6️⃣ Depthwise convolution -```python +```python import jax import jax.numpy as jnp @@ -417,13 +414,10 @@ import kernex as kex @jax.jit @jax.vmap @kex.kmap( - kernel_size= (3,3), - padding = ('same','same')) -def kernex_depthwise_conv2d(x,w): - # Channel-first depthwise convolution - # jax.debug.print("x=\n{a}\nw=\n{b} \n\n",a=x, b=w) - return jnp.sum(x*w) - +kernel_size= (3,3), +padding = ('same','same')) +def kernex_depthwise_conv2d(x,w): # Channel-first depthwise convolution # jax.debug.print("x=\n{a}\nw=\n{b} \n\n",a=x, b=w) +return jnp.sum(x\*w) h,w,c = 5,5,2 k=3 @@ -431,8 +425,9 @@ k=3 x = jnp.arange(1,h*w*c+1).reshape(c,h,w) w = jnp.arange(1,k*k*c+1).reshape(c,k,k) print(kernex_depthwise_conv2d(x,w)) -``` - + +```` +
7️⃣ Maxpooling2D and Averagepooling2D @@ -449,13 +444,10 @@ def maxpool_2d(x): def avgpool_2d(x): # define the kernel for the Average pool operation over the spatial dimensions return jnp.mean(x) -``` - +````
- -
8️⃣ Runge-Kutta integration ```python @@ -463,7 +455,7 @@ def avgpool_2d(x): # lets solve dydt = y, where y0 = 1 and y(t)=e^t # using Runge-Kutta 4th order method # f(t,y) = y -import jax.numpy as jnp +import jax.numpy as jnp import matplotlib.pyplot as plt import kernex as kex @@ -514,77 +506,3 @@ plt.legend() ![img](assets/rk4.svg)
- -## ⌛ Benchmarking - -
Conv2D - -```python - -# testing and benchmarking convolution -# for complete benchmarking check /tests_and_benchmark - -# 3x1024x1024 Input -C,H = 3,1024 - -@jax.jit -def jax_conv2d(x,w): - return jax.lax.conv_general_dilated( - lhs = x, - rhs = w, - window_strides = (1,1), - padding = 'SAME', - dimension_numbers = ('NCHW', 'OIHW', 'NCHW'),)[0] - - -x = jax.random.normal(jax.random.PRNGKey(0),(C,H,H)) -xx = x[None] -w = jax.random.normal(jax.random.PRNGKey(0),(C,3,3)) -ww = w[None] - -# assert equal -np.testing.assert_allclose(kernex_conv2d(x,w),jax_conv2d(xx,ww),atol=1e-3) - -# Mac M1 CPU -# check tests_and_benchmark folder for more. - -%timeit kernex_conv2d(x,w).block_until_ready() -# 3.96 ms ± 272 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) - -%timeit jax_conv2d(xx,ww).block_until_ready() -# 27.5 ms ± 993 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) -``` - -
- -
get_patches - -```python - -# benchmarking `get_patches` with `jax.lax.conv_general_dilated_patches` -# On Mac M1 CPU - -@jax.jit -@kex.kmap(kernel_size=(3,),padding='same') -def get_patches(x): - return x - -@jax.jit -def jax_get_patches(x): - return jax.lax.conv_general_dilated_patches(x,(3,),(1,),padding='same') - -x = jnp.ones([1_000_000]) -xx = jnp.ones([1,1,1_000_000]) - -np.testing.assert_allclose( - get_patches(x), - jax_get_patches(xx).reshape(-1,1_000_000).T) - ->> %timeit get_patches(x).block_until_ready() ->> %timeit jax_get_patches(xx).block_until_ready() - -1.73 ms ± 92.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) -10.6 ms ± 337 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) -``` - -