Skip to content

Commit

Permalink
Incorporated Lineax for faster solves on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
ddrous committed May 25, 2024
1 parent f96b05f commit 2ad2147
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 29 deletions.
101 changes: 75 additions & 26 deletions demos/Laplace/30_laplace_super_scaled.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
"""

# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"

import pstats
from updes import *

import time

import jax
Expand All @@ -16,29 +20,47 @@
import matplotlib.pyplot as plt
import seaborn as sns

from updes import *
import cProfile


DATAFOLDER = "./data/TempFolder/"

RBF = partial(polyharmonic, a=1)
# RBF = partial(gaussian, eps=1e-1)
# RBF = partial(gaussian, eps=1e1)
# RBF = partial(thin_plate, a=3)
MAX_DEGREE = 0
MAX_DEGREE = 1

Nx = Ny = 20
# SUPPORT_SIZE = "max"
SUPPORT_SIZE = 20*2
Nx = Ny = 15
SUPPORT_SIZE = "max"
# SUPPORT_SIZE = 50*1
facet_types={"South":"d", "West":"d", "North":"d", "East":"d"}

start = time.time()
cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, support_size=SUPPORT_SIZE)

## benchmarking with cprofile
res = cProfile.run("cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, support_size=SUPPORT_SIZE)")

## Print results sorted by cumulative time
p = pstats.Stats(res)
p.sort_stats('cumulative').print_stats(10)


## Only print the top 10 high-level function
# p.print_callers(10)



# cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, support_size=SUPPORT_SIZE)



walltime = time.time() - start

print(f"Cloud generation walltime: {walltime:.2f} seconds")

# cloud.visualize_cloud(s=0.5, figsize=(7,6));

## %%
# %%

def my_diff_operator(x, center=None, rbf=None, monomial=None, fields=None):
return nodal_laplacian(x, center, rbf, monomial)
Expand Down Expand Up @@ -80,8 +102,35 @@ def my_rhs_operator(x, centers=None, rbf=None, fields=None):



# import jax
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"
# import jax.numpy as jnp
# import jax.random as jr
# import lineax as lx

# # size = 15000
# # matrix_key, vector_key = jr.split(jr.PRNGKey(0))
# # matrix = jr.normal(matrix_key, (size, size))
# # vector = jr.normal(vector_key, (size,))
# # operator = lx.MatrixLinearOperator(matrix)
# # solution = lx.linear_solve(operator, vector, solver=lx.QR())
# # solution.value

# # size = 8000
# # matrix_key, vector_key = jr.split(jr.PRNGKey(0))
# # matrix = jr.normal(matrix_key, (size, size))
# # vector = jr.normal(vector_key, (size,))
# # solution = jnp.linalg.solve(matrix, vector)
# # solution


# size = 15000
# matrix_key, vector_key = jr.split(jr.PRNGKey(0))
# matrix = jr.normal(matrix_key, (size, size))
# vector = jr.normal(vector_key, (size,))
# solution = jnp.linalg.lstsq(matrix, vector)


# %%


# ## Observing the sparsity patten of the matrices involved
Expand All @@ -98,30 +147,30 @@ def my_rhs_operator(x, centers=None, rbf=None, fields=None):



M = compute_nb_monomials(MAX_DEGREE, 2)
A = assemble_A(cloud, RBF, M)
mat1 = jnp.abs(A)
# M = compute_nb_monomials(MAX_DEGREE, 2)
# A = assemble_A(cloud, RBF, M)
# mat1 = jnp.abs(A)

inv_A = assemble_invert_A(cloud, RBF, M)
mat2 = jnp.abs(inv_A)
# inv_A = assemble_invert_A(cloud, RBF, M)
# mat2 = jnp.abs(inv_A)

## Matrix B for the linear system
mat3 = sol.mat
# ## Matrix B for the linear system
# mat3 = sol.mat

## 3 figures
fig, ax = plt.subplots(1, 3, figsize=(15,5))
# ## 3 figures
# fig, ax = plt.subplots(1, 3, figsize=(15,5))

sns.heatmap(jnp.abs(mat1), ax=ax[0], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False)
ax[0].set_title("Collocation Matrix")
# sns.heatmap(jnp.abs(mat1), ax=ax[0], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False)
# ax[0].set_title("Collocation Matrix")

sns.heatmap(jnp.abs(mat2), ax=ax[1], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False)
ax[1].set_title("Inverse of Collocation Matrix")
# sns.heatmap(jnp.abs(mat2), ax=ax[1], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False)
# ax[1].set_title("Inverse of Collocation Matrix")

sns.heatmap(jnp.abs(mat3), ax=ax[2], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False)
ax[2].set_title("Linear System Matrix (B)")
# sns.heatmap(jnp.abs(mat3), ax=ax[2], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False)
# ax[2].set_title("Linear System Matrix (B)")

# plt.title("Sparsity Pattern of the Collocation Matrix")
plt.show()
# # plt.title("Sparsity Pattern of the Collocation Matrix")
# plt.show()


#%%
Expand Down
6 changes: 5 additions & 1 deletion updes/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,13 @@ def define_local_supports(self):
assert self.support_size > 0, "Support size must be strictly greater than 0"
assert self.support_size <= self.N, "Support size must be strictly less than or equal the number of nodes"

# ## If support size == coords.shape[0], then we are using all the nodes
# if self.support_size > coords.shape[0]:
# self.local_supports = {renumb_map[i]:list(range(self.N)) for i in range(self.N)}
# else:
## Use BallTree for fast nearest neighbors search
# ball_tree = KDTree(coords, leaf_size=40, metric='euclidean')
ball_tree = BallTree(coords, leaf_size=40, metric='euclidean')
ball_tree = BallTree(coords, leaf_size=1, metric='euclidean')
for i in range(self.N):
_, neighbours = ball_tree.query(self.nodes[i][jnp.newaxis], k=self.support_size)
neighbours = neighbours[0][1:] ## Result is a 2d list, without the first el (the node itself)
Expand Down
14 changes: 12 additions & 2 deletions updes/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import jax
import jax.numpy as jnp
from jax.tree_util import Partial
import lineax as lx

from functools import cache

Expand Down Expand Up @@ -601,8 +602,17 @@ def pde_solver( diff_operator:callable,
B1 = assemble_B(diff_operator, cloud, rbf, nb_monomials, diff_args, robin_coeffs)
rhs = assemble_q(rhs_operator, boundary_conditions, cloud, rbf, nb_monomials, rhs_args)

## Solve the linear system
sol_vals = jnp.linalg.solve(B1, rhs)
## Solve the linear system using JAX's direct solver
# sol_vals = jnp.linalg.solve(B1, rhs)

## Solve the linear system using Scipy's iterative solver
# sol_vals = jax.scipy.sparse.linalg.gmres(B1, rhs, tol=1e-5)[0]

## Solve the linear system using Lineax
operator = lx.MatrixLinearOperator(B1)
sol_vals = lx.linear_solve(operator, rhs, solver=lx.QR()).value
# sol_vals = lx.linear_solve(operator, rhs, solver=lx.GMRES(rtol=1e-3, atol=1e-3)).value

sol_coeffs = core_compute_coefficients(sol_vals, cloud, rbf, nb_monomials)

return SteadySol(sol_vals, sol_coeffs, B1)
Expand Down

0 comments on commit 2ad2147

Please sign in to comment.