diff --git a/tests/vqs_test.py b/tests/vqs_test.py index 446a207..639fdff 100644 --- a/tests/vqs_test.py +++ b/tests/vqs_test.py @@ -67,6 +67,19 @@ def __call__(self, s): # ** end class Simple_nonHolomorphic +class MatrixMultiplication_NonHolomorphic(nn.Module): + holo: bool = False + + @nn.compact + def __call__(self, s): + layer1 = nn.Dense(1, use_bias=False, **jVMC.nets.initializers.init_fn_args(dtype=global_defs.tCpx)) + out = layer1(2 * s.ravel() - 1) + if not self.holo: + out = out + 1e-1 * jnp.real(out) + return jnp.sum(out) + +# ** end class MatrixMultiplication_NonHolomorphic + class TestGradients(unittest.TestCase): @@ -241,6 +254,20 @@ def test_gradients_complex_nonholomorphic(self): with self.subTest(i=j): self.assertTrue( jnp.max( jnp.abs( Gfd - G[...,j] ) ) < 1e-2 ) + for ds in dlist: + + global_defs.set_pmap_devices(ds) + + model = MatrixMultiplication_NonHolomorphic(holo=False) + + s=jnp.zeros(get_shape((1,4)),dtype=np.int32) + + psi = NQS(model) + psi0 = psi(s) + G = psi.gradients(s) + ref = jnp.array([-1.1+0.j, -1.1+0.j, -1.1+0.j, -1.1+0.j, -0.-1.j, -0.-1.j, -0.-1.j, -0.-1.j]) + self.assertTrue( jnp.allclose( G.ravel(), ref ) ) + def test_gradient_dict(self):