diff --git a/jVMC/sampler.py b/jVMC/sampler.py index 8904de5..cc8c993 100644 --- a/jVMC/sampler.py +++ b/jVMC/sampler.py @@ -39,6 +39,30 @@ def propose_spin_flip_Z2(key, s, info): doFlip = random.randint(flipKey, (1,), 0, 5)[0] return jax.lax.cond(doFlip == 0, lambda x: 1 - x, lambda x: x, s) +def propose_spin_flip_zeroMag(key, s, info): + # propose spin flips that stay in the zero magnetization sector + + idxKeyUp, idxKeyDown, flipKey = jax.random.split(key, num=3) + + # can't use jnp.where because then it is not jit-compilable + # find indices based on cumsum + bound_up = jax.random.randint(idxKeyUp, (1,), 1, s.shape[0] * s.shape[1] / 2 + 1) + bound_down = jax.random.randint(idxKeyDown, (1,), 1, s.shape[0] * s.shape[1] / 2 + 1) + + id_up = jnp.searchsorted(jnp.cumsum(s), bound_up) + id_down = jnp.searchsorted(jnp.cumsum(1 - s), bound_down) + + idx_up = jnp.unravel_index(id_up, s.shape) + idx_down = jnp.unravel_index(id_down, s.shape) + + s = s.at[idx_up[0], idx_up[1]].set(0) + s = s.at[idx_down[0], idx_down[1]].set(1) + + # On average, do a global spin flip every 30 updates to + # reflect Z_2 symmetry + doFlip = random.randint(flipKey, (1,), 0, 5)[0] + return jax.lax.cond(doFlip == 0, lambda x: 1 - x, lambda x: x, s) + class MCSampler: """A sampler class.