Skip to content

Commit

Permalink
Merge pull request #53 from markusschmitt/test_cpx_nonholo
Browse files Browse the repository at this point in the history
Added test for complex non-holomorphic net.
  • Loading branch information
markusschmitt authored Jun 11, 2023
2 parents df3e622 + f7b98b7 commit 1edd881
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tests/vqs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ def __call__(self, s):

# ** end class Simple_nonHolomorphic

class MatrixMultiplication_NonHolomorphic(nn.Module):
holo: bool = False

@nn.compact
def __call__(self, s):
layer1 = nn.Dense(1, use_bias=False, **jVMC.nets.initializers.init_fn_args(dtype=global_defs.tCpx))
out = layer1(2 * s.ravel() - 1)
if not self.holo:
out = out + 1e-1 * jnp.real(out)
return jnp.sum(out)

# ** end class MatrixMultiplication_NonHolomorphic


class TestGradients(unittest.TestCase):

Expand Down Expand Up @@ -241,6 +254,20 @@ def test_gradients_complex_nonholomorphic(self):
with self.subTest(i=j):
self.assertTrue( jnp.max( jnp.abs( Gfd - G[...,j] ) ) < 1e-2 )

for ds in dlist:

global_defs.set_pmap_devices(ds)

model = MatrixMultiplication_NonHolomorphic(holo=False)

s=jnp.zeros(get_shape((1,4)),dtype=np.int32)

psi = NQS(model)
psi0 = psi(s)
G = psi.gradients(s)
ref = jnp.array([-1.1+0.j, -1.1+0.j, -1.1+0.j, -1.1+0.j, -0.-1.j, -0.-1.j, -0.-1.j, -0.-1.j])
self.assertTrue( jnp.allclose( G.ravel(), ref ) )


def test_gradient_dict(self):

Expand Down

0 comments on commit 1edd881

Please sign in to comment.