Skip to content

Commit

Permalink
【PPSCI Export&Infer No.2】Add export & inference for DeepONet (PaddleP…
Browse files Browse the repository at this point in the history
…addle#901)

* add DeepONet export and infer

* update docstring of geometry

* Update deeponet.py

* Update deeponet.py
  • Loading branch information
HydrogenSulfate authored May 16, 2024
1 parent 25725c6 commit f874bf2
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 73 deletions.
12 changes: 12 additions & 0 deletions docs/zh/examples/deeponet.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@
python deeponet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/deeponet/deeponet_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python deeponet.py mode=export
```

=== "模型推理命令"

``` sh
python deeponet.py mode=infer
```

| 预训练模型 | 指标 |
|:--| :--|
| [deeponet_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/deeponet/deeponet_pretrained.pdparams) | loss(G_eval): 0.00003<br>L2Rel.G(G_eval): 0.01799 |
Expand Down
18 changes: 18 additions & 0 deletions examples/operator_learning/conf/deeponet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,21 @@ TRAIN:
EVAL:
pretrained_model_path: null
eval_with_no_grad: true

# inference settings
INFER:
pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/deeponet/deeponet_pretrained.pdparams"
export_path: ./inference/deeponet
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
device: gpu
engine: native
precision: fp32
onnx_path: ${INFER.export_path}.onnx
ir_optim: true
min_subgraph_size: 10
gpu_mem: 4000
gpu_id: 0
max_batch_size: 128
num_cpu_threads: 4
batch_size: 128
120 changes: 59 additions & 61 deletions examples/operator_learning/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,66 +89,10 @@ def train(cfg: DictConfig):
# evaluate after finished training
solver.eval()

# visualize prediction for different functions u and corresponding G(u)
dtype = paddle.get_default_dtype()

def generate_y_u_G_ref(
u_func: Callable, G_u_func: Callable
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Generate discretized data of given function u and corresponding G(u).
Args:
u_func (Callable): Function u.
G_u_func (Callable): Function G(u).
def predict_func(input_dict):
return solver.predict(input_dict, return_numpy=True)[cfg.MODEL.G_key]

Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: Discretized data of u, y and G(u).
"""
x = np.linspace(0, 1, cfg.MODEL.num_loc, dtype=dtype).reshape(
[1, cfg.MODEL.num_loc]
)
u = u_func(x)
u = np.tile(u, [cfg.NUM_Y, 1])

y = np.linspace(0, 1, cfg.NUM_Y, dtype=dtype).reshape([cfg.NUM_Y, 1])
G_ref = G_u_func(y)
return u, y, G_ref

func_u_G_pair = [
# (title_string, func_u, func_G(u)), s.t. dG/dx == u and G(u)(0) = 0
(r"$u=\cos(x), G(u)=sin(x$)", lambda x: np.cos(x), lambda y: np.sin(y)), # 1
(
r"$u=sec^2(x), G(u)=tan(x$)",
lambda x: (1 / np.cos(x)) ** 2,
lambda y: np.tan(y),
), # 2
(
r"$u=sec(x)tan(x), G(u)=sec(x) - 1$",
lambda x: (1 / np.cos(x) * np.tan(x)),
lambda y: 1 / np.cos(y) - 1,
), # 3
(
r"$u=1.5^x\ln{1.5}, G(u)=1.5^x-1$",
lambda x: 1.5**x * np.log(1.5),
lambda y: 1.5**y - 1,
), # 4
(r"$u=3x^2, G(u)=x^3$", lambda x: 3 * x**2, lambda y: y**3), # 5
(r"$u=4x^3, G(u)=x^4$", lambda x: 4 * x**3, lambda y: y**4), # 6
(r"$u=5x^4, G(u)=x^5$", lambda x: 5 * x**4, lambda y: y**5), # 7
(r"$u=6x^5, G(u)=x^6$", lambda x: 5 * x**4, lambda y: y**5), # 8
(r"$u=e^x, G(u)=e^x-1$", lambda x: np.exp(x), lambda y: np.exp(y) - 1), # 9
]

os.makedirs(os.path.join(cfg.output_dir, "visual"), exist_ok=True)
for i, (title, u_func, G_func) in enumerate(func_u_G_pair):
u, y, G_ref = generate_y_u_G_ref(u_func, G_func)
G_pred = solver.predict({"u": u, "y": y}, return_numpy=True)["G"]
plt.plot(y, G_pred, label=r"$G(u)(y)_{ref}$")
plt.plot(y, G_ref, label=r"$G(u)(y)_{pred}$")
plt.legend()
plt.title(title)
plt.savefig(os.path.join(cfg.output_dir, "visual", f"func_{i}_result.png"))
plt.clf()
plot(cfg, predict_func)


def evaluate(cfg: DictConfig):
Expand Down Expand Up @@ -189,6 +133,50 @@ def evaluate(cfg: DictConfig):
)
solver.eval()

def predict_func(input_dict):
return solver.predict(input_dict, return_numpy=True)[cfg.MODEL.G_key]

plot(cfg, predict_func)


def export(cfg: DictConfig):
# set model
model = ppsci.arch.DeepONet(**cfg.MODEL)

# initialize solver
solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)

# export model
from paddle.static import InputSpec

input_spec = [
{
model.input_keys[0]: InputSpec(
[None, 1000], "float32", name=model.input_keys[0]
),
model.input_keys[1]: InputSpec(
[None, 1], "float32", name=model.input_keys[1]
),
}
]
solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
from deploy import python_infer

predictor = python_infer.GeneralPredictor(cfg)

def predict_func(input_dict):
return next(iter(predictor.predict(input_dict).values()))

plot(cfg, predict_func)


def plot(cfg: DictConfig, predict_func: Callable):
# visualize prediction for different functions u and corresponding G(u)
dtype = paddle.get_default_dtype()

Expand Down Expand Up @@ -242,13 +230,17 @@ def generate_y_u_G_ref(
os.makedirs(os.path.join(cfg.output_dir, "visual"), exist_ok=True)
for i, (title, u_func, G_func) in enumerate(func_u_G_pair):
u, y, G_ref = generate_y_u_G_ref(u_func, G_func)
G_pred = solver.predict({"u": u, "y": y}, return_numpy=True)["G"]
G_pred = predict_func({"u": u, "y": y})
plt.plot(y, G_pred, label=r"$G(u)(y)_{ref}$")
plt.plot(y, G_ref, label=r"$G(u)(y)_{pred}$")
plt.legend()
plt.title(title)
plt.savefig(os.path.join(cfg.output_dir, "visual", f"func_{i}_result.png"))
logger.message(
f"Saved result of function {i} to {cfg.output_dir}/visual/func_{i}_result.png"
)
plt.clf()
plt.close()


@hydra.main(version_base=None, config_path="./conf", config_name="deeponet.yaml")
Expand All @@ -257,8 +249,14 @@ def main(cfg: DictConfig):
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down
45 changes: 33 additions & 12 deletions ppsci/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import numpy as np
import paddle
from typing_extensions import Literal

from ppsci.utils import logger
from ppsci.utils import misc
Expand Down Expand Up @@ -129,17 +130,22 @@ def uniform_points(self, n: int, boundary: bool = True) -> np.ndarray:
def sample_interior(
self,
n: int,
random: str = "pseudo",
criteria: Optional[Callable] = None,
random: Literal["pseudo", "Halton", "LHS"] = "pseudo",
criteria: Optional[Callable[..., np.ndarray]] = None,
evenly: bool = False,
compute_sdf_derivatives: bool = False,
) -> Dict[str, np.ndarray]:
"""Sample random points in the geometry and return those meet criteria.
Args:
n (int): Number of points.
random (str): Random method. Defaults to "pseudo".
criteria (Optional[Callable]): Criteria function. Defaults to None.
random (Literal["pseudo", "Halton", "LHS"]): Random method. Defaults to "pseudo".
pseudo: Pseudo random.
Halton: Halton sequence.
LHS: Latin Hypercube Sampling.
criteria (Optional[Callable[..., np.ndarray]]): Criteria function. Given
coords from differnet dimension and return a boolean array with shape [n,].
Defaults to None.
evenly (bool): Evenly sample points. Defaults to False.
compute_sdf_derivatives (bool): Compute SDF derivatives. Defaults to False.
Expand Down Expand Up @@ -226,16 +232,21 @@ def sample_interior(
def sample_boundary(
self,
n: int,
random: str = "pseudo",
criteria: Optional[Callable] = None,
random: Literal["pseudo", "Halton", "LHS"] = "pseudo",
criteria: Optional[Callable[..., np.ndarray]] = None,
evenly: bool = False,
) -> Dict[str, np.ndarray]:
"""Compute the random points in the geometry and return those meet criteria.
Args:
n (int): Number of points.
random (str): Random method. Defaults to "pseudo".
criteria (Optional[Callable]): Criteria function. Defaults to None.
random (Literal["pseudo", "Halton", "LHS"]): Random method. Defaults to "pseudo".
pseudo: Pseudo random.
Halton: Halton sequence.
LHS: Latin Hypercube Sampling.
criteria (Optional[Callable[..., np.ndarray]]): Criteria function. Given
coords from differnet dimension and return a boolean array with shape [n,].
Defaults to None.
evenly (bool): Evenly sample points. Defaults to False.
Returns:
Expand Down Expand Up @@ -332,12 +343,17 @@ def sample_boundary(
return {**x_dict, **normal_dict}

@abc.abstractmethod
def random_points(self, n: int, random: str = "pseudo") -> np.ndarray:
def random_points(
self, n: int, random: Literal["pseudo", "Halton", "LHS"] = "pseudo"
) -> np.ndarray:
"""Compute the random points in the geometry.
Args:
n (int): Number of points.
random (str): Random method. Defaults to "pseudo".
random (Literal["pseudo", "Halton", "LHS"]): Random method. Defaults to "pseudo".
pseudo: Pseudo random.
Halton: Halton sequence.
LHS: Latin Hypercube Sampling.
Returns:
np.ndarray: Random points in the geometry. The shape is [N, D].
Expand Down Expand Up @@ -379,12 +395,17 @@ def uniform_boundary_points(self, n: int) -> np.ndarray:
return self.random_boundary_points(n)

@abc.abstractmethod
def random_boundary_points(self, n: int, random: str = "pseudo") -> np.ndarray:
def random_boundary_points(
self, n: int, random: Literal["pseudo", "Halton", "LHS"] = "pseudo"
) -> np.ndarray:
"""Compute the random points on the boundary.
Args:
n (int): Number of points.
random (str): Random method. Defaults to "pseudo".
random (Literal["pseudo", "Halton", "LHS"]): Random method. Defaults to "pseudo".
pseudo: Pseudo random.
Halton: Halton sequence.
LHS: Latin Hypercube Sampling.
Returns:
np.ndarray: Random points on the boundary. The shape is [N, D].
Expand Down

0 comments on commit f874bf2

Please sign in to comment.