forked from neuraloperator/physics_informed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
356 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,6 @@ | |
## Requirements | ||
- Pytorch 1.8.0 or later | ||
- wandb | ||
- tqdm | ||
- tqdm | ||
|
||
## Description |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import Dataset | ||
from .utils import get_4dgrid, get_2dgird, concat, get_3dgrid | ||
|
||
|
||
class BelflowData(object): | ||
def __init__(self, npt_col=11, npt_boundary=31, npt_init=11): | ||
self.Ne = npt_col ** 4 | ||
self.Nb = npt_boundary ** 2 * 6 | ||
self.col_xyzt, self.col_uvwp = self.get_collocation(npt_col) | ||
self.bd_xyzt, self.bd_uvwp = self.sample_boundary(npt_boundary) | ||
self.ini_xyzt, self.ini_uvwp = self.get_init(npt_init) | ||
|
||
@staticmethod | ||
def get_collocation(num=11): | ||
xyzt = get_4dgrid(num) | ||
uvwp = BelflowData.cal_uvwp(xyzt) | ||
return xyzt, uvwp | ||
|
||
@staticmethod | ||
def get_init(num=11): | ||
xyz = get_3dgrid(num) | ||
ts = np.zeros((xyz.shape[0], 1)) | ||
coord = np.hstack((xyz, ts)) | ||
uvwp = BelflowData.cal_uvwp(coord) | ||
return coord, uvwp | ||
|
||
@staticmethod | ||
def sample_boundary(num=31): | ||
''' | ||
Sample boundary data on each face | ||
Args: | ||
num: | ||
Returns: | ||
''' | ||
samples = get_2dgird(num) | ||
dataList = [] | ||
offset = range(3) | ||
z = np.ones((samples.shape[0], 1)) | ||
signs = [-1, 1] | ||
for i in offset: | ||
for sign in signs: | ||
dataList.append(concat(samples, z*sign, offset=i)) | ||
bd_xyzt = np.vstack(dataList) | ||
bd_uvwp = BelflowData.cal_uvwp(bd_xyzt) | ||
return bd_xyzt, bd_uvwp | ||
|
||
@staticmethod | ||
def cal_uvwp(xyzt, a=1, d=1): | ||
x, y, z = xyzt[:, 0:1], xyzt[:, 1:2], xyzt[:, 2:3] | ||
t = xyzt[:, -1:] | ||
comp_x = a * x + d * y | ||
comp_y = a * y + d * z | ||
comp_z = a * z + d * x | ||
u = -a * np.exp(- d ** 2 * t) * (np.exp(a * x) * np.sin(comp_y) | ||
+ np.exp(a * z) * np.cos(comp_x)) | ||
v = -a * np.exp(- d ** 2 * t) * (np.exp(a * y) * np.sin(comp_z) | ||
+ np.exp(a * x) * np.cos(comp_y)) | ||
w = -a * np.exp(- d ** 2 * t) * (np.exp(a * z) * np.sin(comp_x) | ||
+ np.exp(a * y) * np.cos(comp_z)) | ||
p = - 0.5 * a ** 2 * np.exp(-2 * d ** 2 * t) \ | ||
* (np.exp(2 * a * x) + np.exp(2 * a * y) + np.exp(2 * a * z) + | ||
2 * np.sin(comp_x) * np.cos(comp_z) * np.exp(a * (y + z)) + | ||
2 * np.sin(comp_y) * np.cos(comp_x) * np.exp(a * (z + x)) + | ||
2 * np.sin(comp_z) * np.cos(comp_y) * np.exp(a * (x + y))) | ||
return np.hstack((u, v, w, p)) | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import torch | ||
import torch.autograd as autograd | ||
from train_utils.utils import save_checkpoint, zero_grad | ||
from train_utils.losses import LpLoss | ||
from .utils import cal_mixgrad | ||
from tqdm import tqdm | ||
|
||
try: | ||
import wandb | ||
except ImportError: | ||
wandb = None | ||
|
||
|
||
class Baselinetrainer(object): | ||
def __init__(self, model, | ||
device=torch.device('cpu'), | ||
log=False, log_args=None): | ||
self.model = model.to(device) | ||
self.device = device | ||
self.log_init(log, log_args) | ||
|
||
def prepare_data(self, dataset): | ||
# collocation points | ||
self.col_xyzt = torch.from_numpy(dataset.col_xyzt).to(self.device).float() | ||
self.col_uvwp = torch.from_numpy(dataset.col_uvwp).to(self.device).float() | ||
# boundary points | ||
self.bd_xyzt = torch.from_numpy(dataset.bd_xyzt).to(self.device).float() | ||
self.bd_uvwp = torch.from_numpy(dataset.bd_uvwp).to(self.device).float() | ||
# initial condition | ||
self.ini_xyzt = torch.from_numpy(dataset.ini_xyzt).to(self.device).float() | ||
self.ini_uvwp = torch.from_numpy(dataset.ini_uvwp).to(self.device).float() | ||
|
||
def train_LBFGS(self, dataset, | ||
optimizer): | ||
pass | ||
|
||
def train_adam(self, | ||
optimizer, | ||
alpha=100.0, beta=100.0, | ||
iter_num=10, | ||
path='beltrami', name='test.pt', | ||
scheduler=None, re=1.0): | ||
self.model.train() | ||
self.col_xyzt.requires_grad = True | ||
mse = torch.nn.MSELoss() | ||
pbar = tqdm(range(iter_num), dynamic_ncols=True, smoothing=0.01) | ||
for e in pbar: | ||
optimizer.zero_grad() | ||
zero_grad(self.col_xyzt) | ||
|
||
pred_bd_uvwp = self.model(self.bd_xyzt) | ||
bd_loss = mse(pred_bd_uvwp[0:3], self.bd_uvwp[0:3]) | ||
|
||
pred_ini_uvwp = self.model(self.ini_xyzt) | ||
ini_loss = mse(pred_ini_uvwp[0:3], self.ini_uvwp[0:3]) | ||
|
||
pred_col_uvwp = self.model(self.col_xyzt) | ||
f_loss = self.loss_f(pred_col_uvwp, self.col_xyzt, re=re) | ||
|
||
total_loss = alpha * bd_loss + beta * ini_loss + f_loss | ||
total_loss.backward() | ||
optimizer.step() | ||
if scheduler is not None: | ||
scheduler.step() | ||
|
||
pbar.set_description( | ||
( | ||
f'Total loss: {total_loss.item():.6f}, f loss: {f_loss.item():.7f} ' | ||
f'Boundary loss : {bd_loss.item():.7f}, initial loss: {ini_loss.item():.7f}' | ||
) | ||
) | ||
if e % 500 == 0: | ||
u_err, v_err, w_err = self.eval_error() | ||
print(f'u error: {u_err}, v error: {v_err}, w error: {w_err}') | ||
save_checkpoint(path, name, self.model) | ||
|
||
def eval_error(self): | ||
lploss = LpLoss() | ||
self.model.eval() | ||
with torch.no_grad(): | ||
pred_uvwp = self.model(self.col_xyzt) | ||
u_error = lploss(pred_uvwp[:, 0], self.col_uvwp[:, 0]) | ||
v_error = lploss(pred_uvwp[:, 1], self.col_uvwp[:, 1]) | ||
w_error = lploss(pred_uvwp[:, 2], self.col_uvwp[:, 2]) | ||
return u_error.item(), v_error.item(), w_error.item() | ||
|
||
@staticmethod | ||
def log_init(log, log_args): | ||
if wandb and log: | ||
wandb.init(project=log_args['project'], | ||
entity='hzzheng-pino', | ||
config=log_args, | ||
tags=['BelflowData']) | ||
|
||
@staticmethod | ||
def loss_f(uvwp, xyzt, re=1.0): | ||
''' | ||
Index table | ||
u: 0, v: 1, w: 2, p: 3 | ||
x: 0, y: 1, z: 2, t: 3 | ||
Args: | ||
uvwp: output of model - (u, v, w, p) | ||
xyzt: input of model - (x, y, z, t) | ||
re: Reynolds number | ||
Returns: | ||
residual of NS | ||
''' | ||
u_xyzt, u_xx, u_yy, u_zz = cal_mixgrad(uvwp[:, 0], xyzt) | ||
v_xyzt, v_xx, v_yy, v_zz = cal_mixgrad(uvwp[:, 1], xyzt) | ||
w_xyzt, w_xx, w_yy, w_zz = cal_mixgrad(uvwp[:, 2], xyzt) | ||
p_xyzt, = autograd.grad(outputs=[uvwp[:, 3].sum()], inputs=xyzt, | ||
create_graph=True) | ||
|
||
evp4 = u_xyzt[:, 0] + v_xyzt[:, 1] + w_xyzt[:, 2] | ||
|
||
evp1 = u_xyzt[:, 3] + torch.sum(uvwp[:, :3] * u_xyzt[:, :3], dim=1) \ | ||
+ p_xyzt[:, 0] - (u_xx + u_yy + u_zz) / re | ||
evp2 = v_xyzt[:, 3] + torch.sum(uvwp[:, :3] * v_xyzt[:, :3], dim=1) \ | ||
+ p_xyzt[:, 1] - (v_xx + v_yy + v_zz) / re | ||
evp3 = w_xyzt[:, 3] + torch.sum(uvwp[:, :3] * w_xyzt[:, :3], dim=1) \ | ||
+ p_xyzt[:, 2] - (w_xx + w_yy + w_zz) / re | ||
|
||
return torch.mean(evp1 ** 2 + evp2 ** 2 + evp3 ** 2 + evp4 ** 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import numpy as np | ||
|
||
import torch | ||
import torch.autograd as autograd | ||
|
||
|
||
def get_2dgird(num=31): | ||
x = np.linspace(-1, 1, num) | ||
y = np.linspace(-1, 1, num) | ||
gridx, gridy = np.meshgrid(x, y) | ||
xs = gridx.reshape(-1, 1) | ||
ys = gridy.reshape(-1, 1) | ||
result = np.hstack((xs, ys)) | ||
return result | ||
|
||
|
||
def get_3dgrid(num=11): | ||
x = np.linspace(-1, 1, num) | ||
y = np.linspace(-1, 1, num) | ||
z = np.linspace(-1, 1, num) | ||
gridx, gridy, gridz = np.meshgrid(x, y, z) | ||
xs = gridx.reshape(-1, 1) | ||
ys = gridy.reshape(-1, 1) | ||
zs = gridz.reshape(-1, 1) | ||
return np.hstack((xs, ys, zs)) | ||
|
||
|
||
def get_4dgrid(num=11): | ||
''' | ||
4-D meshgrid | ||
Args: | ||
num: number of collocation points of each dimension | ||
Returns: | ||
(num**4, 4) tensor | ||
''' | ||
t = np.linspace(0, 1, num) | ||
x = np.linspace(-1, 1, num) | ||
y = np.linspace(-1, 1, num) | ||
z = np.linspace(-1, 1, num) | ||
gridx, gridy, gridz, gridt = np.meshgrid(x, y, z, t) | ||
xs = gridx.reshape(-1, 1) | ||
ys = gridy.reshape(-1, 1) | ||
zs = gridz.reshape(-1, 1) | ||
ts = gridt.reshape(-1, 1) | ||
result = np.hstack((xs, ys, zs, ts)) | ||
return result | ||
|
||
|
||
def vel2vor(u, v, x, y): | ||
u_y, = autograd.grad(outputs=[u.sum()], inputs=[y], create_graph=True) | ||
v_x, = autograd.grad(outputs=[v.sum()], inputs=[x], create_graph=True) | ||
vorticity = - u_y + v_x | ||
return vorticity | ||
|
||
|
||
def net_NS(x, y, t, model): | ||
out = model(torch.cat([x, y, t], dim=1)) | ||
u = out[:, 0] | ||
v = out[:, 1] | ||
p = out[:, 2] | ||
return u, v, p | ||
|
||
|
||
def sub_mse(vec): | ||
''' | ||
Compute mse of two parts of a vector | ||
Args: | ||
vec: | ||
Returns: | ||
''' | ||
length = vec.shape[0] // 2 | ||
diff = (vec[:length] - vec[length: 2 * length]) ** 2 | ||
return diff.mean() | ||
|
||
|
||
def get_sample(npt=100): | ||
num = npt // 2 | ||
bc1_y_sample = torch.rand(size=(num, 1)).repeat(2, 1) | ||
bc1_t_sample = torch.rand(size=(num, 1)).repeat(2, 1) | ||
|
||
bc1_x_sample = torch.cat([torch.zeros(num, 1), torch.ones(num, 1)], dim=0) | ||
|
||
bc2_x_sample = torch.rand(size=(num, 1)).repeat(2, 1) | ||
bc2_t_sample = torch.rand(size=(num, 1)).repeat(2, 1) | ||
|
||
bc2_y_sample = torch.cat([torch.zeros(num, 1), torch.ones(num, 1)], dim=0) | ||
return bc1_x_sample, bc1_y_sample, bc1_t_sample, \ | ||
bc2_x_sample, bc2_y_sample, bc2_t_sample | ||
|
||
|
||
def concat(xy, z, t=0.0, offset=0): | ||
''' | ||
Args: | ||
xy: (N, 2) | ||
z: (N, 1) | ||
t: (N, 1) | ||
offset: start index of xy | ||
Returns: | ||
(N, 4) array | ||
''' | ||
output = np.zeros((z.shape[0], 4)) * t | ||
if offset < 2: | ||
output[:, offset: offset+2] = xy | ||
output[:, (offset+2) % 3: (offset+2) % 3 + 1] = z | ||
else: | ||
output[:, 2:] = xy[:, 0:1] | ||
output[:, 0:1] = xy[:, 1:] | ||
output[:, 1:2] = z | ||
return output | ||
|
||
|
||
def cal_mixgrad(outputs, inputs): | ||
out_grad, = autograd.grad(outputs=[outputs.sum()], inputs=[inputs], create_graph=True) | ||
out_x2, = autograd.grad(outputs=[out_grad[:, 0].sum()], inputs=[inputs], create_graph=True) | ||
out_xx = out_x2[:, 0] | ||
out_y2, = autograd.grad(outputs=[out_grad[:, 1].sum()], inputs=[inputs], create_graph=True) | ||
out_yy = out_y2[:, 1] | ||
out_z2, = autograd.grad(outputs=[out_grad[:, 2].sum()], inputs=[inputs], create_graph=True) | ||
out_zz = out_z2[:, 2] | ||
return out_grad, out_xx, out_yy, out_zz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from baselines.train import Baselinetrainer | ||
from baselines.data import BelflowData | ||
from models.FCN import FCNet | ||
|
||
import torch | ||
from torch.optim import Adam, LBFGS | ||
|
||
|
||
if __name__ == '__main__': | ||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
print(f'Device: {device}') | ||
layers = [4, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 4] | ||
model = FCNet(layers) | ||
trainer = Baselinetrainer(model=model, device=device) | ||
dataset = BelflowData(npt_col=11, npt_boundary=31, npt_init=11) | ||
|
||
alpha = 100.0 | ||
beta = 100.0 | ||
optimizer = Adam(model.parameters(), lr=1e-3) | ||
|
||
trainer.prepare_data(dataset) | ||
trainer.train_adam(optimizer, alpha, beta, | ||
iter_num=5000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from ptflops import get_model_complexity_info |
Oops, something went wrong.