Skip to content

Latest commit

 

History

History

221118-pinns

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Physics Informed Neural Networks

1 Introduction

1.0.1 Finding the inverse function of a parabola

Given a function \(\mathcal{P}: y → y2\), where \(y ∈ [0, 1]\), find a unknow function \(f\) that satisfies \(\mathcal{P}(f(x)) = x,\ ∀ x ∈ [0, 1]\).

1.0.1.1 PIC

./p1.png

1.0.2 Classical

1.0.2.1 MLP

A classical approach is to use a neural network to approximate the data points \(Ω\):

\[fθ(x) ≈ y,\ ∀ (x, y) ∈ Ω\]

where \(θ\) is the parameters of the neural network.

However, nowhere near the solution.

1.0.2.2 Results

./p2.png

1.0.3 Physics approach

1.0.3.1 PINN

\[\mathcal{P}(fθ(x)) ≈ y,\ ∀ (x, y) ∈ Ω\]

Now, minimize the error between \(f2θ(x)\) and \(y\).

1.0.3.2 Results

./p3.png

1.0.4 Physics Informed Neural Networks (PINNs) Definition

Neural networks that are trained to solve supervised learning tasks while respecting any given law of physics described by general nonlinear partial differential equations (PDE).

./p4.png

2 Partial differential equations (PDE)

2.1 Introduction

2.1.1 What is PDE?

  • Equation containing unknown functions and its partial derivative.
  • Describe the relationship between independent variables, unknown functions and partial derivative.

2.1.1.1 Example

  • \(f(x, y) = ax + by + c\), where \(a, b, c\) are unknown parameters.
  • \(u(x, y) = α u(x, y) + β f(x, y)\) where \(u\) is the unknown function.
  • \(ux(x, y) = α uy(x, y) + β fxy(x, y)\) where \(ux\) is the partial derivative of \(u\) with respect to \(x\), \(uy\) is the partial derivative of \(u\) with respect to \(y\), and \(fxy\) is the partial derivative of \(f\) with respect to \(x\) and \(y\).

2.1.2 Notations

  • \(\dot{u} = \frac{∂{u}}{∂{t}}\)
  • \(uxy = {∂ 2u \over ∂ y\,∂ x}\)
  • \(∇ u (x, y, z) = ux + uy + uz\)
  • \(∇ ⋅ ∇ u(x, y, z) = Δ u(x,y,z) = uxx + uyy + uzz\)
  • \(∇\) : nabla, or del.

2.2 PDE in the real world

2.2.1 Laplace’s equation

\[Δ \varphi = 0\]

or

\[∇ ⋅ ∇ \varphi = 0\]

or, in a 3D space:

\[\frac {∂ 2f}{∂ x2}+\frac {∂ 2f}{∂ y2}+\frac {∂ 2f}{∂ z2} = 0\]

2.2.2 Poisson’s equation

\[Δ \varphi = f\]

./p5.png

2.2.3 Heat equation

\[\dot{u} = \frac {∂ u}{∂ t} = α Δ u\]

where \(α\) is the thermal diffusivity.

2.2.4 Wave equation

\[\ddot {u}=c22u\]

where \(c\) is the wave speed.

2.2.5 Burgers’ equation

\[u_t + u u_x = ν uxx\]

\(t\)
temporal coordinate
\(x\)
spatial coordinate
\(u(x, t)\)
speed of fluid at the indicated spatial and temporal coordinates
\(ν\)
viscosity of fluid

2.3 Boundary conditions

2.3.1 Boundary conditions

For a equation \(∇2y+y=0\) in domain \(Ω\).

  • Dirichlet boundary condition: \(y(x)=f(x)\quad ∀ x∈ ∂ Ω\)
  • Neumann boundary condition: \(\frac {∂ y}{∂ \mathbf {n} }(\mathbf {x} )=f(\mathbf {x} )\quad ∀ \mathbf {x} ∈ ∂ Ω\)
    • Where \(f\) is a known scalar function defined on the boundary domain \(∂ Ω\), \(\mathbf{n}\) denotes the (typically exterior) normal to the boundary.
    • The normal derivative, which shows up on the left side, is defined as \(\frac {∂ y}{∂ \mathbf {n} }(\mathbf {x} )=∇ y(\mathbf {x} )⋅ \mathbf {\hat {n}} (\mathbf {x} )\), where \(\mathbf {\hat {n}}\) is the unit normal.
  • Robin boundary condition
    • Combine Dirichlet and Neumann boundary conditions.
  • Periodic boundary condition

