Skip to content

Commit

Permalink
Periodic BCs for gray-scott, and interop with Nodax
Browse files Browse the repository at this point in the history
  • Loading branch information
rdes committed Mar 14, 2024
1 parent c6f8253 commit 81bea44
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 2 deletions.
4 changes: 2 additions & 2 deletions demos/Advection/01_adv_diff_periodic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# %%
%load_ext autoreload
%autoreload 2
# %load_ext autoreload
# %autoreload 2

# %%

Expand Down
147 changes: 147 additions & 0 deletions demos/Gray-Scott/00_gray-scott_with_rbf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# %%

"""
Test of the Updec package on the Advection-Diffusion equation with RBFs:
PDE here: https://en.wikipedia.org/wiki/Convection%E2%80%93diffusion_equation
"""

import time

import jax
import jax.numpy as jnp

# jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

from updec import *
# key = jax.random.PRNGKey(13)
key = None

# from torch.utils.tensorboard import SummaryWriter


RUN_NAME = "TempFolder"
DATAFOLDER = "./data/" + RUN_NAME +"/"
make_dir(DATAFOLDER)

RBF = partial(polyharmonic, a=1)
MAX_DEGREE = 1

DT = 1e-4
NB_TIMESTEPS = 100
PLOT_EVERY = 10

VEL = jnp.array([100.0, 0.0])
## Diffusive constant
K = 0.08

Nx = 40
Ny = 20
SUPPORT_SIZE = "max"

facet_types={"South":"p0", "North":"p0", "West":"p1", "East":"p1"}
cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, noise_key=key, support_size=SUPPORT_SIZE)

cloud.visualize_cloud(s=0.1, figsize=(7,3));

# print("Local supports:", cloud.local_supports[0])


# %%

# def my_diff_operator(x, center=None, rbf=None, monomial=None, fields=None):
# val = nodal_value(x, center, rbf, monomial)
# return val

# def my_rhs_operator(x, centers=None, rbf=None, fields=None):
# lap = value(x, fields[:,0], centers, rbf)
# return lap


def my_diff_operator(x, center=None, rbf=None, monomial=None, fields=None):
val = nodal_value(x, center, rbf, monomial)
grad = nodal_gradient(x, center, rbf, monomial)
lap = nodal_laplacian(x, center, rbf, monomial)
return (val/DT) + jnp.dot(VEL, grad) - K*lap

def my_rhs_operator(x, centers=None, rbf=None, fields=None):
return value(x, fields[:,0], centers, RBF) / DT ## TODO value ?



d_zero = lambda x: 0.
boundary_conditions = {"South":d_zero, "West":d_zero, "North":d_zero, "East":d_zero}


## u0 is zero everywhere except at a point in the middle
u0 = jnp.zeros(cloud.N)
source_id = int(cloud.N*0.5)
source_neighbors = jnp.array(cloud.local_supports[source_id][:cloud.N//40])
u0 = u0.at[source_neighbors].set(0.95)


##

## Begin timestepping for 100 steps

# fig = plt.figure(figsize=(6,3))
# ax1= fig.add_subplot(1, 1, 1, projection='3d')
# ax = fig.add_subplot(1, 1, 1)


u = u0.copy()
ulist = [u]

start = time.time()

for i in range(1, NB_TIMESTEPS+1):
ufield = pde_solver_jit(diff_operator=my_diff_operator,
rhs_operator = my_rhs_operator,
rhs_args=[u],
cloud = cloud,
boundary_conditions = boundary_conditions,
rbf=RBF,
max_degree=MAX_DEGREE,)

u = ufield.vals
ulist.append(u)

if i<=3 or i%PLOT_EVERY==0:
print(f"Step {i}")
# plt.cla()
# cloud.visualize_field(u, cmap="jet", projection="3d", title=f"Step {i}")
ax, _ = cloud.visualize_field(u, cmap="jet", title=f"Step {i}", vmin=0, vmax=1, figsize=(6,3),colorbar=False)
# plt.draw()
plt.show()


walltime = time.time() - start

minutes = walltime // 60 % 60
seconds = walltime % 60
print(f"Walltime: {minutes} minutes {seconds:.2f} seconds")



# %%

# ax = plt.gca()
filename = DATAFOLDER + "advection_diffusion_rbf.gif"
cloud.animate_fields([ulist], cmaps="jet", filename=filename, figsize=(7,3), titles=["Advection-Diffusion with RBFs"])



# %%


## Write stuff to tensorboard
# run_name = str(datetime.datetime.now())[:19] ##For tensorboard
# writer = SummaryWriter("runs/"+run_name, comment='-Laplace')
# hparams_dict = {"rbf":RBF.__name__, "max_degree":MAX_DEGREE, "nb_nodes":Nx*Ny, "support_size":SUPPORT_SIZE}
# metrics_dict = {"metrics/mse_error":float(error), "metrics/wall_time":walltime}
# writer.add_hparams(hparams_dict, metrics_dict, run_name="hp_params")
# writer.add_figure("plots", fig)
# writer.flush()
# writer.close()

# %%

0 comments on commit 81bea44

Please sign in to comment.