Skip to content

Commit

Permalink
sample() tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gplepage committed Mar 5, 2024
1 parent ee07055 commit aea62e1
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions tests/test_gvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,40 +1376,40 @@ def test_sample(self):
sample(gvar([1,1], [[1,1],[1,1]]))


def test_batch_sample(self): ### makes no sense!
" sample(g, nbatch=...) raniter(g, nbatch=...)"
# dictionary
g = gvar(BufferDict(s='1(1)', a=[['1(1)','1(1)','1(1)']]))
nbatch = 5
ranseed(1)
sl = sample(g, nbatch=nbatch, mode='lbatch')
ranseed(1)
sr = sample(g, nbatch=nbatch, mode='rbatch')
for k in g:
self.assertTrue(sl[k].shape[0] == sr[k].shape[-1] == nbatch)
np.testing.assert_allclose(np.sum(sl.flat), np.sum(sr.flat))
for s in sl.batch_iter('lbatch'):
self.assertLess(chi2(s, g) / g.size, 10.)
for s in sr.batch_iter('rbatch'):
self.assertLess(chi2(s, g) / g.size, 10.)
# array
ranseed(1)
sl = sample(g['a'], nbatch=nbatch, mode='lbatch')
ranseed(1)
sr = sample(g['a'], nbatch=nbatch, mode='rbatch')
self.assertTrue(sl.shape[0] == sr.shape[-1] == nbatch)
np.testing.assert_allclose(np.sum(sl.flat), np.sum(sr.flat))
for s in sl:
self.assertLess(chi2(s, g['a']) / g['a'].size, 10.)
# gvar
ranseed(1)
sl = sample(g['s'], nbatch=nbatch, mode='lbatch')
ranseed(1)
sr = sample(g['s'], nbatch=nbatch, mode='rbatch')
self.assertTrue(sl.shape[0] == sr.shape[-1] == nbatch)
self.assertEqual(list(sl), list(sr))
for s in sl:
self.assertLess(chi2(s, g['s']), 10.)
# def test_batch_sample(self): ### makes no sense!
# " sample(g, nbatch=...) raniter(g, nbatch=...)"
# # dictionary
# g = gvar(BufferDict(s='1(1)', a=[['1(1)','1(1)','1(1)']]))
# nbatch = 5
# ranseed(1)
# sl = sample(g, nbatch=nbatch, mode='lbatch')
# ranseed(1)
# sr = sample(g, nbatch=nbatch, mode='rbatch')
# for k in g:
# self.assertTrue(sl[k].shape[0] == sr[k].shape[-1] == nbatch)
# np.testing.assert_allclose(np.sum(sl.flat), np.sum(sr.flat))
# for s in sl.batch_iter('lbatch'):
# self.assertLess(chi2(s, g) / g.size, 10.)
# for s in sr.batch_iter('rbatch'):
# self.assertLess(chi2(s, g) / g.size, 10.)
# # array
# ranseed(1)
# sl = sample(g['a'], nbatch=nbatch, mode='lbatch')
# ranseed(1)
# sr = sample(g['a'], nbatch=nbatch, mode='rbatch')
# self.assertTrue(sl.shape[0] == sr.shape[-1] == nbatch)
# np.testing.assert_allclose(np.sum(sl.flat), np.sum(sr.flat))
# for s in sl:
# self.assertLess(chi2(s, g['a']) / g['a'].size, 10.)
# # gvar
# ranseed(1)
# sl = sample(g['s'], nbatch=nbatch, mode='lbatch')
# ranseed(1)
# sr = sample(g['s'], nbatch=nbatch, mode='rbatch')
# self.assertTrue(sl.shape[0] == sr.shape[-1] == nbatch)
# self.assertEqual(list(sl), list(sr))
# for s in sl:
# self.assertLess(chi2(s, g['s']), 10.)

@unittest.skipIf(FAST,"skipping test_gvar_from_sample for speed")
def test_gvar_from_sample(self):
Expand Down

0 comments on commit aea62e1

Please sign in to comment.