Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fea] Support derivative nodes fusing for any order #745

Merged
26 changes: 13 additions & 13 deletions docs/zh/examples/biharmonic2d.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,17 @@ examples/biharmonic2d/biharmonic2d.py:93:95

``` py linenums="97"
--8<--
examples/biharmonic2d/biharmonic2d.py:97:107
examples/biharmonic2d/biharmonic2d.py:97:108
--8<--
```

#### 3.4.1 内部约束

以作用在背板内部点的 `InteriorConstraint` 为例,代码如下:

``` py linenums="206"
``` py linenums="207"
--8<--
examples/biharmonic2d/biharmonic2d.py:206:215
examples/biharmonic2d/biharmonic2d.py:207:216
--8<--
```

Expand Down Expand Up @@ -160,27 +160,27 @@ examples/biharmonic2d/conf/biharmonic2d.yaml:60:62

如 [2 问题定义](#2) 中所述,$x=0$ 处的挠度 $w$ 为 0,有如下边界条件,其他 7 个边界条件也与之类似:

``` py linenums="110"
``` py linenums="111"
--8<--
examples/biharmonic2d/biharmonic2d.py:110:119
examples/biharmonic2d/biharmonic2d.py:111:120
--8<--
```

在方程约束、边界约束构建完毕之后,以刚才的命名为关键字,封装到一个字典中,方便后续访问。

``` py linenums="216"
``` py linenums="217"
--8<--
examples/biharmonic2d/biharmonic2d.py:216:227
examples/biharmonic2d/biharmonic2d.py:217:228
--8<--
```

### 3.5 优化器构建

训练过程会调用优化器来更新模型参数,此处选择使用 `Adam` 先进行少量训练后,再使用 `LBFGS` 优化器精调。

``` py linenums="80"
``` py linenums="81"
--8<--
examples/biharmonic2d/biharmonic2d.py:80:82
examples/biharmonic2d/biharmonic2d.py:81:83
--8<--
```

Expand All @@ -198,19 +198,19 @@ examples/biharmonic2d/conf/biharmonic2d.yaml:46:56

完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练,注意两个优化过程需要分别构建 `Solver`。

``` py linenums="229"
``` py linenums="230"
--8<--
examples/biharmonic2d/biharmonic2d.py:229:268
examples/biharmonic2d/biharmonic2d.py:230:269
--8<--
```

### 3.8 模型评估和可视化

训练完成后,可以在 `eval` 模式中对训练好的模型进行评估和可视化。由于案例的特殊性,不需构建评估器和可视化器,而是使用自定义代码。

``` py linenums="271"
``` py linenums="272"
--8<--
examples/biharmonic2d/biharmonic2d.py:271:351
examples/biharmonic2d/biharmonic2d.py:272:352
--8<--
```

Expand Down
15 changes: 5 additions & 10 deletions docs/zh/examples/bracket.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,11 @@ examples/bracket/bracket.py:15:19

Bracket 案例涉及到以下线弹性方程,使用 PaddleScience 内置的 `LinearElasticity` 即可。

$$
\begin{cases}
stress\_disp_{xx} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial u}{\partial x} - \sigma_{xx} \\
stress\_disp_{yy} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial v}{\partial y} - \sigma_{yy} \\
stress\_disp_{zz} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial w}{\partial z} - \sigma_{zz} \\
traction_{x} = n_x \sigma_{xx} + n_y \sigma_{xy} + n_z \sigma_{xz} \\
traction_{y} = n_y \sigma_{yx} + n_y \sigma_{yy} + n_z \sigma_{yz} \\
traction_{z} = n_z \sigma_{zx} + n_y \sigma_{zy} + n_z \sigma_{zz} \\
\end{cases}
$$
--8<--
ppsci/equation/pde/linear_elasticity.py:30:42
--8<--

对应的方程实例化代码如下:

