diff --git a/pydiffuser/models/aoup.py b/pydiffuser/models/aoup.py index 36409db..f53161f 100644 --- a/pydiffuser/models/aoup.py +++ b/pydiffuser/models/aoup.py @@ -1,9 +1,12 @@ +from functools import partial from typing import List import jax.numpy as jnp from jax import Array +from numpy.random import uniform from pydiffuser.models.core import OverdampedLangevin, OverdampedLangevinConfig +from pydiffuser.utils import jitted _GENERATE_HOOKS = ["ornstein_uhlenbeck_process"] @@ -50,9 +53,11 @@ def __init__( ): """ Consider an Ornstein-Uhlenbeck process for a self-propulsion velocity p: + ``` dp ____ ── = - μ x p + √2Dou Γ(t), dt + ``` where Γ(t) is a Gaussian white noise with zero mean and unit variance. Note that p is coupled with the overdamped Langevin equation written in `pydiffuser.models.core.sde.OverdampedLangevin`. @@ -72,7 +77,12 @@ def __init__( def ornstein_uhlenbeck_process(self) -> Array: realization, length, dimension, _ = self.generate_info.values() - p = jnp.ones((realization, 1, dimension)) # init + # TODO patternize + p = jitted.get_noise( # init + generator=partial(uniform, -1, 1), + size=realization * dimension, + shape=(realization, 1, dimension), + ) dp = self.get_diff_from_white_noise( diffusivity=self.diffusion_coefficient, shape=(realization, (length - 2), dimension),