3 PINNs

3.0.1 Paper

  • Physics Informed Deep Learning (Part I): Data-driven Solutions of Nonlinear Partial Differential Equations[cite/ft/f:@raissiPhysicsInformedDeep2017]
  • Physics Informed Deep Learning (Part II): Data-driven Discovery of Nonlinear Partial Differential Equations[cite/ft/f:@raissiPhysicsInformedDeep2017a]

3.0.2 Problem

  • Data-driven solution and data-driven discovery
  • Continuous time and discrete time models

3.1 Data-driven solution with continuous time

3.1.1 Data-driven solution with continuous time

General PDE Form:

\[u_t + \mathcal{N}[u] = 0,\ x ∈ Ω, \ t∈[0,T]\]

where:

\(\mathcal{N}[u]\)
nonlinear differential operator
\(u(t, x)\)
unknown function (solution).
\(Ω\)
spatial domain.
\(t\)
time.

3.1.2 Physics informed neural network

  • A neural network \(uθ ≈ u\), where \(θ\) is the parameters of the neural network.
  • A physics informed neural network \(fθ = {uθ}t + \mathcal{N}[uθ]\).
  • Target: \(fθ ≈ ut + \mathcal{N}[u]\) and \(uθ ≈ u\).
    • \(\mathcal{L} = \mathcal{L}f + \mathcal{L}u\)

./p6.png

3.1.3 Example (Burgers’ Equation)

The equation:

\[u_t + u u_x = ν uxx\]

Here, already know \(ν = 0.01 / π\), \(x ∈ [-1, 1], t ∈ [0, 1]\),

Thus,

\[ut + uux - 0.01/π uxx = 0 \]

And the equation along with Dirichlet boundary conditions can be written as:

  • \(u(0, x) = -sin(π x)\)
  • \(u(t, -1) = u(t, 1) = 0\)

3.1.4 Target

  • Data:
    • Boundary only data from boundary conditions.
  • Input: \(\{t, x\}\)
  • Output: \(u(t, x)\)
  • Target: \(fθ ≈ ut + \mathcal{N}[u]\) and \(uθ ≈ u\).
    • \(\mathcal{L} = \mathcal{L}f + \mathcal{L}u\)

3.1.5 Example (Burgers’ Equation) with codes

def u_theta(theta, t, x):
    # u_theta.apply(theta, t, x) to approx u(x, t)
    return net(theta, t, x)

def f_theta(theta, t, x):
    # See the auto diff cookbook
    # https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
    u = u_theta.apply
    u_t = jax.jacrev(u, argnums=1)(theta, t, x)
    u_x = jax.jacrev(u, argnums=2)(theta, t, x)
    u_xx = jax.hessian(u, argnums=2)(theta, t, x)
    # or jax.jacfwd(jax.jacrev(u, argnums=2), argnums=2)
    f = lambda: u_t + u * u_x - 0.01 * u_xx
    return f

3.1.6 Train and Results

  • Train with MLPs with L-BFGS solver (quasi-newton method).
  • Cannot use ReLU but tanh, because when we do the second order derivative, the ReLU will be 0.

./p7.png

3.2 Data-driven discovery with continuous time

3.2.1 Data-driven discovery with continuous time

General PDE Form:

\[u_t + \mathcal{N}[u;λ] = 0,\ x ∈ Ω, \ t∈[0,T]\]

where:

\(\mathcal{N}[u;λ]\)
nonlinear differential operator with parameters \(λ\).
\(u(t, x)\)
unknown function (solution).
\(Ω\)
spatial domain.
\(t\)
time.

3.2.2 Example (Incompressible Navier-Stokes Equation (convection–diffusion equations))

The equations:

\[u_t + λ_1 (u u_x + v u_y) = -p_x + λ_2(uxx + uyy),\] \[v_t + λ_1 (u v_x + v v_y) = -p_y + λ_2(vxx + vyy)\],

where:

\(u(t, x, y)\)
\(x\)-component of the velocity field,
\(v(t, x, y)\)
\(y\)-component of the velocity field,
\(p(t, x, y)\)
pressure,
\(λ\)
the unknown parameters.