``` py linenums="32"
--8<--
Expand Down
3 changes: 2 additions & 1 deletion examples/biharmonic2d/biharmonic2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def train(cfg: DictConfig):
"drop_last": True,
"shuffle": True,
},
"num_workers": 1,
"num_workers": 0,
"auto_collation": False,
}

# set constraint
Expand Down
42 changes: 34 additions & 8 deletions ppsci/autodiff/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from __future__ import annotations

from typing import Dict
from typing import List
from typing import Optional
from typing import Union

import paddle

Expand All @@ -36,14 +38,19 @@ class _Jacobian:
xs (paddle.Tensor): Input Tensor of shape [batch_size, dim_x].
"""

def __init__(self, ys: "paddle.Tensor", xs: "paddle.Tensor"):
def __init__(
self,
ys: "paddle.Tensor",
xs: "paddle.Tensor",
J: Optional[Dict[int, paddle.Tensor]] = None,
):
self.ys = ys
self.xs = xs

self.dim_y = ys.shape[1]
self.dim_x = xs.shape[1]

self.J: Dict[str, paddle.Tensor] = {}
self.J: Dict[int, paddle.Tensor] = {} if J is None else J

def __call__(
self,
Expand Down Expand Up @@ -87,12 +94,12 @@ def __init__(self):
def __call__(
self,
ys: "paddle.Tensor",
xs: "paddle.Tensor",
xs: Union["paddle.Tensor", List["paddle.Tensor"]],
i: int = 0,
j: Optional[int] = None,
retain_graph: Optional[bool] = None,
create_graph: bool = True,
) -> "paddle.Tensor":
) -> Union["paddle.Tensor", List["paddle.Tensor"]]:
"""Compute jacobians for given ys and xs.

Args:
Expand Down Expand Up @@ -121,10 +128,29 @@ def __call__(
>>> y = x * x
>>> dy_dx = ppsci.autodiff.jacobian(y, x)
"""
key = (ys, xs)
if key not in self.Js:
self.Js[key] = _Jacobian(ys, xs)
return self.Js[key](i, j, retain_graph, create_graph)
if not isinstance(xs, (list, tuple)):
key = (ys, xs)
if key not in self.Js:
self.Js[key] = _Jacobian(ys, xs)
return self.Js[key](i, j, retain_graph, create_graph)
else:
grads = paddle.grad(
ys,
xs,
create_graph=create_graph,
retain_graph=retain_graph,
)
Js_list = []
for k, xs_ in enumerate(xs):
key = (ys, xs_)
assert xs_.shape[-1] == 1, (
f"The last dim of each xs should be 1, but xs[{k}] has shape "
f"{xs_.shape}"
)
if key not in self.Js:
self.Js[key] = _Jacobian(ys, xs_, {0: grads[k]})
Js_list.append(self.Js[key](i, j, retain_graph, create_graph))
return Js_list

def _clear(self):
"""Clear cached Jacobians."""
Expand Down
47 changes: 17 additions & 30 deletions ppsci/equation/pde/heat_exchanger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from typing import Union

from ppsci.autodiff import jacobian
from ppsci.equation.pde import base


Expand Down Expand Up @@ -69,37 +68,25 @@ def __init__(
w_c: Union[float, str],
):
super().__init__()
x, t, qm_h, qm_c, qm_w = self.create_symbols("x t qm_h qm_c qm_w")

def heat_boundary_fun(out):
x, t, qm_h = out["x"], out["t"], out["qm_h"]
T_h, T_w = out["T_h"], out["T_w"]
T_h_x = jacobian(T_h, x)
T_h_t = jacobian(T_h, t)
T_h = self.create_function("T_h", (x, t, qm_h))
T_c = self.create_function("T_c", (x, t, qm_c))
T_w = self.create_function("T_w", (x, t, qm_w))

beta_h = (alpha_h * v_h) / qm_h
heat_boundary = T_h_t + v_h * T_h_x - beta_h * (T_w - T_h)
return heat_boundary
T_h_x = T_h.diff(x)
T_h_t = T_h.diff(t)
T_c_x = T_c.diff(x)
T_c_t = T_c.diff(t)
T_w_t = T_w.diff(t)

self.add_equation("heat_boundary", heat_boundary_fun)
beta_h = (alpha_h * v_h) / qm_h
beta_c = (alpha_c * v_c) / qm_c

def cold_boundary_fun(out):
x, t, qm_c = out["x"], out["t"], out["qm_c"]
T_c, T_w = out["T_c"], out["T_w"]
T_c_x = jacobian(T_c, x)
T_c_t = jacobian(T_c, t)
heat_boundary = T_h_t + v_h * T_h_x - beta_h * (T_w - T_h)
cold_boundary = T_c_t - v_c * T_c_x - beta_c * (T_w - T_c)
wall = T_w_t - w_h * (T_h - T_w) - w_c * (T_c - T_w)

beta_c = (alpha_c * v_c) / qm_c
cold_boundary = T_c_t - v_c * T_c_x - beta_c * (T_w - T_c)
return cold_boundary

self.add_equation("cold_boundary", cold_boundary_fun)

def wall_fun(out):
t = out["t"]
T_c, T_h, T_w = out["T_c"], out["T_h"], out["T_w"]
T_w_t = jacobian(T_w, t)

wall = T_w_t - w_h * (T_h - T_w) - w_c * (T_c - T_w)
return wall

self.add_equation("wall", wall_fun)
self.add_equation("heat_boundary", heat_boundary)
self.add_equation("cold_boundary", cold_boundary)
self.add_equation("wall", wall)
9 changes: 6 additions & 3 deletions ppsci/equation/pde/linear_elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ class LinearElasticity(base.PDE):
stress\_disp_{xx} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial u}{\partial x} - \sigma_{xx} \\
stress\_disp_{yy} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial v}{\partial y} - \sigma_{yy} \\
stress\_disp_{zz} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial w}{\partial z} - \sigma_{zz} \\
traction_{x} = n_x \sigma_{xx} + n_y \sigma_{xy} + n_z \sigma_{xz} \\
traction_{y} = n_y \sigma_{yx} + n_y \sigma_{yy} + n_z \sigma_{yz} \\
traction_{z} = n_z \sigma_{zx} + n_y \sigma_{zy} + n_z \sigma_{zz} \\
stress\_disp_{xy} = \mu(\dfrac{\partial u}{\partial y} + \dfrac{\partial v}{\partial x}) - \sigma_{xy} \\
stress\_disp_{xz} = \mu(\dfrac{\partial u}{\partial z} + \dfrac{\partial w}{\partial x}) - \sigma_{xz} \\
stress\_disp_{yz} = \mu(\dfrac{\partial v}{\partial z} + \dfrac{\partial w}{\partial y}) - \sigma_{yz} \\
equilibrium_{x} = \rho \dfrac{\partial^2 u}{\partial t^2} - (\dfrac{\partial \sigma_{xx}}{\partial x} + \dfrac{\partial \sigma_{xy}}{\partial y} + \dfrac{\partial \sigma_{xz}}{\partial z}) \\
equilibrium_{y} = \rho \dfrac{\partial^2 u}{\partial t^2} - (\dfrac{\partial \sigma_{xy}}{\partial x} + \dfrac{\partial \sigma_{yy}}{\partial y} + \dfrac{\partial \sigma_{yz}}{\partial z}) \\
equilibrium_{z} = \rho \dfrac{\partial^2 u}{\partial t^2} - (\dfrac{\partial \sigma_{xz}}{\partial x} + \dfrac{\partial \sigma_{yz}}{\partial y} + \dfrac{\partial \sigma_{zz}}{\partial z}) \\
\end{cases}
$$

Expand Down
26 changes: 18 additions & 8 deletions ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,24 @@ def convert_expr(
]
) -> None:
for container in container_dict.values():
for name, expr in container.output_expr.items():
if isinstance(expr, sp.Basic):
container.output_expr[name] = ppsci.lambdify(
expr,
self.model,
extra_parameters,
# osp.join(self.output_dir, "symbolic_graph_visual", container.name, name), # HACK: Activate it for DEBUG.
)
exprs = [
expr
for expr in container.output_expr.values()
if isinstance(expr, sp.Basic)
]
if len(exprs) > 0:
funcs = ppsci.lambdify(
exprs,
self.model,
extra_parameters=extra_parameters,
fuse_derivative=True,
# graph_filename=osp.join(self.output_dir, "symbolic_graph_visual") # HACK: Activate it for DEBUG.
)
ind = 0
for name in container.output_expr:
if isinstance(container.output_expr[name], sp.Basic):
container.output_expr[name] = funcs[ind]
ind += 1

if self.constraint:
convert_expr(self.constraint)
Expand Down
Loading