Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jun 9, 2023
1 parent ceb4e3f commit d962947
Showing 1 changed file with 32 additions and 114 deletions.
146 changes: 32 additions & 114 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
<div align = "center">
<img width=400 src="assets/kernexlogo.svg" align="center">


<h3 align="center">Differentiable Stencil computations in JAX </h2>

[**Installation**](#Installation)
Expand All @@ -12,8 +11,8 @@
|[**Benchmarking**](#Benchmarking)

![Tests](https://github.com/ASEM000/kernex/actions/workflows/tests.yml/badge.svg)
![pyver](https://img.shields.io/badge/python-3.7%203.8%203.9%203.10-red)
![codestyle](https://img.shields.io/badge/codestyle-black-lightgrey)
![pyver](https://img.shields.io/badge/python-3.8%203.8%203.9%203.11-red)
![codestyle](https://img.shields.io/badge/codestyle-black-black)
[![Downloads](https://pepy.tech/badge/kernex)](https://pepy.tech/project/kernex)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14UEqKzIyZsDzQ9IMeanvztXxbbbatTYV?usp=sharing)
[![codecov](https://codecov.io/gh/ASEM000/kernex/branch/main/graph/badge.svg?token=3KLL24Z94I)](https://codecov.io/gh/ASEM000/kernex)
Expand All @@ -29,7 +28,7 @@ pip install kernex

## 📖 Description<a id="Description"></a>

Kernex extends `jax.vmap` and `jax.lax.scan` with `kmap` and `kscan` for general stencil computations.
Kernex extends `jax.vmap`/`jax.lax.map`/`jax.pmap` with `kmap` and `jax.lax.scan` with `kscan` for general stencil computations.

The prime motivation for this package is to blend the solution process of PDEs into a NN setting.

Expand All @@ -45,8 +44,8 @@ The prime motivation for this package is to blend the solution process of PDEs i

```python

import kernex as kex
import jax.numpy as jnp
import kernex as kex
import jax.numpy as jnp

@kex.kmap(kernel_size=(3,))
def sum_all(x):
Expand All @@ -56,22 +55,23 @@ def sum_all(x):
>>> print(sum_all(x))
[ 6 9 12]
```

</td>
<td>

```python
import kernex as kex
import jax.numpy as jnp
import jax.numpy as jnp

@kex.kscan(kernel_size=(3,))
def sum_all(x):
return jnp.sum(x)
return jnp.sum(x)

>>> x = jnp.array([1,2,3,4,5])
>>> print(sum_all(x))
[ 6 13 22]
> > > x = jnp.array([1,2,3,4,5])
> > > print(sum_all(x))
> > > [ 6 13 22]

```
````
</td>
</tr>
</table>
Expand Down Expand Up @@ -193,7 +193,6 @@ See Linear convection in **More examples** section

## 🔢 More examples<a id="MoreExamples"></a>


<details>
<summary>1️⃣ Convolution operation</summary>

Expand All @@ -208,9 +207,10 @@ import kernex as kex
kernel_size= (3,3,3),
padding = ('valid','same','same'))
def kernex_conv2d(x,w):
# JAX channel first conv2d with 3x3x3 kernel_size
# JAX channel first conv2d with 3x3x3 kernel_size
return jnp.sum(x*w)
````
```

</details>

<details>
Expand Down Expand Up @@ -316,7 +316,6 @@ $\Large u_i^{n} = u_i^{n-1} - c \frac{\Delta t}{\Delta x}(u_i^{n-1}-u_{i-1}^{n-1
</table>
</div>


```python

import jax
Expand Down Expand Up @@ -379,12 +378,11 @@ for line in kx_solution[::20]:

</details>


<details><summary>5️⃣ Gaussian blur</summary>

```python
import jax

import jax
import jax.numpy as jnp
import kernex as kex

Expand All @@ -396,19 +394,18 @@ def gaussian_blur(image, sigma, kernel_size):

@kex.kmap(kernel_size=(kernel_size, kernel_size), padding="same")
def conv(x):
return jnp.sum(x * w)
return jnp.sum(x * w)

return conv(image)


```

</details>

</details>

<details > <summary>6️⃣ Depthwise convolution </summary>
```python
```python

import jax
import jax.numpy as jnp
Expand All @@ -417,22 +414,20 @@ import kernex as kex
@jax.jit
@jax.vmap
@kex.kmap(
kernel_size= (3,3),
padding = ('same','same'))
def kernex_depthwise_conv2d(x,w):
# Channel-first depthwise convolution
# jax.debug.print("x=\n{a}\nw=\n{b} \n\n",a=x, b=w)
return jnp.sum(x*w)

kernel_size= (3,3),
padding = ('same','same'))
def kernex_depthwise_conv2d(x,w): # Channel-first depthwise convolution # jax.debug.print("x=\n{a}\nw=\n{b} \n\n",a=x, b=w)
return jnp.sum(x\*w)

h,w,c = 5,5,2
k=3

x = jnp.arange(1,h*w*c+1).reshape(c,h,w)
w = jnp.arange(1,k*k*c+1).reshape(c,k,k)
print(kernex_depthwise_conv2d(x,w))</summary>
```

````

</details>

<details> <summary>7️⃣ Maxpooling2D and Averagepooling2D </summary>
Expand All @@ -449,21 +444,18 @@ def maxpool_2d(x):
def avgpool_2d(x):
# define the kernel for the Average pool operation over the spatial dimensions
return jnp.mean(x)
```

````

</details>



<details><summary>8️⃣ Runge-Kutta integration</summary>

```python

# lets solve dydt = y, where y0 = 1 and y(t)=e^t
# using Runge-Kutta 4th order method
# f(t,y) = y
import jax.numpy as jnp
import jax.numpy as jnp
import matplotlib.pyplot as plt
import kernex as kex

Expand Down Expand Up @@ -514,77 +506,3 @@ plt.legend()
![img](assets/rk4.svg)

</details>

## ⌛ Benchmarking<a id="Benchmarking"></a>

<details><summary>Conv2D</summary>

```python

# testing and benchmarking convolution
# for complete benchmarking check /tests_and_benchmark

# 3x1024x1024 Input
C,H = 3,1024

@jax.jit
def jax_conv2d(x,w):
return jax.lax.conv_general_dilated(
lhs = x,
rhs = w,
window_strides = (1,1),
padding = 'SAME',
dimension_numbers = ('NCHW', 'OIHW', 'NCHW'),)[0]


x = jax.random.normal(jax.random.PRNGKey(0),(C,H,H))
xx = x[None]
w = jax.random.normal(jax.random.PRNGKey(0),(C,3,3))
ww = w[None]

# assert equal
np.testing.assert_allclose(kernex_conv2d(x,w),jax_conv2d(xx,ww),atol=1e-3)

# Mac M1 CPU
# check tests_and_benchmark folder for more.

%timeit kernex_conv2d(x,w).block_until_ready()
# 3.96 ms ± 272 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit jax_conv2d(xx,ww).block_until_ready()
# 27.5 ms ± 993 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```

</details>

<details><summary>get_patches</summary>

```python

# benchmarking `get_patches` with `jax.lax.conv_general_dilated_patches`
# On Mac M1 CPU

@jax.jit
@kex.kmap(kernel_size=(3,),padding='same')
def get_patches(x):
return x

@jax.jit
def jax_get_patches(x):
return jax.lax.conv_general_dilated_patches(x,(3,),(1,),padding='same')

x = jnp.ones([1_000_000])
xx = jnp.ones([1,1,1_000_000])

np.testing.assert_allclose(
get_patches(x),
jax_get_patches(xx).reshape(-1,1_000_000).T)

>> %timeit get_patches(x).block_until_ready()
>> %timeit jax_get_patches(xx).block_until_ready()

1.73 ms ± 92.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
10.6 ms ± 337 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

</details>

0 comments on commit d962947

Please sign in to comment.