Skip to content

Commit

Permalink
Merge pull request #38 from markusschmitt/Zero-Magnetization-Sampler
Browse files Browse the repository at this point in the history
Zero-Magnetization Sampler
  • Loading branch information
markusschmitt authored Jan 24, 2023
2 parents ba98869 + 1a597b6 commit f4f319f
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions jVMC/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ def propose_spin_flip_Z2(key, s, info):
doFlip = random.randint(flipKey, (1,), 0, 5)[0]
return jax.lax.cond(doFlip == 0, lambda x: 1 - x, lambda x: x, s)

def propose_spin_flip_zeroMag(key, s, info):
# propose spin flips that stay in the zero magnetization sector

idxKeyUp, idxKeyDown, flipKey = jax.random.split(key, num=3)

# can't use jnp.where because then it is not jit-compilable
# find indices based on cumsum
bound_up = jax.random.randint(idxKeyUp, (1,), 1, s.shape[0] * s.shape[1] / 2 + 1)
bound_down = jax.random.randint(idxKeyDown, (1,), 1, s.shape[0] * s.shape[1] / 2 + 1)

id_up = jnp.searchsorted(jnp.cumsum(s), bound_up)
id_down = jnp.searchsorted(jnp.cumsum(1 - s), bound_down)

idx_up = jnp.unravel_index(id_up, s.shape)
idx_down = jnp.unravel_index(id_down, s.shape)

s = s.at[idx_up[0], idx_up[1]].set(0)
s = s.at[idx_down[0], idx_down[1]].set(1)

# On average, do a global spin flip every 30 updates to
# reflect Z_2 symmetry
doFlip = random.randint(flipKey, (1,), 0, 5)[0]
return jax.lax.cond(doFlip == 0, lambda x: 1 - x, lambda x: x, s)


class MCSampler:
"""A sampler class.
Expand Down

0 comments on commit f4f319f

Please sign in to comment.