Additional physical constraints:

  • Solutions to the Navier-Stokes equations are searched in the set of divergence-free functions, i.e.:
    • \(ux + uy = 0\)
    • which describes the conservation of mass of the fluid
  • \(u\) and \(v\) can written as a latent function \(ψ(t, x, y)\) with an assumption:
    • \(u = ψy, v = -ψx\)

3.2.3 NS Equation figure

./p8.png

3.2.4 Example (Navier-Stokes Equation) – Target

  • The neural network equations:
    • \(f := u_t + λ_1 (u u_x + v u_y) + p_x - λ_2(uxx + uyy),\)
    • \(g := v_t + λ_1 (u v_x + v v_y) + p_y - λ_2(vxx + vyy)\)
  • Inptu: \(\{t,x,y,u,v\}\) with noisy.
  • Output: \((ψ(t, x, y), p(t, x, y))\).
  • Target:
    • \(fθ ≈ f \)
    • \(gθ ≈ g\)
    • \(uθ ≈ u\)
    • \(vθ ≈ v\)

3.2.5 Results

./p9.png

4 JAX

4.0.1 Introduction

JAX is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more.

4.0.2 Pure functional

  • \(f(x) = y\), always.
  • non-pure function:
    • IO operator: print
    • No seed random function
    • time
    • Runtime error.

4.0.3 Ecosystem

  • JAX (jax, jaxlib)
    • jax
    • jax.numpy
  • Haiku (dm-haiku) from deepmind
    • Modules
  • Optax (optax) from deepmind
    • Light
    • Linear system optimizers (\(Ax = b\))
  • JAXopt (jaxopt)
    • Other optimizers.
  • Jraph (jraph)
    • Standardized data structures for graphs.
  • JAX, M.D. (jax-md)
    • JAX and Molecular Dynamics
  • RLax (rlax), and Coax (coax)
    • Reinforcement Learning

4.0.4 Example (def)

import jax
import jax.numpy as jnp
import haiku as hk

def _u(t, x):
    return hk.MLP(jnp.concatenate([t, x], axis=-1), [10, 10, 1])

u = hk.transform_with_state(_u)

4.0.5 Example (init)

fake_t = jnp.ones([batch, size])
fake_x = jnp.ones([batch, size])

# theta: params
# state: training state
# rng:   random number generator
params, state = u.init(rng, fake_t, fake_x)

hk.experimental.tabulate(u)(fake_t, fake_x)

4.0.6 Example (loss)

def loss_fn(config, ...):

    def _loss(params, t, x):
        u_theta = u.apply(params, t, x)
        ...
        loss = _f
        return loss

    return _loss

loss = loss_fn(config, ...)

4.0.7 Example (optim)

import optax

lr = optax.linear_schedule(
    0.001,       # init
    0.001 / 10,  # final
    1,           # steps change to final
    150          # start linear decay after steps
)

opt = optax.adam(learning_rate=lr)
opt = optax.adamax(learning_rate=lr)

4.0.8 Example (solver)

import jaxopt

# Linear solver
solver = jaxopt.OptaxSolver(
    loss,
    opt,
    maxiter=epochs,
    ...
)

# non-linear solver
solver = jaxopt.LBFGS(
    loss,
    maxiter=epochs,
    ...
)

opt_state = solver.init(params, state)
update = solver.update

4.0.9 Example (train)

# init
params, state, opt_state, update


for batch in data:
    params, state = update(params, state, batch)

4.0.10 Example (parallel)

# Use pjit
from jax.experimental.maps import Mesh, ResourceEnv, thread_resources
from jax.experimental.pjit import PartitionSpec, pjit

mesh = Mesh(np.asarray(jax.devices(), dtype=object), ["data", ...])
thread_resources.env = ResourceEnv(physical_mesh=mesh, loops=())

update = pjit(
    solver.update,
    in_axis_resources=[
        None,  # params
        None,  # state
        PartitionSpec("data"),  # batch
    ],
    out_axis_resources=None,
)

5 Conclusion

5.0.1 Conclusion

  • Find an inverse function of a parabola
    • Classical
    • Physics informed
  • PDE
    • PDE example
    • PDE boundary
  • PINNs
    • Data-driven solution with continuous time
      • Burgers’ equation
    • Data-driven discovery with continuous time
      • Navier-Stokes equation

6 Refs

6.0.1 Refs