From 71bf410055d2682c4338dedfd7345b097fd9b710 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 21 Jun 2021 13:44:54 -0700 Subject: [PATCH] Fix random test for scipy>=1.17.0 --- tests/random_test.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/tests/random_test.py b/tests/random_test.py index c3c64a6a04ca..c43025a0f85c 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -61,20 +61,30 @@ def _CheckKolmogorovSmirnovCDF(self, samples, cdf): def _CheckChiSquared(self, samples, pmf): alpha = 0.01 # significance level, threshold for p-value - values, actual_freq = np.unique(samples, return_counts=True) + + # scipy.stats.chisquare requires the sum of expected and actual to + # match; this is only the case if we compute the expected frequency + # at *all* nonzero values of the pmf. We don't know this a priori, + # so we add extra values past the largest observed value. The number + # below is empirically enough to get full coverage for the current set + # of tests. If a new test is added where this is not enough, chisquare() + # below will error due to the sums of the inputs not matching. + extra_values = 100 + actual_freq = np.bincount(samples, minlength=samples.max() + extra_values) + values = np.arange(len(actual_freq)) + expected_freq = pmf(values) * samples.size - # per scipy: "A typical rule is that all of the observed and expected - # frequencies should be at least 5." - valid = (actual_freq > 5) & (expected_freq > 5) - self.assertGreater(valid.sum(), 1, - msg='not enough valid frequencies for chi-squared test') - _, p_value = scipy.stats.chisquare( - actual_freq[valid], expected_freq[valid]) + + valid = expected_freq > 0 + actual_freq = actual_freq[valid] + expected_freq = expected_freq[valid] + + _, p_value = scipy.stats.chisquare(actual_freq, expected_freq) self.assertGreater( p_value, alpha, msg=f'Failed chi-squared test with p={p_value}.\n' 'Expected vs. actual frequencies:\n' - f'{expected_freq[valid]}\n{actual_freq[valid]}') + f'{expected_freq}\n{actual_freq}') @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} @@ -390,9 +400,11 @@ def testCategorical(self, p, axis, dtype, sample_shape): if len(p.shape[:-1]) > 0: ps = np.transpose(p, (1, 0)) if axis == 0 else p for cat_samples, cat_p in zip(samples.transpose(), ps): - self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x]) + pmf = lambda x: np.where(x < len(cat_p), cat_p[np.minimum(len(cat_p) - 1, x)], 0.0) + self._CheckChiSquared(cat_samples, pmf=pmf) else: - self._CheckChiSquared(samples, pmf=lambda x: p[x]) + pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0) + self._CheckChiSquared(samples, pmf=pmf) def testBernoulliShape(self): key = random.PRNGKey(0) @@ -538,7 +550,7 @@ def testPoisson(self, lam, dtype): self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False) def testPoissonBatched(self): - key = random.PRNGKey(0) + key = random.PRNGKey(1) lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)]) samples = random.poisson(key, lam, shape=(20000,)) self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)