Skip to content

Commit

Permalink
Add Pytorch code for SAPINN
Browse files Browse the repository at this point in the history
  • Loading branch information
devzhk committed Nov 15, 2021
1 parent 60c5ae2 commit 5aeed9d
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 8 deletions.
2 changes: 1 addition & 1 deletion baselines/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(self, datapath1,
self.data = torch.cat((part1, part2), dim=0)
else:
self.data = part1
self.vor = self.data[offset: offset + num, :, :, :]
self.vor = self.data[offset: offset + num, :, :, :].cpu()
if vel:
self.vel_u, self.vel_v = vor2vel(self.vor) # Compute velocity from vorticity

Expand Down
22 changes: 22 additions & 0 deletions baselines/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import torch.nn as nn
from models.FCN import DenseNet
from typing import List
from .utils import weighted_mse


class DeepONet(nn.Module):
Expand Down Expand Up @@ -29,3 +31,23 @@ def forward(self, u0, grid):
b = self.trunk(grid)
# N x width
return torch.einsum('bi,ni->bn', a, b)


class SAWeight(nn.Module):
def __init__(self, out_dim, num_init: List, num_collo: List):
super(SAWeight, self).__init__()
self.init_param = nn.ParameterList(
[nn.Parameter(torch.rand(num, out_dim)) for num in num_init]
)

self.collo_param = nn.ParameterList(
[nn.Parameter(torch.rand(num, out_dim)) for num in num_collo]
)

def forward(self, init_cond: List, residual: List):
total_loss = 0.0
for param, init_loss in zip(self.init_param, init_cond):
total_loss += weighted_mse(init_loss, 0, param)
for param, res in zip(self.collo_param, residual):
total_loss += weighted_mse(res, 0, param)
return total_loss
25 changes: 24 additions & 1 deletion baselines/sapinns.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import csv
import random
from timeit import default_timer
from tqdm import tqdm
import deepxde as dde
import numpy as np
from baselines.data import NSdata
import torch

from tensordiffeq.boundaries import DomainND, periodicBC
from .tqd_utils import PointsIC
from .model import SAWeight

from models.FCN import DenseNet

Re = 500

Expand Down Expand Up @@ -103,11 +107,30 @@ def train_sapinn(offset, config, args):
domain.add('x', [0.0, 2 * np.pi], dataset.S)
domain.add('y', [0.0, 2 * np.pi], dataset.S)
domain.add('t', [0.0, data_config['time_interval']], dataset.T)
domain.generate_collocation_points(config['train']['num_domain'])
num_collo = config['train']['num_domain']
domain.generate_collocation_points(num_collo)
init_vals = dataset.get_init_cond()
num_inits = config['train']['num_init']
if num_inits > dataset.S ** 2:
num_inits = dataset.S ** 2
init_cond = PointsIC(domain, init_vals, var=['x', 'y'], n_values=num_inits)
bd_cond = periodicBC(domain, ['x', 'y'], n_values=config['train']['num_boundary'])

# prepare initial condition inputs
init_input = torch.tensor(init_cond.input, requires_grad=True)
init_val = torch.tensor(init_cond.val, requires_grad=True)

# prepare boundary condition inputs


weight_net = SAWeight(out_dim=3, num_init=[num_inits], num_collo=[num_collo] * 4)
net = DenseNet(config['model']['layers'], config['model']['activation'])

loader = tqdm(range(config['train']['epochs']), dynamic_ncols=True)

for e in loader:
pass




4 changes: 2 additions & 2 deletions baselines/tqd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def create_target(self, values):
# inp = flatten_and_stack(multimesh(arg_list))
# fun_vals.append(self.fun[i](*inp.T))
if self.n_values is not None:
self.val = convertTensor(np.reshape(values, (-1, 3))[self.nums])
self.val = np.reshape(values, (-1, 3))[self.nums]
else:
self.val = convertTensor(np.reshape(values, (-1, 3)))
self.val = np.reshape(values, (-1, 3))

def loss(self):
return MSE(self.preds, self.val)
7 changes: 7 additions & 0 deletions baselines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
import torch.autograd as autograd


def weighted_mse(pred, target, weight=None):
if weight is None:
return torch.mean((pred - target) ** 2)
else:
return torch.mean(weight * (pred - target) ** 2)


def get_3dboundary_points(num_x, # number of points on x axis
num_y, # number of points on y axis
num_t, # number of points on t axis
Expand Down
9 changes: 7 additions & 2 deletions models/FCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@ def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False)
super(DenseNet, self).__init__()

self.n_layers = len(layers) - 1

assert self.n_layers >= 1

if isinstance(nonlinearity, str):
if nonlinearity == 'tanh':
nonlinearity = nn.Tanh
elif nonlinearity == 'relu':
nonlinearity == nn.ReLU
else:
raise ValueError(f'{nonlinearity} is not supported')
self.layers = nn.ModuleList()

for j in range(self.n_layers):
Expand Down
4 changes: 2 additions & 2 deletions pinns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from baselines.pinns_ns_05s import train
from baselines.pinns_ns_50s import train_longtime
from baselines.tqd_sapinns import train_sa
from baselines.sapinns import train_sapinn
import csv


Expand All @@ -28,7 +28,7 @@
if 'time_scale' in config['data']:
train_longtime(i, config, args)
elif config['log']['group'] == 'SA-PINNs':
train_sa(i, config, args)
train_sapinn(i, config, args)
else:
train(i, config, args)

Expand Down

0 comments on commit 5aeed9d

Please sign in to comment.