diff --git a/documentation/source/index.rst b/documentation/source/index.rst index 6bf8f4c..85ad713 100644 --- a/documentation/source/index.rst +++ b/documentation/source/index.rst @@ -16,17 +16,18 @@ parallelism .. toctree:: - :hidden: - :glob: - :maxdepth: 2 - :caption: API documentation - - vqs - operator - sampler - nets - mpi - util + :hidden: + :glob: + :maxdepth: 2 + :caption: API documentation + + vqs + operator + sampler + nets + mpi + stats + util .. toctree:: :hidden: diff --git a/documentation/source/stats.rst b/documentation/source/stats.rst new file mode 100644 index 0000000..ff3e830 --- /dev/null +++ b/documentation/source/stats.rst @@ -0,0 +1,21 @@ +.. _stats: + +Sample statistics module +======================== + +The ``SampledObs`` class provides funcitonality to conveniently compute sample statistics. + +**Example:** + + Assuming that ``sampler`` is an instance of a sampler class and ``psi`` is a variational quantum state, + the quantum geometric tensor + :math:`S_{k,k'}=\langle(\partial_{\theta_k}\log\psi_\theta)^*\partial_{\theta_{k'}}\log\psi_\theta\rangle_c` + can be computed as + + >>> s, logPsi, p = sampler.sample() + >>> grads = SampledObs( psi.gradients(s), p) + >>> S = grads.covar() + +.. automodule:: jVMC.stats + :members: + :special-members: __call__ diff --git a/jVMC/__init__.py b/jVMC/__init__.py index a150c47..51f025a 100644 --- a/jVMC/__init__.py +++ b/jVMC/__init__.py @@ -5,6 +5,7 @@ from . import mpi_wrapper from . import vqs from . import sampler +from . import stats from .version import __version__ from .global_defs import set_pmap_devices diff --git a/jVMC/global_defs.py b/jVMC/global_defs.py index 52a62f3..0fc7b51 100644 --- a/jVMC/global_defs.py +++ b/jVMC/global_defs.py @@ -10,6 +10,7 @@ import jax from functools import partial +import collections try: myDevice = jax.devices()[MPI.COMM_WORLD.Get_rank() % len(jax.devices())] @@ -21,7 +22,11 @@ myDeviceCount = len(myPmapDevices) pmap_for_my_devices = partial(jax.pmap, devices=myPmapDevices) -import collections +pmapDevices = None +def pmap_devices_updated(): + if collections.Counter(pmapDevices) == collections.Counter(myPmapDevices): + return False + return True def get_iterable(x): diff --git a/jVMC/mpi_wrapper.py b/jVMC/mpi_wrapper.py index af2ad9f..e6cb8f7 100644 --- a/jVMC/mpi_wrapper.py +++ b/jVMC/mpi_wrapper.py @@ -17,38 +17,28 @@ communicationTime = 0. -def _cov_helper_with_p(data, p): +def _cov_helper(data, p): return jnp.expand_dims( jnp.matmul(jnp.conj(jnp.transpose(data)), jnp.multiply(p[:, None], data)), axis=0 ) -def _cov_helper_without_p(data): - return jnp.expand_dims( - jnp.matmul(jnp.conj(jnp.transpose(data)), data), - axis=0 - ) - - _sum_up_pmapd = None _sum_sq_pmapd = None -_sum_sq_withp_pmapd = None mean_helper = None -cov_helper_with_p = None -cov_helper_without_p = None - -pmapDevices = None +cov_helper = None -import collections +#pmapDevices = None -def pmap_devices_updated(): +import collections - if collections.Counter(pmapDevices) == collections.Counter(global_defs.myPmapDevices): - return False - return True +# def pmap_devices_updated(): +# if collections.Counter(pmapDevices) == collections.Counter(global_defs.myPmapDevices): +# return False +# return True def jit_my_stuff(): @@ -57,19 +47,16 @@ def jit_my_stuff(): global _sum_up_pmapd global _sum_sq_pmapd - global _sum_sq_withp_pmapd global mean_helper - global cov_helper_with_p - global cov_helper_without_p + global cov_helper global pmapDevices - if pmap_devices_updated(): + if global_defs.pmap_devices_updated(): _sum_up_pmapd = global_defs.pmap_for_my_devices(lambda x: jax.lax.psum(jnp.sum(x, axis=0), 'i'), axis_name='i') - _sum_sq_pmapd = global_defs.pmap_for_my_devices(lambda data, mean: jax.lax.psum(jnp.sum(jnp.conj(data - mean) * (data - mean), axis=0), 'i'), axis_name='i', in_axes=(0, None)) - _sum_sq_withp_pmapd = global_defs.pmap_for_my_devices(lambda data, mean, p: jax.lax.psum(jnp.conj(data - mean).dot(p * (data - mean)), 'i'), axis_name='i', in_axes=(0, None, 0)) + # _sum_sq_pmapd = global_defs.pmap_for_my_devices(lambda data, mean, p: jax.lax.psum(jnp.conj(data - mean).dot(p[..., None] * (data - mean)), 'i'), axis_name='i', in_axes=(0, None, 0)) + _sum_sq_pmapd = global_defs.pmap_for_my_devices(lambda data, mean, p: jnp.einsum('ij, ij, i -> j', jnp.conj(data - mean[None, ...]), (data - mean[None, ...]), p), in_axes=(0, None, 0)) mean_helper = global_defs.pmap_for_my_devices(lambda data, p: jnp.expand_dims(jnp.dot(p, data), axis=0), in_axes=(0, 0)) - cov_helper_with_p = global_defs.pmap_for_my_devices(_cov_helper_with_p, in_axes=(0, 0)) - cov_helper_without_p = global_defs.pmap_for_my_devices(_cov_helper_without_p) + cov_helper = global_defs.pmap_for_my_devices(_cov_helper, in_axes=(0, 0)) pmapDevices = global_defs.myPmapDevices @@ -111,7 +98,7 @@ def distribute_sampling(numSamples, localDevices=None, numChainsPerDevice=1) -> numChainsPerProcess = localDevices * numChainsPerDevice def spc(spp): - return int( (spp + numChainsPerProcess - 1) // numChainsPerProcess ) + return int((spp + numChainsPerProcess - 1) // numChainsPerProcess) a = numSamples % commSize globNumSamples = (a * spc(1 + numSamples // commSize) + (commSize - a) * spc(numSamples // commSize)) * numChainsPerProcess @@ -171,7 +158,7 @@ def global_sum(data): return jax.device_put(res, global_defs.myDevice) -def global_mean(data, p=None): +def global_mean(data, p): """ Computes the mean of input data across MPI processes and device/batch dimensions. On each MPI process the input data is assumed to be a ``jax.numpy.array`` with a leading @@ -179,11 +166,7 @@ def global_mean(data, p=None): along device and batch dimensions as well as accross MPI processes. Hence, the result is an array of shape ``data.shape[2:]``. - If no probabilities ``p`` are given, the empirical mean is computed, i.e., - - :math:`\\langle X\\rangle=\\frac{1}{N_S}\sum_{j=1}^{N_S} X_j` - - Otherwise, the mean is computed using the given probabilities, i.e., + The mean is computed using the given probabilities, i.e., :math:`\\langle X\\rangle=\sum_{j=1}^{N_S} p_jX_j` @@ -195,17 +178,14 @@ def global_mean(data, p=None): Mean of data across MPI processes and device/batch dimensions. """ - jit_my_stuff() - if p is not None: - return global_sum(mean_helper(data, p)) - global globNumSamples + jit_my_stuff() - return global_sum(data) / globNumSamples + return global_sum(mean_helper(data, p)) -def global_variance(data, p=None): +def global_variance(data, p): """ Computes the variance of input data across MPI processes and device/batch dimensions. On each MPI process the input data is assumed to be a ``jax.numpy.array`` with a leading @@ -213,11 +193,7 @@ def global_variance(data, p=None): along device and batch dimensions as well as accross MPI processes. Hence, the result is an array of shape ``data.shape[2:]``. - If no probabilities ``p`` are given, the empirical element-wise variance is computed, i.e., - - :math:`\\text{Var}(X)=\\frac{1}{N_S}\sum_{j=1}^{N_S} |X_j-\\langle X\\rangle|^2` - - Otherwise, the mean is computed using the given probabilities, i.e., + The mean is computed using the given probabilities, i.e., :math:`\\text{Var}(X)=\sum_{j=1}^{N_S} p_j |X_j-\\langle X\\rangle|^2` @@ -234,15 +210,8 @@ def global_variance(data, p=None): data.block_until_ready() mean = global_mean(data, p) - # Compute sum locally - localSum = None - if p is not None: - localSum = np.array(_sum_sq_withp_pmapd(data, mean, p)[0]) - else: - res = _sum_sq_pmapd(data, mean)[0] - res.block_until_ready() - localSum = np.array(res) + localSum = np.array(_sum_sq_pmapd(data, mean, p)[0]) # Allocate memory for result res = np.empty_like(localSum, dtype=localSum.dtype) @@ -256,15 +225,10 @@ def global_variance(data, p=None): global communicationTime communicationTime += time.perf_counter() - t0 - if p is not None: - # return jnp.array(res) - return jax.device_put(res, global_defs.myDevice) - else: - # return jnp.array(res) / globNumSamples - return jax.device_put(res / globNumSamples, global_defs.myDevice) + return jax.device_put(res, global_defs.myDevice) -def global_covariance(data, p=None): +def global_covariance(data, p): """ Computes the covariance matrix of input data across MPI processes and device/batch dimensions. On each MPI process the input data is assumed to be a ``jax.numpy.array`` with a leading @@ -273,11 +237,7 @@ def global_covariance(data, p=None): matrix along device and batch dimensions as well as accross MPI processes. Hence, the result is an array of shape ``data.shape[2]`` :math:`\\times` ``data.shape[2]``. - If no probabilities ``p`` are given, the empirical covariance is computed, i.e., - - :math:`\\text{Cov}(X)=\\frac{1}{N_S}\sum_{j=1}^{N_S} X_j\\cdot X_j^\\dagger - \\bigg(\\frac{1}{N_S}\sum_{j=1}^{N_S} X_j\\bigg)\\cdot\\bigg(\\frac{1}{N_S}\sum_{j=1}^{N_S}X_j^\\dagger\\bigg)` - - Otherwise, the mean is computed using the given probabilities, i.e., + The mean is computed using the given probabilities, i.e., :math:`\\text{Cov}(X)=\sum_{j=1}^{N_S} p_jX_j\\cdot X_j^\\dagger - \\bigg(\sum_{j=1}^{N_S} p_jX_j\\bigg)\\cdot\\bigg(\sum_{j=1}^{N_S}p_jX_j^\\dagger\\bigg)` @@ -291,11 +251,7 @@ def global_covariance(data, p=None): jit_my_stuff() - if p is not None: - - return global_sum(cov_helper_with_p(data, p)) - - return global_mean(cov_helper_without_p(data)) + return global_sum(cov_helper(data, p)) def bcast_unknown_size(data, root=0): diff --git a/jVMC/operator/povm.py b/jVMC/operator/povm.py index 70e2cd6..53c5502 100644 --- a/jVMC/operator/povm.py +++ b/jVMC/operator/povm.py @@ -35,27 +35,16 @@ def measure_povm(povm, sampler, sampleConfigs=None, probs=None, observables=None for name, ops in observables.items(): results = povm.evaluate_observable(ops, sampleConfigs) result[name] = {} - if probs is not None: - result[name]["mean"] = jnp.array(mpi.global_mean(results[0], probs)) - result[name]["variance"] = jnp.array(mpi.global_variance(results[0], probs)) - result[name]["MC_error"] = jnp.array(0) - - else: - result[name]["mean"] = jnp.array(mpi.global_mean(results[0])) - result[name]["variance"] = jnp.array(mpi.global_variance(results[0])) - result[name]["MC_error"] = jnp.array(result[name]["variance"] / jnp.sqrt(sampler.get_last_number_of_samples())) + result[name]["mean"] = jnp.array(mpi.global_mean(results[0][..., None], probs)[0]) + result[name]["variance"] = jnp.array(mpi.global_variance(results[0][..., None], probs)[0]) + result[name]["MC_error"] = jnp.array(result[name]["variance"] / jnp.sqrt(sampler.get_last_number_of_samples())) for key, value in results[1].items(): result_name = name + "_corr_L" + str(key) result[result_name] = {} - if probs is not None: - result[result_name]["mean"] = jnp.array(mpi.global_mean(value, probs) - result[name]["mean"]**2) - result[result_name]["variance"] = jnp.array(mpi.global_variance(value, probs)) - result[result_name]["MC_error"] = jnp.array(0.) - else: - result[result_name]["mean"] = jnp.array(mpi.global_mean(value) - result[name]["mean"]**2) - result[result_name]["variance"] = jnp.array(mpi.global_variance(value)) - result[result_name]["MC_error"] = jnp.array(result[result_name]["variance"] / jnp.sqrt(sampler.get_last_number_of_samples())) + result[result_name]["mean"] = jnp.array(mpi.global_mean(value[..., None], probs)[0] - result[name]["mean"]**2) + result[result_name]["variance"] = jnp.array(mpi.global_variance(value[..., None], probs)[0]) + result[result_name]["MC_error"] = jnp.array(result[name]["variance"] / jnp.sqrt(sampler.get_last_number_of_samples())) return result diff --git a/jVMC/sampler.py b/jVMC/sampler.py index 2817462..807083d 100644 --- a/jVMC/sampler.py +++ b/jVMC/sampler.py @@ -40,6 +40,7 @@ def propose_spin_flip_Z2(key, s, info): doFlip = random.randint(flipKey, (1,), 0, 5)[0] return jax.lax.cond(doFlip == 0, lambda x: 1 - x, lambda x: x, s) + def propose_spin_flip_zeroMag(key, s, info): # propose spin flips that stay in the zero magnetization sector @@ -69,8 +70,13 @@ class MCSampler: """A sampler class. This class provides functionality to sample computational basis states from \ - the probability distribution induced by the variational wave function, \ - :math:`|\\psi(s)|^2`. + the distribution + + :math:`p_{\\mu}(s)=\\frac{|\\psi(s)|^{\\mu}}{\\sum_s|\\psi(s)|^{\\mu}}`. + + For :math:`\\mu=2` this corresponds to sampling from the Born distribution. \ + :math:`0\leq\\mu<2` can be used to perform importance sampling \ + (see `[arXiv:2108.08631] `_). Sampling is automatically distributed accross MPI processes and locally available \ devices. @@ -90,10 +96,11 @@ class MCSampler: * ``numSamples``: Default number of samples to be returned by the ``sample()`` member function. * ``thermalizationSweeps``: Number of sweeps to perform for thermalization of the Markov chain. * ``sweepSteps``: Number of proposed updates per sweep. + * ``mu``: Parameter for the distribution :math:`p_{\\mu}(s)`, see above. """ def __init__(self, net, sampleShape, key, updateProposer=None, numChains=1, updateProposerArg=None, - numSamples=100, thermalizationSweeps=10, sweepSteps=10, initState=None): + numSamples=100, thermalizationSweeps=10, sweepSteps=10, initState=None, mu=2): """Initializes the MCSampler class. """ @@ -110,11 +117,14 @@ def __init__(self, net, sampleShape, key, updateProposer=None, numChains=1, upda stateShape = (global_defs.device_count(), numChains) + sampleShape if initState is None: initState = jnp.zeros(sampleShape, dtype=np.int32) - self.states = jnp.stack([initState] * (global_defs.device_count()*numChains), axis=0).reshape(stateShape) + self.states = jnp.stack([initState] * (global_defs.device_count() * numChains), axis=0).reshape(stateShape) # Make sure that net is initialized self.net(self.states) + self.mu = mu + if mu < 0 or mu > 2: + raise ValueError("mu must be in the range [0, 2]") self.updateProposer = updateProposer self.updateProposerArg = updateProposerArg @@ -174,7 +184,7 @@ def sample(self, parameters=None, numSamples=None, multipleOf=1): If supported by ``net``, direct sampling is peformed. Otherwise, MCMC is run \ to generate the desired number of samples. For direct sampling the real part \ of ``net`` needs to provide a ``sample()`` member function that generates \ - samples from :math:`|\\psi(s)|^2`. + samples from :math:`p_{\\mu}(s)`. Sampling is automatically distributed accross MPI processes and available \ devices. In that case the number of samples returned might exceed ``numSamples``. @@ -191,21 +201,19 @@ def sample(self, parameters=None, numSamples=None, multipleOf=1): a way that the number of samples per processor is identical for each processor. Returns: - A sample of computational basis configurations drawn from :math:`|\\psi(s)|^2`. + A sample of computational basis configurations drawn from :math:`p_{\\mu}(s)`. """ if numSamples is None: numSamples = self.numSamples if self.net.is_generator: - configs = self._get_samples_gen(parameters, numSamples, multipleOf) - - return configs, self.net(configs), None + return configs, self.net(configs), jnp.ones(configs.shape[:2]) / jnp.prod(jnp.asarray(configs.shape[:2])) configs, logPsi = self._get_samples_mcmc(parameters, numSamples, multipleOf) - - return configs, logPsi, None + p = jnp.exp((2 - self.mu) * jnp.real(logPsi)) + return configs, logPsi, p / mpi.global_sum(p) def _randomize_samples(self, samples, key, orbit): """ For a given set of samples apply a random symmetry transformation to each sample @@ -252,9 +260,9 @@ def _get_samples_mcmc(self, params, numSamples, multipleOf=1): static_broadcasted_argnums=(1, 2, 3, 9, 11), in_axes=(None, None, None, None, 0, 0, 0, 0, 0, None, None, None)) - (self.states, self.logPsiSq, self.key, self.numProposed, self.numAccepted), configs =\ + (self.states, self.logAccProb, self.key, self.numProposed, self.numAccepted), configs =\ self._get_samples_jitd[numSamplesStr](params, numSamples, self.thermalizationSweeps, self.sweepSteps, - self.states, self.logPsiSq, self.key, self.numProposed, self.numAccepted, + self.states, self.logAccProb, self.key, self.numProposed, self.numAccepted, self.updateProposer, self.updateProposerArg, self.sampleShape) # return configs, None @@ -262,29 +270,29 @@ def _get_samples_mcmc(self, params, numSamples, multipleOf=1): def _get_samples(self, params, numSamples, thermSweeps, sweepSteps, - states, logPsiSq, key, + states, logAccProb, key, numProposed, numAccepted, updateProposer, updateProposerArg, sampleShape, sweepFunction=None): # Thermalize - states, logPsiSq, key, numProposed, numAccepted =\ - sweepFunction(states, logPsiSq, key, numProposed, numAccepted, params, thermSweeps * sweepSteps, updateProposer, updateProposerArg) + states, logAccProb, key, numProposed, numAccepted =\ + sweepFunction(states, logAccProb, key, numProposed, numAccepted, params, thermSweeps * sweepSteps, updateProposer, updateProposerArg) # Collect samples def scan_fun(c, x): - states, logPsiSq, key, numProposed, numAccepted =\ + states, logAccProb, key, numProposed, numAccepted =\ sweepFunction(c[0], c[1], c[2], c[3], c[4], params, sweepSteps, updateProposer, updateProposerArg) - return (states, logPsiSq, key, numProposed, numAccepted), states + return (states, logAccProb, key, numProposed, numAccepted), states - meta, configs = jax.lax.scan(scan_fun, (states, logPsiSq, key, numProposed, numAccepted), None, length=numSamples) + meta, configs = jax.lax.scan(scan_fun, (states, logAccProb, key, numProposed, numAccepted), None, length=numSamples) # return meta, configs.reshape((configs.shape[0]*configs.shape[1], -1)) return meta, configs.reshape((configs.shape[0] * configs.shape[1],) + sampleShape) - def _sweep(self, states, logPsiSq, key, numProposed, numAccepted, params, numSteps, updateProposer, updateProposerArg, net=None): + def _sweep(self, states, logAccProb, key, numProposed, numAccepted, params, numSteps, updateProposer, updateProposerArg, net=None): def perform_mc_update(i, carry): @@ -294,8 +302,8 @@ def perform_mc_update(i, carry): newStates = vmap(updateProposer, in_axes=(0, 0, None))(newKeys[:len(carry[0])], carry[0], updateProposerArg) # Compute acceptance probabilities - newLogPsiSq = jax.vmap(lambda y: 2. * jnp.real(net(params, y)), in_axes=(0,))(newStates) - P = jnp.exp(newLogPsiSq - carry[1]) + newLogAccProb = jax.vmap(lambda y: self.mu * jnp.real(net(params, y)), in_axes=(0,))(newStates) + P = jnp.exp(newLogAccProb - carry[1]) # Roll dice newKey, carryKey = random.split(carryKey,) @@ -310,23 +318,22 @@ def update(acc, old, new): return jax.lax.cond(acc, lambda x: x[1], lambda x: x[0], (old, new)) carryStates = vmap(update, in_axes=(0, 0, 0))(accepted, carry[0], newStates) - carryLogPsiSq = jnp.where(accepted == True, newLogPsiSq, carry[1]) + carryLogAccProb = jnp.where(accepted == True, newLogAccProb, carry[1]) - return (carryStates, carryLogPsiSq, carryKey, numProposed, numAccepted) + return (carryStates, carryLogAccProb, carryKey, numProposed, numAccepted) - (states, logPsiSq, key, numProposed, numAccepted) =\ - jax.lax.fori_loop(0, numSteps, perform_mc_update, (states, logPsiSq, key, numProposed, numAccepted)) + (states, logAccProb, key, numProposed, numAccepted) =\ + jax.lax.fori_loop(0, numSteps, perform_mc_update, (states, logAccProb, key, numProposed, numAccepted)) - return states, logPsiSq, key, numProposed, numAccepted + return states, logAccProb, key, numProposed, numAccepted def _mc_init(self, netParams): - # Initialize logPsiSq - #self.logPsiSq = 2. * net.real_coefficients(self.states) + # Initialize logAccProb net, _ = self.net.get_sampler_net() - self.logPsiSq = global_defs.pmap_for_my_devices( - lambda x: jax.vmap(lambda y: 2. * jnp.real(net(netParams, y)), in_axes=(0,))(x) - )(self.states) + self.logAccProb = global_defs.pmap_for_my_devices( + lambda x: jax.vmap(lambda y: self.mu * jnp.real(net(netParams, y)), in_axes=(0,))(x) + )(self.states) shape = (global_defs.device_count(),) + (1,) @@ -377,7 +384,7 @@ def __init__(self, net, sampleShape, lDim=2, logProbFactor=0.5): self._normalize_pmapd = global_defs.pmap_for_my_devices(self._normalize, in_axes=(0, None)) self.get_basis() - + # Make sure that net params are initialized self.psi(self.basis) @@ -488,4 +495,7 @@ def sample(self, parameters=None, numSamples=None, multipleOf=None): def set_number_of_samples(self, N): pass + def get_last_number_of_samples(self): + return jnp.inf + # ** end class ExactSampler diff --git a/jVMC/stats.py b/jVMC/stats.py new file mode 100644 index 0000000..a6f0ebb --- /dev/null +++ b/jVMC/stats.py @@ -0,0 +1,138 @@ +import jax +import jax.numpy as jnp + +import jVMC +import jVMC.mpi_wrapper as mpi +from jVMC.global_defs import pmap_for_my_devices + + +class SampledObs(): + """This class implements the computation of statistics from Monte Carlo or exact samples. + + Initializer arguments: + * ``observations``: Observations :math:`O_n` in the sample. The array must have a leading device \ + dimension plus a batch dimension. + * ``weights``: Weights :math:`w_n` associated with observation :math:`O_n`. + """ + + def __init__(self, observations, weights): + """Initializes SampledObs class. + + Args: + * ``observations``: Observations :math:`O_n` in the sample. The array must have a leading device \ + dimension plus a batch dimension. + * ``weights``: Weights :math:`w_n` associated with observation :math:`O_n`. + """ + + self.jit_my_stuff() + + if len(observations.shape) == 2: + observations = observations[...,None] + + self._weights = weights + self._mean = mpi.global_sum( self._mean_helper(observations,self._weights)[None,...] ) + self._data = self._data_prep(observations, self._mean) + + + def mean(self): + """Returns the mean. + """ + + return self._mean + + + def covar(self, other=None): + """Returns the covariance. + + Args: + * ``other`` [optional]: Another instance of `SampledObs`. + """ + + self.jit_my_stuff() + + if other is None: + other = self + + return mpi.global_sum( self._covar_helper(self._data, other._data, self._weights)[None,...] ) + + + def var(self): + """Returns the variance. + """ + + return mpi.global_sum( self._mean_helper(jnp.abs(self._data)**2, self._weights)[None,...] ) + + + def covar_data(self, other=None): + """Returns the covariance. + + Args: + * ``other`` [optional]: Another instance of `SampledObs`. + """ + + self.jit_my_stuff() + + if other is None: + other = self + return SampledObs( self._covar_data_helper(self._data, other._data), self._weights ) + + + def covar_var(self, other=None): + """Returns the variance of the covariance. + + Args: + * ``other`` [optional]: Another instance of `SampledObs`. + """ + + self.jit_my_stuff() + + if other is None: + other = self + + return mpi.global_sum( self._covar_var_helper(self._data, other._data, self._weights)[None,...] ) \ + - jnp.abs(self.covar(other))**2 + + + def transform(self, fun=lambda x: x): + """Returns a `SampledObs` for the transformed data. + + Args: + * ``fun``: A function. + """ + + return SampledObs( self._trafo_helper(self._data, self._mean, fun), self._weights ) + + + def jit_my_stuff(self): + # This is a helper function to make sure that pmap'd functions work with the actual choice of devices + # at all times. + + if jVMC.global_defs.pmap_devices_updated(): + self._mean_helper = pmap_for_my_devices(lambda data, w: jnp.tensordot(w, data, axes=(0,0)), in_axes=(0, 0)) + self._data_prep = pmap_for_my_devices(lambda data, mean: data - mean, in_axes=(0, None)) + self._covar_helper = pmap_for_my_devices( + lambda data1, data2, w: + jnp.tensordot( + jnp.conj( + jax.vmap(lambda a,b: a*b, in_axes=(0,0))(w, data1) + ), + data2, axes=(0,0)), + in_axes=(0, 0, 0) + ) + self._covar_var_helper = pmap_for_my_devices( + lambda data1, data2, w: + jnp.sum( + w[...,None,None] * + jnp.abs( + jax.vmap(lambda a,b: jnp.outer(a,b))(jnp.conj(data1), data2), + )**2, + axis=0), + in_axes=(0, 0, 0) + ) + self._covar_data_helper = pmap_for_my_devices(lambda data1, data2: jax.vmap(lambda a,b: jnp.outer(a,b))(jnp.conj(data1), data2), in_axes=(0, 0)) + self._trafo_helper = pmap_for_my_devices(lambda data, mean, f: f(data + mean), in_axes=(0, None), static_broadcasted_argnums=(2,)) + + + + + \ No newline at end of file diff --git a/jVMC/util/tdvp.py b/jVMC/util/tdvp.py index d9dee0d..abfda45 100644 --- a/jVMC/util/tdvp.py +++ b/jVMC/util/tdvp.py @@ -2,12 +2,10 @@ import jax.numpy as jnp import numpy as np +import jVMC import jVMC.mpi_wrapper as mpi import jVMC.global_defs as global_defs - -from functools import partial - -import time +from jVMC.stats import SampledObs def realFun(x): @@ -27,7 +25,7 @@ class TDVP: and the quantum Fisher matrix - :math:`S_{k,k'} = \langle \mathcal O_{\\theta_k} (\mathcal O_{\\theta_{k'}})^*\\rangle_c` + :math:`S_{k,k'} = \langle (\mathcal O_{\\theta_k})^* \mathcal O_{\\theta_{k'}}\\rangle_c` and for real parameters :math:`\\theta\in\mathbb R`, the TDVP equation reads @@ -66,7 +64,6 @@ class TDVP: """ def __init__(self, sampler, snrTol=2, svdTol=1e-14, makeReal='imag', rhsPrefactor=1.j, diagonalShift=0., crossValidation=False, diagonalizeOnDevice=True): - self.sampler = sampler self.snrTol = snrTol self.svdTol = svdTol @@ -83,13 +80,6 @@ def __init__(self, sampler, snrTol=2, svdTol=1e-14, makeReal='imag', rhsPrefacto self.makeReal = imagFun # pmap'd member functions - self.subtract_helper_Eloc = global_defs.pmap_for_my_devices(lambda x, y: x - y, in_axes=(0, None)) - self.subtract_helper_grad = global_defs.pmap_for_my_devices(lambda x, y: x - y, in_axes=(0, None)) - self.get_EO = global_defs.pmap_for_my_devices(lambda f, Eloc, grad: -f * jnp.multiply(Eloc[:, None], jnp.conj(grad)), - in_axes=(None, 0, 0, 0), static_broadcasted_argnums=(0)) - self.get_EO_p = global_defs.pmap_for_my_devices(lambda f, p, Eloc, grad: -f * jnp.multiply((p * Eloc)[:, None], jnp.conj(grad)), - in_axes=(None, 0, 0, 0), static_broadcasted_argnums=(0)) - self.transform_EO = global_defs.pmap_for_my_devices(lambda eo, v: jnp.matmul(eo, jnp.conj(v)), in_axes=(0, None)) self.makeReal_pmapd = global_defs.pmap_for_my_devices(jax.vmap(lambda x: self.makeReal(x))) def set_diagonal_shift(self, delta): @@ -130,42 +120,27 @@ def get_S(self): return self.S - def get_tdvp_equation(self, Eloc, gradients, p=None): - - self.ElocMean = mpi.global_mean(Eloc, p) - self.ElocVar = jnp.real(mpi.global_variance(Eloc, p)) - Eloc = self.subtract_helper_Eloc(Eloc, self.ElocMean) - gradientsMean = mpi.global_mean(gradients, p) - gradients = self.subtract_helper_grad(gradients, gradientsMean) - - if p is None: - - EOdata = self.get_EO(self.rhsPrefactor, Eloc, gradients) + def get_tdvp_equation(self, Eloc, gradients): - self.F0 = mpi.global_mean(EOdata) - - else: - - EOdata = self.get_EO_p(self.rhsPrefactor, p, Eloc, gradients) - - self.F0 = mpi.global_sum(EOdata) + self.ElocMean = Eloc.mean() + self.ElocVar = Eloc.var() + self.F0 = (-self.rhsPrefactor) * gradients.covar(Eloc).ravel() #* EOdata.mean() F = self.makeReal(self.F0) - self.S0 = mpi.global_covariance(gradients, p) + self.S0 = gradients.covar() S = self.makeReal(self.S0) if self.diagonalShift > 1e-10: S = S + jnp.diag(self.diagonalShift * jnp.diag(S)) - return S, F, EOdata + return S, F def get_sr_equation(self, Eloc, gradients): - return self.get_tdvp_equation(Eloc, gradients, rhsPrefactor=1.) - def transform_to_eigenbasis(self, S, F, EOdata): - + def _transform_to_eigenbasis(self, S, F): + if self.diagonalizeOnDevice: try: self.ev, self.V = jnp.linalg.eigh(S) @@ -184,20 +159,26 @@ def transform_to_eigenbasis(self, S, F, EOdata): self.VtF = jnp.dot(jnp.transpose(jnp.conj(self.V)), F) - EOdata = self.transform_EO(self.makeReal_pmapd(EOdata), self.V) - EOdata.block_until_ready() - self.rhoVar = mpi.global_variance(EOdata) + def _get_snr(self, Eloc, gradients): - self.snr = jnp.sqrt(jnp.abs(mpi.globNumSamples * (jnp.conj(self.VtF) * self.VtF) / self.rhoVar)) + EO = gradients.covar_data(Eloc).transform( + fun=lambda x: jnp.matmul( + jnp.transpose(jnp.conj(self.V)), + jVMC.util.imagFun((-self.rhsPrefactor) * x) + ) + ) + self.rhoVar = EO.var().ravel() - def solve(self, Eloc, gradients, p=None): + self.snr = jnp.sqrt(jnp.abs(mpi.globNumSamples * (jnp.conj(self.VtF) * self.VtF) / self.rhoVar)).ravel() + def solve(self, Eloc, gradients): # Get TDVP equation from MC data - self.S, F, Fdata = self.get_tdvp_equation(Eloc, gradients, p) + self.S, F = self.get_tdvp_equation(Eloc, gradients) F.block_until_ready() - # Transform TDVP equation to eigenbasis - self.transform_to_eigenbasis(self.S, F, Fdata) + # Transform TDVP equation to eigenbasis and compute SNR + self._transform_to_eigenbasis(self.S, F) #, Fdata) + self._get_snr(Eloc, gradients) # Discard eigenvalues below numerical precision self.invEv = jnp.where(jnp.abs(self.ev / self.ev[-1]) > 1e-14, 1. / self.ev, 0.) @@ -205,7 +186,7 @@ def solve(self, Eloc, gradients, p=None): # Set regularizer for singular value cutoff regularizer = 1. / (1. + (self.svdTol / jnp.abs(self.ev / self.ev[-1]))**6) - if p is None: + if not isinstance(self.sampler, jVMC.sampler.ExactSampler): # Construct a soft cutoff based on the SNR regularizer *= 1. / (1. + (self.snrTol / self.snr)**6) @@ -273,14 +254,16 @@ def stop_timing(outp, name, waitFor=None): start_timing(outp, "compute Eloc") Eloc = hamiltonian.get_O_loc(sampleConfigs, psi, sampleLogPsi, t) stop_timing(outp, "compute Eloc", waitFor=Eloc) + Eloc = SampledObs( Eloc, p) # Evaluate gradients start_timing(outp, "compute gradients") sampleGradients = psi.gradients(sampleConfigs) stop_timing(outp, "compute gradients", waitFor=sampleGradients) + sampleGradients = SampledObs( sampleGradients, p) start_timing(outp, "solve TDVP eqn.") - update, solverResidual = self.solve(Eloc, sampleGradients, p) + update, solverResidual = self.solve(Eloc, sampleGradients) stop_timing(outp, "solve TDVP eqn.") if outp is not None: @@ -303,7 +286,7 @@ def stop_timing(outp, name, waitFor=None): if self.crossValidation: - if p != None: + if isinstance(self.sampler, jVMC.sampler.ExactSampler): update_1, _ = self.solve(Eloc[:, 0::2], sampleGradients[:, 0::2], p[:, 0::2]) S2, F2, _ = self.get_tdvp_equation(Eloc[:, 1::2], sampleGradients[:, 1::2], p[:, 1::2]) else: diff --git a/jVMC/util/util.py b/jVMC/util/util.py index 8794053..d174bb1 100644 --- a/jVMC/util/util.py +++ b/jVMC/util/util.py @@ -133,21 +133,16 @@ def measure(observables, psi, sampler, numSamples=None): for op in get_iterable(ops): - args=() + args = () if isinstance(op, collections.abc.Iterable): args = tuple(op[1:]) op = op[0] Oloc = op.get_O_loc(sampleConfigs, psi, sampleLogPsi, *args) - if p is not None: - tmpMeans.append(mpi.global_mean(Oloc, p)) - tmpVariances.append(mpi.global_variance(Oloc, p)) - tmpErrors.append(0.) - else: - tmpMeans.append(mpi.global_mean(Oloc)) - tmpVariances.append(mpi.global_variance(Oloc)) - tmpErrors.append(jnp.sqrt(tmpVariances[-1]) / jnp.sqrt(sampler.get_last_number_of_samples())) + tmpMeans.append(mpi.global_mean(Oloc[..., None], p)[0]) + tmpVariances.append(mpi.global_variance(Oloc[..., None], p)[0]) + tmpErrors.append(jnp.sqrt(tmpVariances[-1]) / jnp.sqrt(sampler.get_last_number_of_samples())) result[name] = {} result[name]["mean"] = jnp.real(jnp.array(tmpMeans)) diff --git a/tests/mpi_wrapper_t.py b/tests/mpi_wrapper_t.py index 6d14199..b4fbe4a 100644 --- a/tests/mpi_wrapper_t.py +++ b/tests/mpi_wrapper_t.py @@ -8,38 +8,43 @@ import numpy as np +import sys +sys.path.append(sys.path[0] + '/../') import jVMC import jVMC.mpi_wrapper as mpi import jVMC.global_defs as global_defs + def get_shape(shape): return (global_defs.device_count(),) + shape + class TestMPI(unittest.TestCase): def test_mean(self): - - data=jnp.array(np.arange(720*4*global_defs.device_count()).reshape((global_defs.device_count()*720,4))) - myNumSamples = mpi.distribute_sampling(global_defs.device_count()*720) - myData=data[mpi.rank*myNumSamples:(mpi.rank+1)*myNumSamples].reshape(get_shape((-1,4))) + data = jnp.array(np.arange(720 * 4 * global_defs.device_count()).reshape((global_defs.device_count() * 720, 4))) + myNumSamples = mpi.distribute_sampling(global_defs.device_count() * 720) + + myData = data[mpi.rank * myNumSamples:(mpi.rank + 1) * myNumSamples].reshape(get_shape((-1, 4))) - self.assertTrue( jnp.sum(mpi.global_mean(myData)-jnp.mean(data,axis=0)) < 1e-10 ) + self.assertTrue(jnp.sum(mpi.global_mean(myData, jnp.ones(myData.shape[:2]) / jnp.prod(jnp.asarray(myData.shape[:2]))) - jnp.mean(data, axis=0)) < 1e-10) def test_var(self): - - data=jnp.array(np.arange(720*4*global_defs.device_count()).reshape((global_defs.device_count()*720,4))) - myNumSamples = mpi.distribute_sampling(global_defs.device_count()*720) - myData=data[mpi.rank*myNumSamples:(mpi.rank+1)*myNumSamples].reshape(get_shape((-1,4))) + data = jnp.array(np.arange(720 * 4 * global_defs.device_count()).reshape((global_defs.device_count() * 720, 4))) + myNumSamples = mpi.distribute_sampling(global_defs.device_count() * 720) - self.assertTrue( jnp.sum(mpi.global_variance(myData)-jnp.var(data,axis=0)) < 1e-10 ) + myData = data[mpi.rank * myNumSamples:(mpi.rank + 1) * myNumSamples].reshape(get_shape((-1, 4))) + + self.assertTrue(jnp.sum(mpi.global_variance(myData, jnp.ones(myData.shape[:2]) / jnp.prod(jnp.asarray(myData.shape[:2]))) - jnp.var(data, axis=0)) < 1e-10) def test_bcast(self): - data=np.zeros(10, dtype=np.int32) + data = np.zeros(10, dtype=np.int32) with self.assertRaises(TypeError) as context: mpi.bcast_unknown_size(data) + if __name__ == "__main__": unittest.main() diff --git a/tests/sampler_t.py b/tests/sampler_t.py index 2e75684..57987bc 100644 --- a/tests/sampler_t.py +++ b/tests/sampler_t.py @@ -8,6 +8,9 @@ import numpy as np +import sys +sys.path.append(sys.path[0] + '/../') + import jVMC import jVMC.nets as nets from jVMC.vqs import NQS @@ -59,7 +62,7 @@ def test_MCMC_sampling(self): # Get samples from MCMC sampler numSamples = 500000 - smc, _, _ = mcSampler.sample(numSamples=numSamples) + smc, _, p = mcSampler.sample(numSamples=numSamples) smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) @@ -67,10 +70,56 @@ def test_MCMC_sampling(self): # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) - pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) + pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17), weights=p[0]) # Compare histogram to exact probabilities - self.assertTrue(jnp.max(jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1,))[:16])) < 2e-3) + self.assertTrue(jnp.max(jnp.abs(pmc - pex.reshape((-1,))[:16])) < 2e-3) + + def test_MCMC_sampling_with_mu(self): + mu = 1 + + L = 4 + + weights = jnp.array( + [0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853, 0.21392751, + 0.19648707, 0.17103704, -0.15457255, 0.10954413, 0.13228065, -0.14935214, + -0.09963073, 0.17610707, 0.13386381, -0.14836467] + ) + + # Set up variational wave function + rbm = nets.CpxRBM(numHidden=2, bias=False) + orbit = jVMC.util.symmetries.get_orbit_1d(L, translation=False, reflection=False, z2sym=False) + net = nets.sym_wrapper.SymNet(net=rbm, orbit=orbit) + psi = NQS(net) + + # Set up exact sampler + exactSampler = sampler.ExactSampler(psi, L) + + # Set up MCMC sampler + mcSampler = sampler.MCSampler(psi, (L,), random.PRNGKey(0), updateProposer=jVMC.sampler.propose_spin_flip, numChains=777, mu=mu) + + psi.set_parameters(weights) + + # Compute exact probabilities + _, _, pex = exactSampler.sample() + + # Get samples from MCMC sampler + numSamples = 500000 + smc, logPsi, weighting_probs = mcSampler.sample(numSamples=numSamples) + + smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) + + self.assertTrue(smc.shape[0] >= numSamples) + + # Compute histogram of sampled configurations + smcInt = jax.vmap(state_to_int)(smc) + + psi_log_basis = psi(exactSampler.basis) + + pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17), weights=weighting_probs[0, :]) + + # Compare histogram to exact probabilities + self.assertTrue(jnp.max(jnp.abs(pmc - pex.reshape((-1,))[:16])) < 2e-3) def test_MCMC_sampling_with_two_nets(self): L = 4 @@ -102,7 +151,7 @@ def test_MCMC_sampling_with_two_nets(self): # Get samples from MCMC sampler numSamples = 500000 - smc, _, _ = mcSampler.sample(numSamples=numSamples) + smc, _, p = mcSampler.sample(numSamples=numSamples) smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) @@ -110,10 +159,10 @@ def test_MCMC_sampling_with_two_nets(self): # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) - pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) + pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17), weights=p[0]) # Compare histogram to exact probabilities - self.assertTrue(jnp.max(jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1,))[:16])) < 2e-3) + self.assertTrue(jnp.max(jnp.abs(pmc - pex.reshape((-1,))[:16])) < 2e-3) def test_MCMC_sampling_with_integer_key(self): L = 4 @@ -143,7 +192,7 @@ def test_MCMC_sampling_with_integer_key(self): # Get samples from MCMC sampler numSamples = 500000 - smc, _, _ = mcSampler.sample(numSamples=numSamples) + smc, _, p = mcSampler.sample(numSamples=numSamples) smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) @@ -151,10 +200,10 @@ def test_MCMC_sampling_with_integer_key(self): # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) - pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) + pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17), weights=p[0]) # Compare histogram to exact probabilities - self.assertTrue(jnp.max(jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1,))[:16])) < 2e-3) + self.assertTrue(jnp.max(jnp.abs(pmc - pex.reshape((-1,))[:16])) < 2e-3) def test_autoregressive_sampling(self): @@ -182,9 +231,9 @@ def test_autoregressive_sampling(self): _, _, pex = exactSampler.sample() numSamples = 1000000 - smc, p, _ = mcSampler.sample(numSamples=numSamples) + smc, logPsi, p = mcSampler.sample(numSamples=numSamples) - self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12) + self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - logPsi))) < 1e-12) smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) @@ -192,9 +241,9 @@ def test_autoregressive_sampling(self): # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) - pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) + pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17), weights=p[0]) - self.assertTrue(jnp.max(jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1,))[:16])) < 1.1e-3) + self.assertTrue(jnp.max(jnp.abs(pmc - pex.reshape((-1,))[:16])) < 1.1e-3) def test_autoregressive_sampling_with_symmetries(self): @@ -219,9 +268,9 @@ def test_autoregressive_sampling_with_symmetries(self): _, logPsi, pex = exactSampler.sample() numSamples = 1000000 - smc, p, _ = mcSampler.sample(numSamples=numSamples) + smc, logPsi, p = mcSampler.sample(numSamples=numSamples) - self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12) + self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - logPsi))) < 1e-12) smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) @@ -229,9 +278,9 @@ def test_autoregressive_sampling_with_symmetries(self): # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) - pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) + pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17), weights=p[0]) - self.assertTrue(jnp.max(jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1,))[:16])) < 2e-3) + self.assertTrue(jnp.max(jnp.abs(pmc - pex.reshape((-1,))[:16])) < 2e-3) def test_autoregressive_sampling_with_lstm(self): @@ -259,9 +308,9 @@ def test_autoregressive_sampling_with_lstm(self): _, logPsi, pex = exactSampler.sample() numSamples = 1000000 - smc, p, _ = mcSampler.sample(numSamples=numSamples) + smc, logPsi, p = mcSampler.sample(numSamples=numSamples) - self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12) + self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - logPsi))) < 1e-12) smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) @@ -269,9 +318,9 @@ def test_autoregressive_sampling_with_lstm(self): # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) - pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) + pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17), weights=p[0]) - self.assertTrue(jnp.max(jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1,))[:16])) < 1e-3) + self.assertTrue(jnp.max(jnp.abs(pmc - pex.reshape((-1,))[:16])) < 1e-3) def test_autoregressive_sampling_with_gru(self): @@ -299,9 +348,9 @@ def test_autoregressive_sampling_with_gru(self): _, logPsi, pex = exactSampler.sample() numSamples = 1000000 - smc, p, _ = mcSampler.sample(numSamples=numSamples) + smc, logPsi, p = mcSampler.sample(numSamples=numSamples) - self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12) + self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - logPsi))) < 1e-12) smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) @@ -309,9 +358,9 @@ def test_autoregressive_sampling_with_gru(self): # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) - pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) + pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17), weights=p[0]) - self.assertTrue(jnp.max(jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1,))[:16])) < 1e-3) + self.assertTrue(jnp.max(jnp.abs(pmc - pex.reshape((-1,))[:16])) < 1e-3) def test_autoregressive_sampling_with_rnn2d(self): @@ -373,9 +422,9 @@ def test_autoregressive_sampling_with_rnn2d_symmetric(self): self.assertTrue(jnp.abs(jnp.sum(pex) - 1.) < 1e-12) numSamples = 1000000 - smc, p, _ = mcSampler.sample(numSamples=numSamples) + smc, logPsi, p = mcSampler.sample(numSamples=numSamples) - self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12) + self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - logPsi))) < 1e-12) smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) @@ -383,9 +432,9 @@ def test_autoregressive_sampling_with_rnn2d_symmetric(self): # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) - pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) + pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17), weights=p[0]) - self.assertTrue(jnp.max(jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1,))[:16])) < 1e-3) + self.assertTrue(jnp.max(jnp.abs(pmc - pex.reshape((-1,))[:16])) < 1e-3) def test_autoregressive_sampling_with_lstm2d(self): @@ -407,9 +456,9 @@ def test_autoregressive_sampling_with_lstm2d(self): self.assertTrue(jnp.abs(jnp.sum(pex) - 1.) < 1e-12) numSamples = 1000000 - smc, p, _ = mcSampler.sample(numSamples=numSamples) + smc, logPsi, p = mcSampler.sample(numSamples=numSamples) - self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12) + self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - logPsi))) < 1e-12) smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) @@ -417,9 +466,9 @@ def test_autoregressive_sampling_with_lstm2d(self): # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) - pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) + pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17), weights=p[0]) - self.assertTrue(jnp.max(jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1,))[:16])) < 1e-3) + self.assertTrue(jnp.max(jnp.abs(pmc - pex.reshape((-1,))[:16])) < 1e-3) if __name__ == "__main__": diff --git a/tests/stats_t.py b/tests/stats_t.py new file mode 100644 index 0000000..62e12bd --- /dev/null +++ b/tests/stats_t.py @@ -0,0 +1,32 @@ +import unittest + +import jax +import jax.numpy as jnp + +from jVMC.stats import SampledObs +import jVMC.mpi_wrapper as mpi + + +class TestStats(unittest.TestCase): + + def test_sampled_obs(self): + + Obs1Loc = jnp.array([[1,2,3]]) + Obs2Loc = jnp.array([[[1,4],[2,5],[3,7]]]) + p = (1./3) * jnp.ones(3)[None,...] + + obs1 = SampledObs(Obs1Loc, p) + obs2 = SampledObs(Obs2Loc, p) + + self.assertAlmostEqual(obs1.mean()[0], 2.) + self.assertAlmostEqual(obs1.var()[0], 2./3) + + self.assertTrue(jnp.allclose(obs2.covar(), jnp.array([[2./3, 1],[1.,14./9]]))) + + self.assertTrue(jnp.allclose(obs2.mean(), jnp.array([2,16./3]))) + self.assertTrue(jnp.allclose(obs1.covar(obs2), jnp.array([2./3,1.]))) + + self.assertTrue(jnp.allclose(obs1.covar(obs2), obs1.covar_data(obs2).mean())) + + self.assertTrue(jnp.allclose(obs1.covar_var(obs2), obs1.covar_data(obs2).var())) + diff --git a/tests/tdvp_t.py b/tests/tdvp_t.py index 82e0143..b8e6e92 100644 --- a/tests/tdvp_t.py +++ b/tests/tdvp_t.py @@ -17,6 +17,7 @@ from jVMC.vqs import NQS import jVMC.operator as op import jVMC.sampler as sampler +import jVMC.mpi_wrapper as mpi from jVMC.util import measure, ground_state_search @@ -128,5 +129,142 @@ def test_time_evolution(self): self.assertTrue(np.max(np.abs(netZZ - refZZ[:len(netZZ)])) < 1e-3) +class TestTimeEvolutionMCSampler(unittest.TestCase): + def test_time_evolution(self): + L = 4 + J = -1.0 + hx = -0.3 + + weights = jnp.array( + [0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853, 0.21392751, + 0.19648707, 0.17103704, -0.15457255, 0.10954413, 0.13228065, -0.14935214, + -0.09963073, 0.17610707, 0.13386381, -0.14836467] + ) + + # Set up variational wave function + rbm = nets.CpxRBM(numHidden=2, bias=False) + orbit = jVMC.util.symmetries.get_orbit_1d(L, translation=False, reflection=False, z2sym=False) + net = nets.sym_wrapper.SymNet(net=rbm, orbit=orbit) + psi = NQS(net, batchSize=5000) + psi(jnp.array([[[1, 1, 1, 1]]])) + psi.set_parameters(weights) + + # Set up hamiltonian for time evolution + hamiltonian = op.BranchFreeOperator() + for l in range(L): + hamiltonian.add(op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L)))) + hamiltonian.add(op.scal_opstr(hx, (op.Sx(l), ))) + + # Set up ZZ observable + ZZ = op.BranchFreeOperator() + for l in range(L): + ZZ.add((op.Sz(l), op.Sz((l + 1) % L))) + + # Set up exact sampler + MCsampler = sampler.MCSampler(psi, (L,), jax.random.PRNGKey(0), numSamples=50000, updateProposer=sampler.propose_spin_flip, mu=1, numChains=500) + + # Set up adaptive time stepper + stepper = jVMCstepper.AdaptiveHeun(timeStep=1e-3, tol=1e-4) + + tdvpEquation = jVMC.util.TDVP(MCsampler, snrTol=1, svdTol=1e-8, rhsPrefactor=1.j, diagonalShift=0., makeReal='imag') + + t = 0 + obs = [] + times = [] + times.append(t) + newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, MCsampler) + obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']]) + while t < 0.2: + dp, dt = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi, numSamples=5000) + psi.set_parameters(dp) + t += dt + times.append(t) + newMeas = measure({'E': [(hamiltonian, t)], 'ZZ': ZZ}, psi, MCsampler) + obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']]) + + obs = np.array(jnp.asarray(obs)) + + # Check energy conservation + obs[:, 0] = np.abs((obs[:, 0] - obs[0, 0]) / obs[0, 0]) + self.assertTrue(np.max(obs[:, 0]) < 1e-1) + + # Check observable dynamics + zz = interp1d(np.array(times), obs[:, 1, 0]) + refTimes = np.arange(0, 0.2, 0.05) + netZZ = zz(refTimes) + refZZ = np.array( + [0.882762129306284, 0.8936168721790617, 0.9257753299594491, 0.9779836185039352, 1.0482156449061142, + 1.1337654450614298, 1.231369697427413, 1.337354107391303, 1.447796176316155, 1.558696104640795, + 1.666147269524912, 1.7664978782554912, 1.8564960156892512, 1.9334113379450693, 1.9951280521882777, + 2.0402054805651546, 2.067904337137255, 2.078178742959828, 2.071635856483114, 2.049466698269522, 2.049466698269522] + ) + self.assertTrue(np.max(np.abs(netZZ - refZZ[:len(netZZ)])) < 2e-2) + + +class TestSNRConsistency(unittest.TestCase): + def test_snr_consistency(self): + L = 4 + J = -1.0 + hx = -0.3 + + weights = jnp.array( + [0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853, 0.21392751, + 0.19648707, 0.17103704, -0.15457255, 0.10954413, 0.13228065, -0.14935214, + -0.09963073, 0.17610707, 0.13386381, -0.14836467] + ) + + # Set up variational wave function + rbm = nets.CpxRBM(numHidden=2, bias=False) + orbit = jVMC.util.symmetries.get_orbit_1d(L, translation=False, reflection=False, z2sym=False) + net = nets.sym_wrapper.SymNet(net=rbm, orbit=orbit) + psi = NQS(net, batchSize=5000) + psi(jnp.array([[[1, 1, 1, 1]]])) + psi.set_parameters(weights) + + # Set up hamiltonian for time evolution + hamiltonian = op.BranchFreeOperator() + for l in range(L): + hamiltonian.add(op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L)))) + hamiltonian.add(op.scal_opstr(hx, (op.Sx(l), ))) + + # Set up exact sampler + sampler = jVMC.sampler.MCSampler(psi, (L,), jax.random.PRNGKey(0), numSamples=10, updateProposer=jVMC.sampler.propose_spin_flip, mu=2, numChains=500) + + # Get sample + sampleConfigs, sampleLogPsi, p = sampler.sample() + + # Evaluate local energy + Eloc_old = hamiltonian.get_O_loc(sampleConfigs, psi, sampleLogPsi, 0.0) + Eloc = jVMC.stats.SampledObs( Eloc_old, p) + + # Evaluate gradients + sampleGradients_old = psi.gradients(sampleConfigs) + sampleGradients = jVMC.stats.SampledObs( sampleGradients_old, p) + + self.F0 = (-1.j) * sampleGradients.covar(Eloc).ravel() + S = jVMC.util.imagFun( sampleGradients.covar() ) + + ev, V = jnp.linalg.eigh(S) + + # old version + subtract_helper_Eloc = global_defs.pmap_for_my_devices(lambda x, y: x - y, in_axes=(0, None)) + subtract_helper_grad = global_defs.pmap_for_my_devices(lambda x, y: x - y, in_axes=(0, None)) + transform_EO = global_defs.pmap_for_my_devices(lambda eo, v: jnp.matmul(eo, jnp.conj(v)), in_axes=(0, None)) + get_EO = global_defs.pmap_for_my_devices(lambda f, p, Eloc, grad: -f * jnp.multiply((p * Eloc)[:, None], jnp.conj(grad)), + in_axes=(None, 0, 0, 0), static_broadcasted_argnums=(0)) + Eloc_old = subtract_helper_Eloc(Eloc_old, mpi.global_mean(Eloc_old, p)) + sampleGradients_old = subtract_helper_grad(sampleGradients_old, mpi.global_mean(sampleGradients_old, p)) + EOdata = get_EO(-1., p, Eloc_old, sampleGradients_old) * mpi.globNumSamples + EOdata = transform_EO(jVMC.util.imagFun(EOdata), V) + EOdata.block_until_ready() + rhoVar_old = mpi.global_variance(EOdata, jnp.ones(EOdata.shape[:2]) / mpi.globNumSamples) + + EO = sampleGradients.covar_data(Eloc).transform(fun=lambda x: jnp.matmul(jnp.transpose(jnp.conj(V)), jVMC.util.imagFun(x))) + rhoVar_new = EO.var().ravel() + + self.assertTrue( jnp.allclose(rhoVar_old, rhoVar_new) ) + + + if __name__ == "__main__": unittest.main()