diff --git a/src/emcee/moves/de.py b/src/emcee/moves/de.py index 32e60cfa..27105e0c 100644 --- a/src/emcee/moves/de.py +++ b/src/emcee/moves/de.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from functools import lru_cache import numpy as np @@ -13,7 +14,7 @@ class DEMove(RedBlueMove): This `Differential evolution proposal `_ is implemented following `Nelson et al. (2013) - `_. + `_. Args: sigma (float): The standard deviation of the Gaussian used to stretch @@ -27,8 +28,7 @@ class DEMove(RedBlueMove): def __init__(self, sigma=1.0e-5, gamma0=None, **kwargs): self.sigma = sigma self.gamma0 = gamma0 - kwargs["nsplits"] = 3 - super(DEMove, self).__init__(**kwargs) + super().__init__(**kwargs) def setup(self, coords): self.g0 = self.gamma0 @@ -38,14 +38,40 @@ def setup(self, coords): self.g0 = 2.38 / np.sqrt(2 * ndim) def get_proposal(self, s, c, random): - Ns = len(s) - Nc = list(map(len, c)) - ndim = s.shape[1] - q = np.empty((Ns, ndim), dtype=np.float64) - f = self.sigma * random.randn(Ns) - for i in range(Ns): - w = np.array([c[j][random.randint(Nc[j])] for j in range(2)]) - random.shuffle(w) - g = np.diff(w, axis=0) * self.g0 + f[i] - q[i] = s[i] + g - return q, np.zeros(Ns, dtype=np.float64) + c = np.concatenate(c, axis=0) + ns, ndim = s.shape + nc = c.shape[0] + + # Get the pair indices + pairs = _get_nondiagonal_pairs(nc) + + # Sample from the pairs + indices = random.choice(pairs.shape[0], size=ns, replace=True) + pairs = pairs[indices] + + # Compute diff vectors + diffs = np.diff(c[pairs], axis=1).squeeze(axis=1) # (ns, ndim) + + # Sample a gamma value for each walker following Nelson et al. (2013) + gamma = self.g0 * (1 + self.sigma * random.randn(ns, 1)) # (ns, 1) + + # In this way, sigma is the standard deviation of the distribution of gamma, + # instead of the standard deviation of the distribution of the proposal as proposed by Ter Braak (2006). + # Otherwise, sigma should be tuned for each dimension, which confronts the idea of affine-invariance. + + q = s + gamma * diffs + + return q, np.zeros(ns, dtype=np.float64) + + +@lru_cache(maxsize=1) +def _get_nondiagonal_pairs(n: int) -> np.ndarray: + """Get the indices of a square matrix with size n, excluding the diagonal.""" + rows, cols = np.tril_indices(n, -1) # -1 to exclude diagonal + + # Combine rows-cols and cols-rows pairs + pairs = np.column_stack( + [np.concatenate([rows, cols]), np.concatenate([cols, rows])] + ) + + return pairs