Skip to content

Commit

Permalink
Merge pull request #7040 from jakevdp:fix-random-test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 380658330
  • Loading branch information
jax authors committed Jun 21, 2021
2 parents 3f768ee + 71bf410 commit 3787b56
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,30 @@ def _CheckKolmogorovSmirnovCDF(self, samples, cdf):

def _CheckChiSquared(self, samples, pmf):
alpha = 0.01 # significance level, threshold for p-value
values, actual_freq = np.unique(samples, return_counts=True)

# scipy.stats.chisquare requires the sum of expected and actual to
# match; this is only the case if we compute the expected frequency
# at *all* nonzero values of the pmf. We don't know this a priori,
# so we add extra values past the largest observed value. The number
# below is empirically enough to get full coverage for the current set
# of tests. If a new test is added where this is not enough, chisquare()
# below will error due to the sums of the inputs not matching.
extra_values = 100
actual_freq = np.bincount(samples, minlength=samples.max() + extra_values)
values = np.arange(len(actual_freq))

expected_freq = pmf(values) * samples.size
# per scipy: "A typical rule is that all of the observed and expected
# frequencies should be at least 5."
valid = (actual_freq > 5) & (expected_freq > 5)
self.assertGreater(valid.sum(), 1,
msg='not enough valid frequencies for chi-squared test')
_, p_value = scipy.stats.chisquare(
actual_freq[valid], expected_freq[valid])

valid = expected_freq > 0
actual_freq = actual_freq[valid]
expected_freq = expected_freq[valid]

_, p_value = scipy.stats.chisquare(actual_freq, expected_freq)
self.assertGreater(
p_value, alpha,
msg=f'Failed chi-squared test with p={p_value}.\n'
'Expected vs. actual frequencies:\n'
f'{expected_freq[valid]}\n{actual_freq[valid]}')
f'{expected_freq}\n{actual_freq}')

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
Expand Down Expand Up @@ -390,9 +400,11 @@ def testCategorical(self, p, axis, dtype, sample_shape):
if len(p.shape[:-1]) > 0:
ps = np.transpose(p, (1, 0)) if axis == 0 else p
for cat_samples, cat_p in zip(samples.transpose(), ps):
self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x])
pmf = lambda x: np.where(x < len(cat_p), cat_p[np.minimum(len(cat_p) - 1, x)], 0.0)
self._CheckChiSquared(cat_samples, pmf=pmf)
else:
self._CheckChiSquared(samples, pmf=lambda x: p[x])
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
self._CheckChiSquared(samples, pmf=pmf)

def testBernoulliShape(self):
key = random.PRNGKey(0)
Expand Down Expand Up @@ -538,7 +550,7 @@ def testPoisson(self, lam, dtype):
self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)

def testPoissonBatched(self):
key = random.PRNGKey(0)
key = random.PRNGKey(1)
lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
samples = random.poisson(key, lam, shape=(20000,))
self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
Expand Down

0 comments on commit 3787b56

Please sign in to comment.