-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Super-scaling investigation of Laplace
- Loading branch information
Showing
2 changed files
with
245 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# %% | ||
|
||
""" | ||
Super-Scaled Updes on the Laplace equation with RBFs | ||
""" | ||
|
||
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false" | ||
import time | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
# jax.config.update('jax_platform_name', 'cpu') | ||
# jax.config.update("jax_enable_x64", True) | ||
|
||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
|
||
from updes import * | ||
|
||
DATAFOLDER = "./data/TempFolder/" | ||
|
||
RBF = partial(polyharmonic, a=1) | ||
MAX_DEGREE = 0 | ||
|
||
Nx = Ny = 10 | ||
SUPPORT_SIZE = "max" | ||
# SUPPORT_SIZE = 9*1 | ||
facet_types={"South":"n", "West":"d", "North":"d", "East":"d"} | ||
|
||
start = time.time() | ||
cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, support_size=SUPPORT_SIZE) | ||
walltime = time.time() - start | ||
|
||
print(f"Cloud generation walltime: {walltime:.2f} seconds") | ||
|
||
# cloud.visualize_cloud(s=0.5, figsize=(7,6)); | ||
|
||
## %% | ||
|
||
def my_diff_operator(x, center=None, rbf=None, monomial=None, fields=None): | ||
return nodal_laplacian(x, center, rbf, monomial) | ||
|
||
def my_rhs_operator(x, centers=None, rbf=None, fields=None): | ||
return -0.0 | ||
|
||
d_north = lambda node: jnp.sin(jnp.pi * node[0]) | ||
d_zero = lambda node: 0.0 | ||
boundary_conditions = {"South":d_zero, "West":d_zero, "North":d_north, "East":d_zero} | ||
|
||
start = time.time() | ||
sol = pde_solver_jit(diff_operator=my_diff_operator, | ||
rhs_operator = my_rhs_operator, | ||
cloud = cloud, | ||
boundary_conditions = boundary_conditions, | ||
rbf=RBF, | ||
max_degree=MAX_DEGREE) | ||
walltime = time.time() - start | ||
|
||
minutes = walltime // 60 % 60 | ||
seconds = walltime % 60 | ||
print(f"Walltime: {minutes} minutes {seconds:.2f} seconds") | ||
|
||
## RBF solution | ||
rbf_sol = sol.vals | ||
|
||
|
||
fig = plt.figure(figsize=(6*1,5)) | ||
ax= fig.add_subplot(1, 1, 1, projection='3d') | ||
cloud.visualize_field(rbf_sol, cmap="jet", projection="3d", title="Laplace with RBFs", ax=ax); | ||
plt.show() | ||
## Savefig | ||
fig.savefig(DATAFOLDER+"super_scaled.png", dpi=300) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# %% | ||
|
||
""" | ||
Control of Laplace equation with differentiable physics | ||
""" | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import optax | ||
|
||
import matplotlib.pyplot as plt | ||
from tqdm import tqdm | ||
import tracemalloc, time | ||
|
||
from updes import * | ||
|
||
#%% | ||
|
||
|
||
DATAFOLDER = "./data/TempFolder/" | ||
|
||
RBF = polyharmonic | ||
MAX_DEGREE = 1 | ||
|
||
Nx = 100 | ||
Ny = Nx | ||
|
||
LR = 1e-3 | ||
GAMMA = 1 ### LR decay rate | ||
EPOCHS = 50 | ||
|
||
|
||
facet_types={"North":"d", "South":"d", "West":"d", "East":"d"} | ||
train_cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, noise_key=None, support_size="max") | ||
|
||
train_cloud.visualize_cloud(s=0.1, title="Training cloud", figsize=(5,4)); | ||
|
||
#%% | ||
|
||
start = time.process_time() | ||
tracemalloc.start() | ||
|
||
|
||
## For the cost function | ||
north_ids = jnp.array(train_cloud.facet_nodes["North"]) | ||
xy_north = train_cloud.sorted_nodes[north_ids, :] | ||
x_north = xy_north[:, 0] | ||
q_cost = jax.vmap(lambda x: jnp.cos(2*jnp.pi * x))(x_north) | ||
|
||
|
||
## Exact solution | ||
def laplace_exact_sol(xy): | ||
PI = jnp.pi | ||
x, y = xy | ||
|
||
a = 0.5 * jnp.sin(2*PI*x) * (jnp.exp(2*PI*(y-1)) + jnp.exp(2*PI*(1-y))) / jnp.cosh(2*PI) | ||
b = jnp.cos(2*PI*x) * (jnp.exp(2*PI*y) + jnp.exp(-2*PI*y)) / (4*PI*jnp.cosh(2*PI)) | ||
|
||
return a+b | ||
|
||
def laplace_exact_control(x): | ||
PI = jnp.pi | ||
return (jnp.sin(2*PI*x)/jnp.cosh(2*PI)) + (jnp.cos(2*PI*x)*jnp.tanh(2*PI)/(2*PI)) | ||
|
||
|
||
exact_sol = jax.vmap(laplace_exact_sol)(train_cloud.sorted_nodes) | ||
exact_control = jax.vmap(laplace_exact_control)(x_north) | ||
|
||
|
||
#%% | ||
def my_diff_operator(x, center=None, rbf=None, monomial=None, fields=None): | ||
return nodal_laplacian(x, center, rbf, monomial) | ||
|
||
def my_rhs_operator(x, centers=None, rbf=None, fields=None): | ||
return 0.0 | ||
|
||
|
||
|
||
### Optimisation start ### | ||
d_south = jax.jit(lambda x: jnp.sin(2*jnp.pi * x[0])) | ||
d_east = jax.jit(lambda x: jnp.sinh(2*jnp.pi*x[1]) / (2*jnp.pi * jnp.cosh(2*jnp.pi))) | ||
d_west = d_east | ||
|
||
# @jax.jit | ||
def loss_fn(bcn): | ||
sol = pde_solver(diff_operator=my_diff_operator, | ||
rhs_operator = my_rhs_operator, | ||
cloud = train_cloud, | ||
boundary_conditions = {"South":d_south, "West":d_west, "North":bcn, "East":d_east}, | ||
rbf=RBF, | ||
max_degree=MAX_DEGREE) | ||
|
||
grad_n_y = gradient_vec(xy_north, sol.coeffs, train_cloud.sorted_nodes, RBF)[...,1] | ||
|
||
loss_cost = (grad_n_y - q_cost)**2 | ||
return jnp.trapezoid(loss_cost, x=x_north) | ||
|
||
|
||
@jax.jit | ||
def update_step(bcn, opt_state): | ||
loss, grad = jax.value_and_grad(loss_fn)(bcn) | ||
updates, opt_state = optimiser.update(grad, opt_state) | ||
bcn = optax.apply_updates(bcn, updates) | ||
|
||
north_loss = jnp.mean((bcn-exact_control)**2) | ||
|
||
return bcn, opt_state, loss, north_loss | ||
|
||
# grad_loss_fn = jax.value_and_grad(loss_fn) | ||
|
||
|
||
# %% | ||
|
||
optimal_bcn = jnp.zeros((north_ids.shape[0])) | ||
history_cost = [] | ||
north_mse = [] | ||
|
||
scheduler = optax.piecewise_constant_schedule(init_value=LR, | ||
boundaries_and_scales={int(EPOCHS*0.5):0.1, int(EPOCHS*0.75):0.1}) | ||
optimiser = optax.adam(learning_rate=scheduler) | ||
opt_state = optimiser.init(optimal_bcn) | ||
|
||
for step in tqdm(range(1, EPOCHS+1)): | ||
|
||
optimal_bcn, opt_state, loss, north_error = update_step(optimal_bcn, opt_state) | ||
|
||
history_cost.append(loss) | ||
north_mse.append(north_error) | ||
|
||
if step<=3 or step%10==0: | ||
print("Epoch: %-5d InitLR: %.4f Loss: %.8f TestError: %.6f" % (step, LR, loss, north_error)) | ||
|
||
mem_usage = tracemalloc.get_traced_memory()[1] | ||
exec_time = time.process_time() - start | ||
|
||
print("A few script details:") | ||
print(" Peak memory usage: ", mem_usage, 'bytes') | ||
print(' CPU execution time:', exec_time, 'seconds') | ||
|
||
tracemalloc.stop() | ||
|
||
|
||
### Visualisation at north | ||
ax = plot(x_north, exact_control, "-", label="Analytical", x_label=r"$x$", figsize=(5,3), ylim=(-.2,.2)) | ||
plot(x_north, optimal_bcn, "--", label="Diff. Physics", ax=ax, title=f"Optimised north solution / MSE = {north_error:.4f}"); | ||
|
||
|
||
ax = plot(history_cost, label='Cost objective', x_label='epochs', title="Loss", y_scale="log"); | ||
plot(north_mse, label='Test Error at North', x_label='epochs', title="Loss", y_scale="log", ax=ax); | ||
|
||
|
||
# %% | ||
|
||
############# Just for fun ########## TODO do this outside the loop | ||
|
||
optimal_conditions = {"South":d_south, "West":d_west, "North":optimal_bcn, "East":d_east} | ||
sol = pde_solver(diff_operator=my_diff_operator, | ||
rhs_operator = my_rhs_operator, | ||
cloud = train_cloud, | ||
boundary_conditions = optimal_conditions, | ||
rbf=RBF, | ||
max_degree=MAX_DEGREE) | ||
|
||
### Optional visualisation of whole solution | ||
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(6*2,5)) | ||
train_cloud.visualize_field(sol.vals, cmap="jet", projection="2d", title="Optimized solution", ax=ax1, vmin=-1, vmax=1) | ||
train_cloud.visualize_field(jnp.abs(sol.vals-exact_sol), cmap="magma", projection="2d", title="Absolute error", ax=ax2, vmin=0, vmax=1); | ||
plt.savefig(DATAFOLDER+"solution_"+str(step)+".png", transparent=True) | ||
|
||
|
||
|
||
# %% |