Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero-Magnetization Sampler #38

Merged
merged 1 commit into from
Jan 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions jVMC/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ def propose_spin_flip_Z2(key, s, info):
doFlip = random.randint(flipKey, (1,), 0, 5)[0]
return jax.lax.cond(doFlip == 0, lambda x: 1 - x, lambda x: x, s)

def propose_spin_flip_zeroMag(key, s, info):
# propose spin flips that stay in the zero magnetization sector

idxKeyUp, idxKeyDown, flipKey = jax.random.split(key, num=3)

# can't use jnp.where because then it is not jit-compilable
# find indices based on cumsum
bound_up = jax.random.randint(idxKeyUp, (1,), 1, s.shape[0] * s.shape[1] / 2 + 1)
bound_down = jax.random.randint(idxKeyDown, (1,), 1, s.shape[0] * s.shape[1] / 2 + 1)

id_up = jnp.searchsorted(jnp.cumsum(s), bound_up)
id_down = jnp.searchsorted(jnp.cumsum(1 - s), bound_down)

idx_up = jnp.unravel_index(id_up, s.shape)
idx_down = jnp.unravel_index(id_down, s.shape)

s = s.at[idx_up[0], idx_up[1]].set(0)
s = s.at[idx_down[0], idx_down[1]].set(1)

# On average, do a global spin flip every 30 updates to
# reflect Z_2 symmetry
doFlip = random.randint(flipKey, (1,), 0, 5)[0]
return jax.lax.cond(doFlip == 0, lambda x: 1 - x, lambda x: x, s)


class MCSampler:
"""A sampler class.
Expand Down