Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stats.py for MPI #68

Merged
merged 2 commits into from
Dec 11, 2023
Merged

Conversation

tszoldra
Copy link
Contributor

@tszoldra tszoldra commented Dec 1, 2023

There was an error in tensor indexing in the stats.SampledObs calculations of the global sums. Estimation of mean, variance etc. worked only if the leading dimension of observations and weights was equal to 1. Otherwise, for TestStats.test_sampled_obs test with the leading dimension of the input array equal to jVMC.global_defs.device_count() greater than 1 an error occurred:

Ran 1 test in 0.107s

FAILED (errors=1)

Error
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/c8888/dev/vmc_jax/tests/stats_test.py", line 18, in test_sampled_obs
    obs1 = SampledObs(Obs1Loc, p)
  File "/home/c8888/dev/vmc_jax/jVMC/stats.py", line 120, in __init__
    self._mean = mpi.global_sum(_mean_helper(observations, self._weights)[None, ...])
  File "/home/c8888/dev/vmc_jax/jVMC/mpi_wrapper.py", line 135, in global_sum
    localSum = np.array(_sum_up_pmapd(data)[0])
ValueError: Leading axis size of input to pmapped function must equal the number of local devices passed to pmap. Got axis_size=1, num_local_devices=16.
(Local devices available to pmap: TFRT_CPU_0, TFRT_CPU_1, TFRT_CPU_2, TFRT_CPU_3, TFRT_CPU_4, TFRT_CPU_5, TFRT_CPU_6, TFRT_CPU_7, TFRT_CPU_8, TFRT_CPU_9, TFRT_CPU_10, TFRT_CPU_11, TFRT_CPU_12, TFRT_CPU_13, TFRT_CPU_14, TFRT_CPU_15)

This made it impossible to run SR, minSR etc with multiple devices.

It is easy to reproduce these errors without many physical devices by setting the following environment variables before running the test to fake multiple devices:

XLA_FLAGS=--xla_force_host_platform_device_count=16;JAX_PLATFORM_NAME=cpu

In this pull request I fixed the indexing in stats.SampledObs and slightly modified the test for stats.

I am not sure if SampledObs.subset should also be changed.

@markusschmitt
Copy link
Owner

Here we have the same Jax version issues that break the tests as in PR #66. Can you fix the setup file for this one too?

@markusschmitt markusschmitt merged commit bffcb52 into markusschmitt:master Dec 11, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants