Skip to content

Commit

Permalink
Merge pull request #4 from ddrous/dev
Browse files Browse the repository at this point in the history
Merge Lineax integration for patch 1.0.2
  • Loading branch information
ddrous authored May 28, 2024
2 parents c341c81 + 7c48d7b commit 0ee5c69
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Updec CI/CD
name: Updes CI/CD

on: [push]

Expand Down
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ pip install updes
```

The example below illustrates how to solve the Laplace equation with Dirichlet and Neumann boundary conditions:
<p align="center">
<img src="docs/assets/LaplacePDE.png" width="250">
</p>

```python
import updes
import jax.numpy as jnp
Expand Down Expand Up @@ -70,6 +74,7 @@ cloud.visualize_field(sol.vals, cmap="jet", projection="3d", title="RBF solution

## To-Dos
1. Logo, contributors guide, and developer documentation
2. Improved ill-conditioned linear systems for RBF-FD (i.e. `support_size != "max"`)
2. More introductory examples in the documentation :
- Integration with neural networks and [Equinox](https://github.com/patrick-kidger/equinox)
- Non-linear and multi-dimensional PDEs
Expand All @@ -83,7 +88,7 @@ We welcome contributions from the community. Please feel free to open an issue o


## Dependencies
- **Core**: [JAX](https://github.com/google/jax) - [GMSH](https://pypi.org/project/gmsh/) - [Matplotlib](https://github.com/matplotlib/matplotlib) - [Seaborn](https://github.com/mwaskom/seaborn) - [Scikit-Learn](https://github.com/scikit-learn/scikit-learn)
- **Core**: [JAX](https://github.com/google/jax) - [GMSH](https://pypi.org/project/gmsh/) - [Lineax](https://github.com/patrick-kidger/lineax) - [Matplotlib](https://github.com/matplotlib/matplotlib) - [Seaborn](https://github.com/mwaskom/seaborn) - [Scikit-Learn](https://github.com/scikit-learn/scikit-learn)
- **Optional**: [PyVista](https://github.com/pyvista/pyvista) - [FFMPEG](https://github.com/kkroening/ffmpeg-python) - [QuartoDoc](https://github.com/machow/quartodoc/)

See the `pyproject.toml` file the specific versions of the dependencies.
Expand Down
98 changes: 74 additions & 24 deletions demos/Laplace/30_laplace_super_scaled.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
"""

# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"

import pstats
from updes import *

import time

import jax
Expand All @@ -16,29 +20,48 @@
import matplotlib.pyplot as plt
import seaborn as sns

from updes import *
import cProfile


DATAFOLDER = "./data/TempFolder/"

RBF = partial(polyharmonic, a=1)
# RBF = partial(gaussian, eps=1e-1)
# RBF = partial(gaussian, eps=1e1)
# RBF = partial(thin_plate, a=3)
MAX_DEGREE = 0
MAX_DEGREE = 1

Nx = Ny = 20
Nx = Ny = 50
# SUPPORT_SIZE = "max"
SUPPORT_SIZE = 20*2
SUPPORT_SIZE = 2
facet_types={"South":"d", "West":"d", "North":"d", "East":"d"}

start = time.time()

## benchmarking with cprofile
# res = cProfile.run("cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, support_size=SUPPORT_SIZE)")
cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, support_size=SUPPORT_SIZE)

## Print results sorted by cumulative time
# p = pstats.Stats(res)
# p.sort_stats('cumulative').print_stats(10)


## Only print the top 10 high-level function
# p.print_callers(10)



# cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, support_size=SUPPORT_SIZE)



walltime = time.time() - start

print(f"Cloud generation walltime: {walltime:.2f} seconds")

# cloud.visualize_cloud(s=0.5, figsize=(7,6));

## %%
# %%

def my_diff_operator(x, center=None, rbf=None, monomial=None, fields=None):
return nodal_laplacian(x, center, rbf, monomial)
Expand Down Expand Up @@ -80,8 +103,35 @@ def my_rhs_operator(x, centers=None, rbf=None, fields=None):



# import jax
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"
# import jax.numpy as jnp
# import jax.random as jr
# import lineax as lx

# # size = 15000
# # matrix_key, vector_key = jr.split(jr.PRNGKey(0))
# # matrix = jr.normal(matrix_key, (size, size))
# # vector = jr.normal(vector_key, (size,))
# # operator = lx.MatrixLinearOperator(matrix)
# # solution = lx.linear_solve(operator, vector, solver=lx.QR())
# # solution.value

# # size = 8000
# # matrix_key, vector_key = jr.split(jr.PRNGKey(0))
# # matrix = jr.normal(matrix_key, (size, size))
# # vector = jr.normal(vector_key, (size,))
# # solution = jnp.linalg.solve(matrix, vector)
# # solution


# size = 15000
# matrix_key, vector_key = jr.split(jr.PRNGKey(0))
# matrix = jr.normal(matrix_key, (size, size))
# vector = jr.normal(vector_key, (size,))
# solution = jnp.linalg.lstsq(matrix, vector)


# %%


# ## Observing the sparsity patten of the matrices involved
Expand All @@ -98,30 +148,30 @@ def my_rhs_operator(x, centers=None, rbf=None, fields=None):



M = compute_nb_monomials(MAX_DEGREE, 2)
A = assemble_A(cloud, RBF, M)
mat1 = jnp.abs(A)
# M = compute_nb_monomials(MAX_DEGREE, 2)
# A = assemble_A(cloud, RBF, M)
# mat1 = jnp.abs(A)

inv_A = assemble_invert_A(cloud, RBF, M)
mat2 = jnp.abs(inv_A)
# inv_A = assemble_invert_A(cloud, RBF, M)
# mat2 = jnp.abs(inv_A)

## Matrix B for the linear system
mat3 = sol.mat
# ## Matrix B for the linear system
# mat3 = sol.mat

## 3 figures
fig, ax = plt.subplots(1, 3, figsize=(15,5))
# ## 3 figures
# fig, ax = plt.subplots(1, 3, figsize=(15,5))

sns.heatmap(jnp.abs(mat1), ax=ax[0], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False)
ax[0].set_title("Collocation Matrix")
# sns.heatmap(jnp.abs(mat1), ax=ax[0], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False)
# ax[0].set_title("Collocation Matrix")

sns.heatmap(jnp.abs(mat2), ax=ax[1], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False)
ax[1].set_title("Inverse of Collocation Matrix")
# sns.heatmap(jnp.abs(mat2), ax=ax[1], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False)
# ax[1].set_title("Inverse of Collocation Matrix")

sns.heatmap(jnp.abs(mat3), ax=ax[2], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False)
ax[2].set_title("Linear System Matrix (B)")
# sns.heatmap(jnp.abs(mat3), ax=ax[2], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False)
# ax[2].set_title("Linear System Matrix (B)")

# plt.title("Sparsity Pattern of the Collocation Matrix")
plt.show()
# # plt.title("Sparsity Pattern of the Collocation Matrix")
# plt.show()


#%%
Expand Down
Binary file added docs/assets/LaplacePDE.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion docs/assets/NextRelease.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
For the next release v1.1.0
For the next release v1.0.2
- [X] Added colorbar to animate fields
- [X] Fixed the args inputs to construct the local matrix for nodal_div_grad: (if array, else, etc.)
- [X] Implemented the Darcy flow problem
- [X] Faster linear solves with Lineax
2 changes: 1 addition & 1 deletion docs/assets/README_PyPI.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ cloud.visualize_field(sol.vals, cmap="jet", projection="3d", title="RBF solution


## Dependencies
- **Core**: JAX - GMSH - Matplotlib - Seaborn - Scikit-Learn
- **Core**: JAX - GMSH - Lineax - Matplotlib - Seaborn - Scikit-Learn
- **Optional**: PyVista - FFMPEG - QuartoDoc

See the `pyproject.toml` file the specific versions of the dependencies.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ keywords = [

dependencies = [
"jax >= 0.3.4",
"lineax",
"gmsh",
"pytest",
"matplotlib>=3.4.0",
Expand Down
6 changes: 5 additions & 1 deletion updes/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,13 @@ def define_local_supports(self):
assert self.support_size > 0, "Support size must be strictly greater than 0"
assert self.support_size <= self.N, "Support size must be strictly less than or equal the number of nodes"

# ## If support size == coords.shape[0], then we are using all the nodes
# if self.support_size > coords.shape[0]:
# self.local_supports = {renumb_map[i]:list(range(self.N)) for i in range(self.N)}
# else:
## Use BallTree for fast nearest neighbors search
# ball_tree = KDTree(coords, leaf_size=40, metric='euclidean')
ball_tree = BallTree(coords, leaf_size=40, metric='euclidean')
ball_tree = BallTree(coords, leaf_size=1, metric='euclidean')
for i in range(self.N):
_, neighbours = ball_tree.query(self.nodes[i][jnp.newaxis], k=self.support_size)
neighbours = neighbours[0][1:] ## Result is a 2d list, without the first el (the node itself)
Expand Down
14 changes: 12 additions & 2 deletions updes/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import jax
import jax.numpy as jnp
from jax.tree_util import Partial
import lineax as lx

from functools import cache

Expand Down Expand Up @@ -601,8 +602,17 @@ def pde_solver( diff_operator:callable,
B1 = assemble_B(diff_operator, cloud, rbf, nb_monomials, diff_args, robin_coeffs)
rhs = assemble_q(rhs_operator, boundary_conditions, cloud, rbf, nb_monomials, rhs_args)

## Solve the linear system
sol_vals = jnp.linalg.solve(B1, rhs)
## Solve the linear system using JAX's direct solver
# sol_vals = jnp.linalg.solve(B1, rhs)

## Solve the linear system using Scipy's iterative solver
# sol_vals = jax.scipy.sparse.linalg.gmres(B1, rhs, tol=1e-5)[0]

## Solve the linear system using Lineax
operator = lx.MatrixLinearOperator(B1)
sol_vals = lx.linear_solve(operator, rhs, solver=lx.QR()).value
# sol_vals = lx.linear_solve(operator, rhs, solver=lx.GMRES(rtol=1e-3, atol=1e-3)).value

sol_coeffs = core_compute_coefficients(sol_vals, cloud, rbf, nb_monomials)

return SteadySol(sol_vals, sol_coeffs, B1)
Expand Down

0 comments on commit 0ee5c69

Please sign in to comment.