Skip to content

Commit

Permalink
Merge pull request #61 from markusschmitt/fix_device_assignment
Browse files Browse the repository at this point in the history
Pmap devices in `stats` class
  • Loading branch information
markusschmitt authored Jul 3, 2023
2 parents 4a4529d + e2b5b0e commit 82a5525
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
22 changes: 11 additions & 11 deletions jVMC/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def jit_my_stuff():

statsPmapDevices = global_defs.myPmapDevices

_mean_helper = pmap_for_my_devices(lambda data, w: jnp.tensordot(w, data, axes=(0,0)), in_axes=(0, 0))
_data_prep = pmap_for_my_devices(lambda data, w, mean: jax.vmap(lambda d, w, m: jnp.sqrt(w) * (d - m), in_axes=(0,0,None))(data, w, mean), in_axes=(0, 0, None))
_covar_helper = pmap_for_my_devices(
_mean_helper = jVMC.global_defs.pmap_for_my_devices(lambda data, w: jnp.tensordot(w, data, axes=(0,0)), in_axes=(0, 0))
_data_prep = jVMC.global_defs.pmap_for_my_devices(lambda data, w, mean: jax.vmap(lambda d, w, m: jnp.sqrt(w) * (d - m), in_axes=(0,0,None))(data, w, mean), in_axes=(0, 0, None))
_covar_helper = jVMC.global_defs.pmap_for_my_devices(
lambda data1, data2:
jnp.tensordot(
jnp.conj(data1),
data2, axes=(0,0)),
in_axes=(0, 0)
)
_covar_var_helper = pmap_for_my_devices(
_covar_var_helper = jVMC.global_defs.pmap_for_my_devices(
lambda data1, data2, w:
jnp.sum(
jnp.abs(
Expand All @@ -60,14 +60,14 @@ def jit_my_stuff():
axis=0),
in_axes=(0, 0, 0)
)
_covar_data_helper = pmap_for_my_devices(lambda data1, data2, w: jax.vmap(lambda a,b,w: jnp.outer(a,b) / w)(jnp.conj(data1), data2, w), in_axes=(0, 0, 0))
_trafo_helper_1 = pmap_for_my_devices(
_covar_data_helper = jVMC.global_defs.pmap_for_my_devices(lambda data1, data2, w: jax.vmap(lambda a,b,w: jnp.outer(a,b) / w)(jnp.conj(data1), data2, w), in_axes=(0, 0, 0))
_trafo_helper_1 = jVMC.global_defs.pmap_for_my_devices(
lambda data, w, mean, f: f(
jax.vmap(lambda x,y: x/jnp.sqrt(y), in_axes=(0,0))(data, w)
+ mean
),
in_axes=(0, 0, None), static_broadcasted_argnums=(3,))
_trafo_helper_2 = pmap_for_my_devices(
_trafo_helper_2 = jVMC.global_defs.pmap_for_my_devices(
lambda data, w, mean, v, f:
jnp.matmul(v,
f(
Expand All @@ -76,10 +76,10 @@ def jit_my_stuff():
)
),
in_axes=(0, 0, None, None), static_broadcasted_argnums=(4,))
_select_helper = pmap_for_my_devices( lambda ix,g: jax.vmap(lambda ix,g: g[ix], in_axes=(None, 0))(ix,g), in_axes=(None, 0) )
_get_subset_helper = pmap_for_my_devices(lambda x, ixs: x[slice(*ixs)], in_axes=(0,), static_broadcasted_argnums=(1,))
_subset_mean_helper = pmap_for_my_devices(lambda d, w, m: jnp.tensordot(jnp.sqrt(w), d, axes=(0,0)) + m, in_axes=(0,0,None))
_subset_data_prep = pmap_for_my_devices(jax.vmap(lambda d, w, m1, m2: d+jnp.sqrt(w)*(m1-m2), in_axes=(0,0,None,None)), in_axes=(0,0,None,None))
_select_helper = jVMC.global_defs.pmap_for_my_devices( lambda ix,g: jax.vmap(lambda ix,g: g[ix], in_axes=(None, 0))(ix,g), in_axes=(None, 0) )
_get_subset_helper = jVMC.global_defs.pmap_for_my_devices(lambda x, ixs: x[slice(*ixs)], in_axes=(0,), static_broadcasted_argnums=(1,))
_subset_mean_helper = jVMC.global_defs.pmap_for_my_devices(lambda d, w, m: jnp.tensordot(jnp.sqrt(w), d, axes=(0,0)) + m, in_axes=(0,0,None))
_subset_data_prep = jVMC.global_defs.pmap_for_my_devices(jax.vmap(lambda d, w, m1, m2: d+jnp.sqrt(w)*(m1-m2), in_axes=(0,0,None,None)), in_axes=(0,0,None,None))


class SampledObs():
Expand Down
2 changes: 1 addition & 1 deletion jVMC/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Current jVMC version at head on Github."""
__version__ = "1.2.2"
__version__ = "1.2.3"

0 comments on commit 82a5525

Please sign in to comment.