Skip to content

Commit

Permalink
Docstrings for utils
Browse files Browse the repository at this point in the history
  • Loading branch information
ddrous committed Apr 14, 2024
1 parent a921820 commit f358d45
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 51 deletions.
11 changes: 9 additions & 2 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@ quartodoc:
sidebar: _sidebar.yml

sections:
- title: Some functions
- title: Utility Functions
desc: Functions to inspect docstrings.
contents:
# the functions being documented in the package.
# you can refer to anything: class methods, modules, etc..
- operators.pde_solver
- utils.RK4
- cloud.Cloud
- title: Cloud
desc: Functions to inspect docstrings.
contents:
# the functions being documented in the package.
# you can refer to anything: class methods, modules, etc..
- operators.pde_solver
- cloud.Cloud
2 changes: 1 addition & 1 deletion updes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import updes.config as UPDEC
import updes.config as UPDES

from updes.utils import *
from updes.cloud import *
Expand Down
20 changes: 10 additions & 10 deletions updes/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import cache, lru_cache, partial

# from updec.config import RBF, MAX_DEGREE, DIM
import updes.config as UPDEC
import updes.config as UPDES
from updes.utils import compute_nb_monomials, SteadySol, make_all_monomials
from updes.cloud import Cloud, SquareCloud
from updes.assembly import assemble_B, assemble_q, core_compute_coefficients
Expand Down Expand Up @@ -562,10 +562,10 @@ def pde_solver( diff_operator:callable,
diff_operator = jax.jit(diff_operator, static_argnums=[2,3])
rhs_operator = jax.jit(rhs_operator, static_argnums=2)

UPDEC.RBF = rbf
UPDES.RBF = rbf
### For rememmering purposes
UPDEC.MAX_DEGREE = max_degree
UPDEC.DIM = cloud.dim
UPDES.MAX_DEGREE = max_degree
UPDES.DIM = cloud.dim


## Build robin coeffs
Expand Down Expand Up @@ -682,10 +682,10 @@ def pde_multi_solver( diff_operators:list,
diff_operators = [jax.jit(diff_op, static_argnums=[2,3]) for diff_op in diff_operators]
rhs_operators = [jax.jit(rhs_op, static_argnums=2) for rhs_op in rhs_operators]

UPDEC.RBF = rbf
UPDES.RBF = rbf
### For rememmering purposes
UPDEC.MAX_DEGREE = max_degree
UPDEC.DIM = cloud.dim
UPDES.MAX_DEGREE = max_degree
UPDES.DIM = cloud.dim


## Build robin coeffs
Expand Down Expand Up @@ -751,10 +751,10 @@ def pde_multi_solver_unbounded( diff_operators:list,
diff_operators = [jax.jit(diff_op, static_argnums=[2,3]) for diff_op in diff_operators]
rhs_operators = [jax.jit(rhs_op, static_argnums=2) for rhs_op in rhs_operators]

UPDEC.RBF = rbf
UPDES.RBF = rbf
### For rememmering purposes
UPDEC.MAX_DEGREE = max_degree
UPDEC.DIM = cloud.dim
UPDES.MAX_DEGREE = max_degree
UPDES.DIM = cloud.dim


## Build robin coeffs
Expand Down
93 changes: 55 additions & 38 deletions updes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,58 +16,50 @@
import random


# def periodic_distance_squre(node1, node2, W, H):
# dx = jnp.abs(node1[0] - node2[0])
# dy = jnp.abs(node1[1] - node2[1])
# # dx = jnp.where(dx > W/2, W - dx, dx)
# # dy = jnp.where(dy > H/2, H - dy, dy)
# dx = jnp.minimum(dx, W - dx)
# dy = jnp.minimum(dy, H - dy)
# return jnp.sqrt(dx**2 + dy**2)


## Euclidian distance
def distance(node1, node2):
diff = node1 - node2
# return jnp.sum(diff*diff) ## Squared distance
# return periodic_distance_squre(node1, node2, 1., 1.)
# return jnp.linalg.norm(node1 - node2) ## Carefull: not differentiable at 0
return jnp.sqrt(diff.T @ diff)


## Print each item of a dictionary in a new line
def print_line_by_line(dictionary):
for k, v in dictionary.items():
print("\t", k,":",v)


## Hardy's Multiquadric RBF
def multiquadric_func(r, eps):
return jnp.sqrt(1 + (eps*r)**2)
@jax.jit
def multiquadric(x, center, eps=1.):
return multiquadric_func(distance(x, center), eps)

## Inverse Multiquadric RBF
def inv_multiquadric_func(r, eps):
return 1./ jnp.sqrt(1 + (eps*r)**2)
@jax.jit
def inverse_multiquadric(x, center, eps=1.):
return inv_multiquadric_func(distance(x, center), eps)

## Gaussian RBF
def gaussian_func(r, eps):
return jnp.exp(-(eps * r)**2)
def gaussian(x, center, eps=1.):
return gaussian_func(distance(x, center), eps)

## Polyharmonic Spline RBF
def polyharmonic_func(r, a):
return r**(2*a+1)
@jax.jit
def polyharmonic(x, center, a=1):
return polyharmonic_func(distance(x, center), a)
@jax.jit
def polyharmonic_grad(x, center):
return 3 * distance(x, center) * (x - center)


## Gradient of Polyharmonic Spline RBF
# @jax.jit
# def polyharmonic_grad(x, center):
# return 3 * distance(x, center) * (x - center)

## Thin Plate Spline RBF
def thin_plate_func(r, a):
# return jnp.log(r) * r**(2*a)
return jnp.nan_to_num(jnp.log(r) * r**(2*a), neginf=0., posinf=0.)
Expand All @@ -76,23 +68,37 @@ def thin_plate(x, center, a=1):
return thin_plate_func(distance(x, center), a)


# @jax.jit
@Partial(jax.jit, static_argnums=2)
def make_nodal_rbf(x, node, rbf):
""" Gives the tuned rbf function """
# """ Gives the tuned rbf function """
"""A function that returns the value of the RBF at a given point x, with respect to a given node. The RBF is tuned to the given node.
Args:
x (Float[Array, "dim"]): The point at which the RBF is to be evaluated.
node (Float[Array, "dim"]): The centroid with respect to which the RBF is evaluated.
rbf (Callable): The RBF function to be used, with signature rbf(r) where r is the Euclidean distance between the two points
Returns:
float: The scalar value of the RBF at the given point x, with respect to the given node.
"""
if rbf==None:
func = polyharmonic
else:
func = rbf
# return jnp.where(jnp.all(x==node), 0., func(distance(x, node))) ## TODO Bad attempt to avoid differentiability
return func(distance(x, node))


@Partial(jax.jit, static_argnums=1)
def make_monomial(x, id):
""" Easy way to keep track of all monomials """
## x is a 2D vector
if id == 0:
"""A function that returns the value of a monomial at a given point x.
Args:
x (Float[Array, "dim"]): The point at which the monomial is to be evaluated.
id (int): The id of the monomial to be evaluated.
Returns:
float: The value of the monomial at the given point x.
"""
return 1.0
elif id == 1:
return x[0]
Expand Down Expand Up @@ -123,25 +129,25 @@ def make_monomial(x, id):
elif id == 14:
return x[1]**4
else:
pass ## Higher order monomials not yet supported
pass ## TODO: support higher order monomials !

@cache
def make_all_monomials(nb_monomials):
# return jnp.array([Partial(make_monomial, id=j) for j in range(nb_monomials)])
"""A function that returns up to a certain number of monomials"""
return [partial(make_monomial, id=j) for j in range(nb_monomials)]



def compute_nb_monomials(max_degree, problem_dimension):
"""Computes the number of monomials of dregree less than 'max_degree', in dimension 'problem_dimension'"""
return math.comb(max_degree+problem_dimension, max_degree)


## This stores both the RBF coefficients, the values, and its matrix after solving a PDE
SteadySol = namedtuple('PDESolution', ['vals', 'coeffs', 'mat'])



def random_name(length=5):
"Make random names to identify runs"
"""Make up a random name to identify a run"""
name = ""
for _ in range(length):
name += str(random.randint(0, 9))
Expand All @@ -151,16 +157,10 @@ def random_name(length=5):
def make_dir(path):
"Make a directory if it doesn't exist"
if not os.path.exists(path):
# os.system("rm -rf " + path)
os.mkdir(path)


# plt.style.use('bmh')
# sns.set(context='notebook', style='ticks',
# font='sans-serif', font_scale=1, color_codes=True, rc={"lines.linewidth": 2})

## Wrapper function for matplotlib and seaborn
def plot(*args, ax=None, figsize=(6,3.5), x_label=None, y_label=None, title=None, x_scale='linear', y_scale='linear', xlim=None, ylim=None, **kwargs):
"""Wrapper function for matplotlib and seaborn"""
if ax==None:
_, ax = plt.subplots(1, 1, figsize=figsize)
# sns.despine(ax=ax)
Expand All @@ -182,7 +182,7 @@ def plot(*args, ax=None, figsize=(6,3.5), x_label=None, y_label=None, title=None
plt.tight_layout()
return ax


## A dataloader for mini-batch training if needed
def dataloader(array, batch_size, key):
dataset_size = array.shape[0]
indices = jnp.arange(dataset_size)
Expand All @@ -198,7 +198,24 @@ def dataloader(array, batch_size, key):


def RK4(fun, t_span, y0, *args, t_eval=None, subdivisions=1, **kwargs):
""" Perform numerical integration with a time step divided by the evaluation subdivision factor (Not necessarily equally spaced). If we get NaNs, we can try to increasing the subdivision factor for finer time steps."""
"""Numerical integration with a time interval subdivisions
Args:
fun (Callable): The function to be integrated.
y0 (Float[Array]): The initial condition.
t_span (Tuple): The time interval for which the integration is to be performed.
t_eval (Float[Array]): The time points at which the solution is to be evaluated.
subdivisions (int): To improve stability, each interval in t_eval is divided into this many subdivisions. Consider increasing this if you obtain NaNs.
*args: Additional arguments to be passed to the function.
**kwargs: Additional keyword arguments to be passed to the function.
Raises:
Warning: if t_span[0] is None.
ValueError: if t_eval is None and t_span[1] is None.
Returns:
Float[Array, "nb_time_steps"]: The solution at the time points in t_eval.
"""
if t_eval is None:
if t_span[0] is None:
t_eval = jnp.array([t_span[1]])
Expand Down

0 comments on commit f358d45

Please sign in to comment.