Skip to content

Commit

Permalink
fix repeated jit
Browse files Browse the repository at this point in the history
  • Loading branch information
tszoldra committed Dec 12, 2024
1 parent 129ea53 commit b2a93bd
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions jVMC/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(self, net, sampleShape, key, updateProposer=None, numChains=1, upda

# Make sure that net is initialized
self.net(self.states)
self.sampler_net, _ = self.net.get_sampler_net()

self.logProbFactor = logProbFactor
self.mu = mu
Expand Down Expand Up @@ -150,6 +151,17 @@ def __init__(self, net, sampleShape, key, updateProposer=None, numChains=1, upda
self._get_samples_jitd = {} # will hold a jit'd function for each number of samples
self._randomize_samples_jitd = {} # will hold a jit'd function for each number of samples


# pmap'd helper function
self._logAccProb_pmapd = global_defs.pmap_for_my_devices(self._logAccProb,
in_axes=(0, None, None, None),
static_broadcasted_argnums=(2,))

def _logAccProb(self, x, mu, sampler_net, netParams):
# vmap is over parallel MC chains
return jax.vmap(lambda y: mu * jnp.real(sampler_net(netParams, y)), in_axes=(0,))(x)


def set_number_of_samples(self, N):
"""Set default number of samples.
Expand Down Expand Up @@ -347,11 +359,7 @@ def update(acc, old, new):

def _mc_init(self, netParams):

# Initialize logAccProb
net, _ = self.net.get_sampler_net()
self.logAccProb = global_defs.pmap_for_my_devices(
lambda x: jax.vmap(lambda y: self.mu * jnp.real(net(netParams, y)), in_axes=(0,))(x)
)(self.states)
self.logAccProb = self._logAccProb_pmapd(self.states, self.mu, self.sampler_net, netParams)

shape = (global_defs.device_count(),) + (1,)

Expand Down

0 comments on commit b2a93bd

Please sign in to comment.