Skip to content

Commit

Permalink
Add Paddle as a new backend of DeepXDE (#562)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndPuQing authored May 20, 2022
1 parent bd56d8d commit 770e7c1
Show file tree
Hide file tree
Showing 36 changed files with 428 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ docs/_build/
.ipynb_checkpoints

# VSCode
.vscode/
.vscode/
4 changes: 2 additions & 2 deletions deepxde/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _missing_api(*args, **kwargs):


def load_backend(mod_name):
if mod_name not in ["tensorflow.compat.v1", "tensorflow", "pytorch", "jax"]:
if mod_name not in ["tensorflow.compat.v1", "tensorflow", "pytorch", "jax", "paddle"]:
raise NotImplementedError("Unsupported backend: %s" % mod_name)

print("Using backend: %s\n" % mod_name, file=sys.stderr, flush=True)
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_preferred_backend():
config_dict = json.load(config_file)
backend_name = config_dict.get("backend", "").lower()

if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "jax"]:
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "jax", "paddle"]:
return backend_name
print(
"Deepxde backend not selected or invalid. Assuming tensorflow.compat.v1 for now.",
Expand Down
1 change: 1 addition & 0 deletions deepxde/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
tf = None
torch = None
jax = None
paddle = None

###############################################################################
# Tensor, data type and context interfaces
Expand Down
1 change: 1 addition & 0 deletions deepxde/backend/paddle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tensor import * # pylint: disable=redefined-builtin
110 changes: 110 additions & 0 deletions deepxde/backend/paddle/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""paddle backend implementation"""
import paddle


if paddle.device.is_compiled_with_cuda():
paddle.device.set_device("gpu")

lib = paddle


def data_type_dict():
return {
"float16": paddle.float16,
"float32": paddle.float32,
"float64": paddle.float64,
"uint8": paddle.uint8,
"int8": paddle.int8,
"int16": paddle.int16,
"int32": paddle.int32,
"int64": paddle.int64,
"bool": paddle.bool,
}


def is_tensor(obj):
return paddle.is_tensor(obj)


def shape(input_tensor):
return input_tensor.shape


def ndim(input_tensor):
return input_tensor.ndim


def Variable(initial_value, dtype=None):
return paddle.to_tensor(initial_value, dtype=dtype, stop_gradient=False)


def as_tensor(data, dtype=None):
if paddle.is_tensor(data):
if dtype is None or data.dtype == dtype:
return data
return data.astype(dtype)
return paddle.to_tensor(data, dtype=dtype)


def from_numpy(np_array):
return paddle.to_tensor(np_array)


def to_numpy(input_tensor):
return input_tensor.detach().cpu().numpy()


def elu(x):
return paddle.nn.functional.elu(x)


def relu(x):
return paddle.nn.functional.relu(x)


def selu(x):
return paddle.nn.functional.selu(x)


def sigmoid(x):
return paddle.nn.functional.sigmoid(x)


def silu(x):
return paddle.nn.functional.silu(x)


def sin(x):
return paddle.sin(x)


def square(x):
return paddle.square(x)


def tanh(x):
return paddle.tanh(x)


def mean(input_tensor, dim, keepdims=False):
return paddle.mean(input_tensor, axis=dim, keepdim=keepdims)


def reduce_mean(input_tensor):
return paddle.mean(input_tensor)


def sum(input_tensor, dim, keepdims=False):
return paddle.sum(input_tensor, axis=dim, keepdim=keepdims)


def reduce_sum(input_tensor):
return paddle.sum(input_tensor)


def zeros(shape, dtype):
return paddle.zeros(shape, dtype=dtype)


def zeros_like(input_tensor):
return paddle.zeros_like(input_tensor)
4 changes: 2 additions & 2 deletions deepxde/backend/set_default_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def set_default_backend(backend_name):
print(
'Setting the default backend to "{}". You can change it in the '
"~/.deepxde/config.json file or export the DDEBACKEND environment variable. "
"Valid options are: tensorflow.compat.v1, tensorflow, pytorch, jax (all lowercase)".format(
"Valid options are: tensorflow.compat.v1, tensorflow, pytorch, jax, paddle (all lowercase)".format(
backend_name
)
)
Expand All @@ -25,7 +25,7 @@ def set_default_backend(backend_name):
"backend",
nargs=1,
type=str,
choices=["tensorflow.compat.v1", "tensorflow", "pytorch", "jax"],
choices=["tensorflow.compat.v1", "tensorflow", "pytorch", "jax", "paddle"],
help="Set default backend",
)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion deepxde/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def on_train_begin(self):
self.value = self.model.sess.run(self.var_list)
elif backend_name == "tensorflow":
self.value = [var.numpy() for var in self.var_list]
elif backend_name == "pytorch":
elif backend_name in ["pytorch", "paddle"]:
self.value = [var.detach().item() for var in self.var_list]
print(
self.model.train_state.epoch,
Expand Down
4 changes: 3 additions & 1 deletion deepxde/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from .backend import backend_name, tf, torch
from .backend import backend_name, tf, torch, paddle
from .real import Real

random_seed = None
Expand Down Expand Up @@ -62,5 +62,7 @@ def set_random_seed(seed):
elif backend_name == "jax":
global jax_random_seed
jax_random_seed = seed
elif backend_name == "paddle":
paddle.seed(seed)
global random_seed
random_seed = seed
3 changes: 2 additions & 1 deletion deepxde/data/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
self.test()

def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch"]:
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
outputs_pde = outputs
elif backend_name == "jax":
# JAX requires pure functions
Expand All @@ -152,6 +152,7 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
)

bcs_start = np.cumsum([0] + self.num_bcs)
bcs_start = list(map(int, bcs_start))
error_f = [fi[bcs_start[-1] :] for fi in f]
losses = [
loss_fn[i](bkd.zeros_like(error), error) for i, error in enumerate(error_f)
Expand Down
15 changes: 9 additions & 6 deletions deepxde/gradients.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["jacobian", "hessian"]

from .backend import backend_name, tf, torch, jax
from .backend import backend_name, tf, torch, jax, paddle


class Jacobian:
Expand All @@ -18,7 +18,7 @@ def __init__(self, ys, xs):
self.ys = ys
self.xs = xs

if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch"]:
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
self.dim_y = ys.shape[1]
elif backend_name == "jax":
# For backend jax, a tuple of a jax array and a callable is passed as one of
Expand Down Expand Up @@ -50,6 +50,9 @@ def __call__(self, i=0, j=None):
self.J[i] = torch.autograd.grad(
y, self.xs, grad_outputs=torch.ones_like(y), create_graph=True
)[0]
elif backend_name == "paddle":
y = self.ys[:, i : i + 1] if self.dim_y > 1 else self.ys
self.J[i] = paddle.grad(y, self.xs, create_graph=True)[0]
elif backend_name == "jax":
# Here, we use jax.grad to compute the gradient of a function. This is
# different from TensorFlow and PyTorch that the input of a function is
Expand All @@ -67,7 +70,7 @@ def __call__(self, i=0, j=None):
grad_fn = jax.grad(lambda x: self.ys[1](x)[i])
self.J[i] = (jax.vmap(grad_fn)(self.xs), grad_fn)

if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch"]:
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
return (
self.J[i] if j is None or self.dim_x == 1 else self.J[i][:, j : j + 1]
)
Expand Down Expand Up @@ -141,7 +144,7 @@ def __call__(self, ys, xs, i=0, j=None):
# f(x)
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
key = (ys.ref(), xs.ref())
elif backend_name == "pytorch":
elif backend_name in ["pytorch", "paddle"]:
key = (ys, xs)
elif backend_name == "jax":
key = (id(ys[0]), id(xs))
Expand Down Expand Up @@ -197,7 +200,7 @@ class Hessian:
"""

def __init__(self, y, xs, component=None, grad_y=None):
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch"]:
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
dim_y = y.shape[1]
elif backend_name == "jax":
dim_y = y[0].shape[0]
Expand Down Expand Up @@ -239,7 +242,7 @@ def __init__(self):
def __call__(self, y, xs, component=None, i=0, j=0, grad_y=None):
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
key = (y.ref(), xs.ref(), component)
elif backend_name == "pytorch":
elif backend_name in ["pytorch", "paddle"]:
key = (y, xs, component)
elif backend_name == "jax":
key = (id(y[0]), id(xs), component)
Expand Down
2 changes: 1 addition & 1 deletion deepxde/icbc/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def wrapper_cache_auxiliary(X, beg, end, aux_var):
return wrapper_nocache
if utils.get_num_args(func) == 2:
return wrapper_nocache_auxiliary
if backend_name == "pytorch":
if backend_name in ["pytorch", "paddle"]:
if utils.get_num_args(func) == 1:
return wrapper_cache
if utils.get_num_args(func) == 2:
Expand Down
Loading

0 comments on commit 770e7c1

Please sign in to comment.