-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
random_test.py
1623 lines (1380 loc) · 67.1 KB
/
random_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from unittest import SkipTest, skipIf
from typing import Any, Tuple, NamedTuple, Optional
import zlib
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import scipy.linalg
import scipy.special
import scipy.stats
import jax
from jax import core
from jax import dtypes
from jax import grad
from jax import lax
from jax import numpy as jnp
from jax import prng
from jax import random
from jax._src import test_util as jtu
from jax import vmap
from jax.interpreters import xla
import jax._src.random
from jax.config import config
config.parse_flags_with_absl()
float_dtypes = jtu.dtypes.all_floating
complex_dtypes = jtu.dtypes.complex
int_dtypes = jtu.dtypes.all_integer
uint_dtypes = jtu.dtypes.all_unsigned
def _prng_key_as_array(key):
# TODO(frostig): remove once we upgrade to always enable_custom_prng
return key.unsafe_raw_array() if config.jax_enable_custom_prng else key
PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl),
('rbg', prng.rbg_prng_impl),
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]
class RandomValuesCase(NamedTuple):
name: str
prng_impl: str
shape: Tuple[int]
dtype: Any
params: dict
expected: np.ndarray
skip_on_x64: bool = False
atol: Optional[float] = None
rtol: Optional[float] = None
def _testname(self):
if self.dtype is None:
shape_dtype = str(self.shape)
else:
shape_dtype = jtu.format_shape_dtype_string(self.shape, self.dtype)
name = f"_{self.name}_{self.prng_impl}_{shape_dtype}"
if self.params:
fmt = lambda x: str(x).replace(' ', '').replace('\n', '')
name += "_" + "_".join(f"{k}={fmt(v)}" for k, v in self.params.items())
return name
def _seed(self):
# Generate a deterministic unique 32-bit seed given the name and prng impl
return zlib.adler32((self.name + self.prng_impl).encode())
_RANDOM_VALUES_CASES = [
# TODO(jakevdp) add coverage for other distributions.
RandomValuesCase("bernoulli", "threefry2x32", (5,), None, {'p': 0.5},
np.array([False, True, True, True, False]), skip_on_x64=True),
RandomValuesCase("bernoulli", "rbg", (5,), None, {'p': 0.5},
np.array([True, True, True, True, True]), skip_on_x64=True),
RandomValuesCase("beta", "threefry2x32", (5,), np.float32, {'a': 0.8, 'b': 0.9},
np.array([0.533685, 0.843179, 0.063495, 0.573444, 0.459514], dtype='float32')),
RandomValuesCase("beta", "rbg", (5,), np.float32, {'a': 0.8, 'b': 0.9},
np.array([0.841308, 0.669989, 0.731763, 0.985127, 0.022745], dtype='float32')),
RandomValuesCase("cauchy", "threefry2x32", (5,), np.float32, {},
np.array([ -0.088416, -10.169713, 3.49677, -1.18056, 0.34556], dtype='float32'), rtol=1E-5),
RandomValuesCase("cauchy", "rbg", (5,), np.float32, {},
np.array([0.008389, 0.108793, -0.031826, -0.01876, 0.963218], dtype='float32')),
RandomValuesCase("dirichlet", "threefry2x32", (2,), np.float32, {'alpha': np.array([0.5, 0.6, 0.7], dtype='float32')},
np.array([[0.556287, 0.304219, 0.139494], [0.15221 , 0.632251, 0.21554]], dtype='float32')),
RandomValuesCase("dirichlet", "rbg", (2,), np.float32, {'alpha': np.array([0.5, 0.6, 0.7], dtype='float32')},
np.array([[0.024769, 0.002189, 0.973041], [0.326, 0.00244, 0.67156]], dtype='float32')),
RandomValuesCase("double_sided_maxwell", "threefry2x32", (5,), np.float32, {"loc": 1, "scale": 2},
np.array([-2.408914, -3.370437, 3.235352, -0.907734, -1.708732], dtype='float32'), skip_on_x64=True),
RandomValuesCase("double_sided_maxwell", "rbg", (5,), np.float32, {"loc": 1, "scale": 2},
np.array([4.957495, 3.003086, 5.33935, 2.942878, -1.203524], dtype='float32'), skip_on_x64=True),
RandomValuesCase("exponential", "threefry2x32", (5,), np.float32, {},
np.array([0.526067, 0.043046, 0.039932, 0.46427 , 0.123886], dtype='float32')),
RandomValuesCase("exponential", "rbg", (5,), np.float32, {},
np.array([0.231303, 0.684814, 0.017181, 0.089552, 0.345087], dtype='float32')),
RandomValuesCase("gamma", "threefry2x32", (5,), np.float32, {'a': 0.8},
np.array([0.332641, 0.10187 , 1.816109, 0.023457, 0.487853], dtype='float32')),
RandomValuesCase("gamma", "rbg", (5,), np.float32, {'a': 0.8},
np.array([0.235293, 0.446747, 0.146372, 0.79252 , 0.294762], dtype='float32')),
RandomValuesCase("gumbel", "threefry2x32", (5,), np.float32, {},
np.array([2.06701, 0.911726, 0.145736, 0.185427, -0.00711], dtype='float32')),
RandomValuesCase("gumbel", "rbg", (5,), np.float32, {},
np.array([-0.099308, -1.123809, 1.007618, -0.077968, 3.421349], dtype='float32')),
RandomValuesCase("laplace", "threefry2x32", (5,), np.float32, {},
np.array([0.578939, -0.204902, 0.555733, 0.911053, -0.96456], dtype='float32')),
RandomValuesCase("laplace", "rbg", (5,), np.float32, {},
np.array([-2.970422, 1.925082, -0.757887, -4.444797, 0.561983], dtype='float32')),
RandomValuesCase("loggamma", "threefry2x32", (5,), np.float32, {'a': 0.8},
np.array([-0.899633, -0.424083, 0.631593, 0.102374, -1.07189], dtype='float32')),
RandomValuesCase("loggamma", "rbg", (5,), np.float32, {'a': 0.8},
np.array([-1.333825, 0.287259, -0.343074, -0.998258, -0.773598], dtype='float32')),
RandomValuesCase("logistic", "threefry2x32", (5,), np.float32, {},
np.array([0.19611, -1.709053, -0.274093, -0.208322, -1.675489], dtype='float32')),
RandomValuesCase("logistic", "rbg", (5,), np.float32, {},
np.array([-0.234923, -0.545184, 0.700992, -0.708609, -1.474884], dtype='float32')),
RandomValuesCase("maxwell", "threefry2x32", (5,), np.float32, {},
np.array([3.070779, 0.908479, 1.521317, 0.875551, 1.306137], dtype='float32')),
RandomValuesCase("maxwell", "rbg", (5,), np.float32, {},
np.array([2.048746, 0.470027, 1.053105, 1.01969, 2.710645], dtype='float32')),
RandomValuesCase("multivariate_normal", "threefry2x32", (2,), np.float32, {"mean": np.ones((1, 3)), "cov": np.eye(3)},
np.array([[ 1.067826, 1.215599, 0.234166], [-0.237534, 1.32591, 1.413987]], dtype='float32'), skip_on_x64=True),
RandomValuesCase("multivariate_normal", "rbg", (2,), np.float32, {"mean": np.ones((1, 3)), "cov": np.eye(3)},
np.array([[-0.036897, 0.770969, 0.756959], [1.755091, 2.350553, 0.627142]], dtype='float32'), skip_on_x64=True),
RandomValuesCase("normal", "threefry2x32", (5,), np.float32, {},
np.array([-1.173234, -1.511662, 0.070593, -0.099764, 1.052845], dtype='float32')),
RandomValuesCase("normal", "rbg", (5,), np.float32, {},
np.array([-0.479658, 0.565747, -1.065106, 0.997962, -1.478002], dtype='float32')),
RandomValuesCase("pareto", "threefry2x32", (5,), np.float32, {"b": 0.5},
np.array([2.751398, 1.281863, 87.85448, 1.254542, 2.824487], dtype='float32')),
RandomValuesCase("pareto", "rbg", (5,), np.float32, {"b": 0.5},
np.array([1.241914, 1.521864, 5.615384, 1911.502, 1.816702], dtype='float32')),
RandomValuesCase("poisson", "threefry2x32", (5,), np.int32, {"lam": 5},
np.array([7, 3, 6, 11, 6], dtype='int32')),
# Note: poisson not implemented for rbg sampler.
RandomValuesCase("rademacher", "threefry2x32", (5,), np.int32, {},
np.array([-1, -1, -1, -1, 1], dtype='int32'), skip_on_x64=True),
RandomValuesCase("rademacher", "rbg", (5,), np.int32, {},
np.array([1, 1, 1, -1, -1], dtype='int32'), skip_on_x64=True),
RandomValuesCase("randint", "threefry2x32", (5,), np.int32, {"minval": 0, "maxval": 10},
np.array([0, 5, 7, 7, 5], dtype='int32')),
RandomValuesCase("randint", "rbg", (5,), np.int32, {"minval": 0, "maxval": 10},
np.array([7, 1, 8, 5, 8], dtype='int32')),
RandomValuesCase("truncated_normal", "threefry2x32", (5,), np.float32, {"lower": 0, "upper": 2},
np.array([0.582807, 1.709771, 0.159513, 0.861376, 0.36148], dtype='float32')),
RandomValuesCase("truncated_normal", "rbg", (5,), np.float32, {"lower": 0, "upper": 2},
np.array([0.770068, 1.516464, 0.710406, 0.762801, 1.305324], dtype='float32')),
RandomValuesCase("uniform", "threefry2x32", (5,), np.float32, {},
np.array([0.298671, 0.073213, 0.873356, 0.260549, 0.412797], dtype='float32')),
RandomValuesCase("uniform", "rbg", (5,), np.float32, {},
np.array([0.477161, 0.706508, 0.656261, 0.432547, 0.057772], dtype='float32')),
RandomValuesCase("weibull_min", "threefry2x32", (5,), np.float32, {"scale": 1, "concentration": 1},
np.array([1.605863, 0.841809, 0.224218, 0.4826 , 0.027901], dtype='float32')),
RandomValuesCase("weibull_min", "rbg", (5,), np.float32, {"scale": 1, "concentration": 1},
np.array([1.370903, 0.086532, 0.061688, 3.407599, 0.215077], dtype='float32')),
]
class PrngTest(jtu.JaxTestCase):
def testThreefry2x32(self):
# We test the hash by comparing to known values provided in the test code of
# the original reference implementation of Threefry. For the values, see
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
def result_to_hex(result):
return tuple([hex(x.copy()).rstrip("L") for x in result])
expected = ("0x6b200159", "0x99ba4efe")
result = prng.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))
self.assertEqual(expected, result_to_hex(result))
expected = ("0x1cb996fc", "0xbb002be7")
result = prng.threefry_2x32(np.uint32([-1, -1]), np.uint32([-1, -1]))
self.assertEqual(expected, result_to_hex(result))
expected = ("0xc4923a9c", "0x483df7a0")
result = prng.threefry_2x32(
np.uint32([0x13198a2e, 0x03707344]),
np.uint32([0x243f6a88, 0x85a308d3]))
self.assertEqual(expected, result_to_hex(result))
def testThreefry2x32Large(self):
n = 10000000
result = prng.threefry_2x32(
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
jnp.concatenate([
jnp.full((n,), 0x243f6a88, jnp.uint32),
jnp.full((n,), 0x85a308d3, jnp.uint32)
]))
np.testing.assert_equal(result[:n], np.full((n,), 0xc4923a9c, dtype=np.uint32))
np.testing.assert_equal(result[n:], np.full((n,), 0x483df7a0, dtype=np.uint32))
def testThreefry2x32Empty(self):
# Regression test for an op-by-op crash for empty arrays in CUDA mode.
with jax.disable_jit():
result = prng.threefry_2x32(
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
jnp.ones((10, 0,), jnp.uint32))
np.testing.assert_equal(result, np.zeros((10, 0,), dtype=np.uint32))
def testNoOpByOpUnderHash(self):
def fail(*args, **kwargs): assert False
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
try:
_ = prng.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32))
finally:
xla.apply_primitive = apply_primitive
def testRngRandomBits(self):
# Test specific outputs to ensure consistent random values between JAX versions.
key = random.PRNGKey(1701)
bits8 = jax._src.random._random_bits(key, 8, (3,))
expected8 = np.array([216, 115, 43], dtype=np.uint8)
self.assertArraysEqual(bits8, expected8)
bits16 = jax._src.random._random_bits(key, 16, (3,))
expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
self.assertArraysEqual(bits16, expected16)
bits32 = jax._src.random._random_bits(key, 32, (3,))
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
self.assertArraysEqual(bits32, expected32)
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
bits64 = jax._src.random._random_bits(key, 64, (3,))
if config.x64_enabled:
expected64 = np.array([3982329540505020460, 16822122385914693683,
7882654074788531506], dtype=np.uint64)
else:
expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32)
self.assertArraysEqual(bits64, expected64)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_" + name, "prng_name": name}
for name, _ in PRNG_IMPLS))
def testRngRandomBitsShapeDtype(self, prng_name):
# Like testRngRandomBits, but only meant to exercise random_bits
# on every PRNG implementation. Instead of values, only checks
# that shapes/dtypes are as expected.
with jax.default_prng_impl(prng_name):
key = random.PRNGKey(1701)
bits8 = jax._src.random._random_bits(key, 8, (3,))
self.assertEqual(bits8.shape, (3,))
self.assertEqual(bits8.dtype, np.dtype('uint8'))
bits16 = jax._src.random._random_bits(key, 16, (3,))
self.assertEqual(bits16.shape, (3,))
self.assertEqual(bits16.dtype, np.dtype('uint16'))
bits32 = jax._src.random._random_bits(key, 32, (3,))
self.assertEqual(bits32.shape, (3,))
self.assertEqual(bits32.dtype, np.dtype('uint32'))
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
bits64 = jax._src.random._random_bits(key, 64, (3,))
expected_dtype = np.dtype('uint64' if config.x64_enabled else 'uint32')
self.assertEqual(bits64.shape, (3,))
self.assertEqual(bits64.dtype, expected_dtype)
def testRngRandomBitsViewProperty(self):
# TODO: add 64-bit if it ever supports this property.
# TODO: will this property hold across endian-ness?
N = 10
key = random.PRNGKey(1701)
nbits = [8, 16, 32]
rand_bits = [jax._src.random._random_bits(key, n, (N * 64 // n,))
for n in nbits]
rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
assert np.all(rand_bits_32 == rand_bits_32[0])
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": case._testname(), "case": case}
for case in _RANDOM_VALUES_CASES))
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testRandomDistributionValues(self, case):
"""
Tests values output by various distributions. This will catch any unintentional
changes to the implementations that could result in different random sequences.
Any refactoring of random distributions that leads to non-trivial differences in
this test should involve a deprecation cycle following the procedures outlined at
https://jax.readthedocs.io/en/latest/api_compatibility.html
"""
if config.x64_enabled and case.skip_on_x64:
self.skipTest("test produces different values when jax_enable_x64=True")
with jax.default_prng_impl(case.prng_impl):
func = getattr(random, case.name)
key = random.PRNGKey(case._seed())
if case.dtype:
actual = func(key, **case.params, shape=case.shape, dtype=case.dtype)
else:
actual = func(key, **case.params, shape=case.shape)
self.assertAllClose(actual, case.expected, atol=case.atol, rtol=case.rtol)
def testPRNGValues(self):
# Test to ensure consistent random values between JAX versions
k = random.PRNGKey(0)
self.assertEqual(random.randint(k, (3, 3), 0, 8).dtype,
dtypes.canonicalize_dtype(jnp.int_))
if config.x64_enabled:
self.assertAllClose(
random.randint(k, (3, 3), 0, 8, dtype='int64'),
np.array([[7, 2, 6],
[2, 1, 0],
[6, 7, 7]], dtype='int64'))
self.assertAllClose(
random.randint(k, (3, 3), 0, 8, dtype='int32'),
np.array([[2, 1, 3],
[6, 1, 5],
[6, 3, 4]], dtype='int32'))
self.assertAllClose(
_prng_key_as_array(random.split(k, 4)),
np.array([[2285895361, 1501764800],
[1518642379, 4090693311],
[ 433833334, 4221794875],
[ 839183663, 3740430601]], dtype='uint32'))
self.assertAllClose(
_prng_key_as_array(random.fold_in(k, 4)),
np.array([2285895361, 433833334], dtype='uint32'))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "seed={seed}_type={type}_jit={jit}".format(**dct), **dct} for dct in [
{"seed": 0, "type": int, "jit": True, "key": [0, 0]},
{"seed": 0, "type": int, "jit": False, "key": [0, 0]},
{"seed": 1, "type": np.int32, "jit": True, "key": [0, 1]},
{"seed": 1, "type": np.int32, "jit": False, "key": [0, 1]},
{"seed": 2, "type": np.uint32, "jit": True, "key": [0, 2]},
{"seed": 2, "type": np.uint32, "jit": False, "key": [0, 2]},
{"seed": 3, "type": np.int64, "jit": True, "key": [0, 3]},
{"seed": 3, "type": np.int64, "jit": False, "key": [0, 3]},
{"seed": -1, "type": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -1, "type": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -2, "type": np.int32, "jit": True, "key": [0, 4294967294]},
{"seed": -2, "type": np.int32, "jit": False, "key": [0, 4294967294]},
{"seed": -3, "type": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": -3, "type": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": np.iinfo(np.int32).max + 100, "type": int, "jit": True, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 100, "type": int, "jit": False, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": True, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": False, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
]
))
def test_prng_seeds_and_keys(self, seed, type, jit, key):
if (jit and type is int and not config.x64_enabled and
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
self.skipTest("Expected failure: integer out of range for jit.")
seed = type(seed)
if jit:
actual = _prng_key_as_array(jax.jit(random.PRNGKey)(seed))
else:
actual = _prng_key_as_array(random.PRNGKey(seed))
expected = jnp.array(key, dtype=jnp.uint32)
self.assertArraysEqual(actual, expected)
def test_default_prng_selection(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
for name, impl in PRNG_IMPLS:
with jax.default_prng_impl(name):
self.assertIs(random.default_prng_impl(), impl)
key = random.PRNGKey(42)
self.assertIs(key.impl, impl)
k1, k2 = random.split(key, 2)
self.assertIs(k1.impl, impl)
self.assertIs(k2.impl, impl)
def test_default_prng_selection_without_custom_prng_mode(self):
if config.jax_enable_custom_prng:
self.skipTest("test requires that config.jax_enable_custom_prng is False")
for name, impl in PRNG_IMPLS:
with jax.default_prng_impl(name):
self.assertIs(random.default_prng_impl(), impl)
key = random.PRNGKey(42)
self.assertEqual(key.shape, impl.key_shape)
k1, k2 = random.split(key, 2)
self.assertEqual(k1.shape, impl.key_shape)
self.assertEqual(k2.shape, impl.key_shape)
def test_explicit_threefry2x32_key(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.threefry2x32_key(42)
self.assertIs(key.impl, prng.threefry_prng_impl)
def test_explicit_rbg_key(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.rbg_key(42)
self.assertIs(key.impl, prng.rbg_prng_impl)
def test_explicit_unsafe_rbg_key(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.unsafe_rbg_key(42)
self.assertIs(key.impl, prng.unsafe_rbg_prng_impl)
def test_key_array_indexing_0d(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.PRNGKey(1701)
self.assertEqual(key.shape, ())
self.assertEqual(key[None].shape, (1,))
self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*',
lambda: key[0])
def test_key_array_indexing_nd(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
keys = vmap(vmap(random.PRNGKey))(jnp.arange(6).reshape((2, 3)))
self.assertEqual(keys.shape, (2, 3))
self.assertEqual(keys[0, 0].shape, ())
self.assertEqual(keys[0, 1].shape, ())
self.assertEqual(keys[0].shape, (3,))
self.assertEqual(keys[1, :].shape, (3,))
self.assertEqual(keys[:, 1].shape, (2,))
self.assertEqual(keys[None].shape, (1, 2, 3))
self.assertEqual(keys[None, None].shape, (1, 1, 2, 3))
self.assertEqual(keys[None, :, None].shape, (1, 2, 1, 3))
self.assertEqual(keys[None, None, None, 0, None, None, None, 1].shape,
(1,) * 6)
self.assertEqual(keys[..., 1:, None].shape, (2, 2, 1))
self.assertEqual(keys[None, 0, ..., 1, None].shape, (1, 1))
self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*',
lambda: keys[0, 1, 2])
self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*',
lambda: keys[0, 1, None, 2])
class LaxRandomTest(jtu.JaxTestCase):
def _CheckCollisions(self, samples, nbits):
fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev
nitems = len(samples)
nbins = 2 ** nbits
nexpected = nbins * (1 - ((nbins - 1) / nbins) ** nitems)
ncollisions = len(np.unique(samples))
sq_percent_deviation = ((ncollisions - nexpected) / nexpected) ** 2
self.assertLess(sq_percent_deviation, 1 / np.sqrt(nexpected * fail_prob))
def _CheckKolmogorovSmirnovCDF(self, samples, cdf):
fail_prob = 0.01 # conservative bound on statistical fail prob by Kolmo CDF
self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob)
def _CheckChiSquared(self, samples, pmf):
alpha = 0.01 # significance level, threshold for p-value
# scipy.stats.chisquare requires the sum of expected and actual to
# match; this is only the case if we compute the expected frequency
# at *all* nonzero values of the pmf. We don't know this a priori,
# so we add extra values past the largest observed value. The number
# below is empirically enough to get full coverage for the current set
# of tests. If a new test is added where this is not enough, chisquare()
# below will error due to the sums of the inputs not matching.
extra_values = 100
actual_freq = np.bincount(samples, minlength=samples.max() + extra_values)
values = np.arange(len(actual_freq))
expected_freq = pmf(values) * samples.size
valid = expected_freq > 0
actual_freq = actual_freq[valid]
expected_freq = expected_freq[valid]
_, p_value = scipy.stats.chisquare(actual_freq, expected_freq)
self.assertGreater(
p_value, alpha,
msg=f'Failed chi-squared test with p={p_value}.\n'
'Expected vs. actual frequencies:\n'
f'{expected_freq}\n{actual_freq}')
def seed_prng(self, seed):
return random.threefry2x32_key(seed)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in jtu.dtypes.floating))
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
numpy_bits = np.array(1., dtype).view(bits_dtype)
xla_bits = jax.jit(
lambda: lax.bitcast_convert_type(np.array(1., dtype), bits_dtype))()
self.assertEqual(numpy_bits, xla_bits)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testRngUniform(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.uniform(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in int_dtypes + uint_dtypes))
def testRngRandint(self, dtype):
lo = 5
hi = 10
key = self.seed_prng(0)
rand = lambda key: random.randint(key, (10000,), lo, hi, dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self.assertTrue(np.all(lo <= samples))
self.assertTrue(np.all(samples < hi))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testNormal(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.normal(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
def testNormalBfloat16(self):
# Passing bfloat16 as dtype string.
# https://github.com/google/jax/issues/6813
res_bfloat16_str = random.normal(self.seed_prng(0), dtype='bfloat16')
res_bfloat16 = random.normal(self.seed_prng(0), dtype=jnp.bfloat16)
self.assertAllClose(res_bfloat16, res_bfloat16_str)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in complex_dtypes))
def testNormalComplex(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.normal(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(jnp.real(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
self._CheckKolmogorovSmirnovCDF(jnp.imag(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
self.assertEqual(dtype, samples.dtype)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testTruncatedNormal(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
min_val = np.min(uncompiled_samples)
max_val = np.max(uncompiled_samples)
self.assertTrue(min_val > -0.3)
self.assertTrue(max_val < 0.3)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.truncnorm(-0.3, 0.3).cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
def testShuffle(self, dtype):
key = self.seed_prng(0)
x = np.arange(100).astype(dtype)
rand = lambda key: random.shuffle(key, x)
crand = jax.jit(rand)
with self.assertWarns(FutureWarning):
perm1 = rand(key)
with self.assertWarns(FutureWarning):
perm2 = crand(key)
self.assertAllClose(perm1, perm2)
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
dict(
testcase_name=
f"_{np.dtype(dtype).name}_input_range_or_shape={input_range_or_shape}"
f"_shape={shape}_replace={replace}_weighted={weighted}_axis={axis}",
dtype=dtype, input_range_or_shape=input_range_or_shape,
shape=shape, replace=replace, weighted=weighted, axis=axis)
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
for shape in [(), (5,), (4, 5)]
for replace in [True, False]
for weighted in [True, False]
for input_range_or_shape in [100, (10, 10), (10, 5, 2), 1, (1, 5)]
for is_range in [type(input_range_or_shape) is int]
for ndim in [1 if is_range else len(input_range_or_shape)]
for axis in range(-ndim, ndim or 1)
for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]]
if replace or np.prod(shape) <= ninputs))
def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis):
# This is the function API that we test against (note that self.rng().choice differs)
np_choice = np.random.default_rng(0).choice
key = self.seed_prng(0)
is_range = type(input_range_or_shape) is int
x = (input_range_or_shape if is_range else
self.rng().permutation(jnp.arange(np.prod(
input_range_or_shape), dtype=dtype)).reshape(input_range_or_shape))
N = x if is_range else x.shape[axis]
p = None if not weighted else (np.arange(N) + 1) / np.sum(np.arange(N) + 1)
rand = lambda key, x: random.choice(key, x, shape, replace, p, axis)
sample = rand(key, x)
if not is_range:
self.assertEqual(dtype, sample.dtype)
np_shape = np.shape(np_choice(x, shape or None, replace, p, axis))
self.assertEqual(np_shape, sample.shape)
if not replace and shape:
def lsort(x):
if not np.prod(x.shape): return x
ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis])))
return jnp.take(x, ind, axis)
self.assertArraysEqual(lsort(sample), lsort(np.unique(sample, axis=axis)))
self.assertArraysEqual(sample, rand(key, np.array(x)))
self.assertArraysEqual(sample, jax.jit(rand, static_argnames=
'x' if is_range else None)(key, x))
@parameterized.named_parameters(jtu.cases_from_list(
dict(
testcase_name=f"_dtype={dtype}_range_or_shape={range_or_shape}"
f"_axis={axis}_independent={independent}",
dtype=dtype, range_or_shape=range_or_shape, axis=axis, independent=independent)
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
for range_or_shape in [0, 1, 100, (0,), (1,), (100,),
(10, 10), (10, 5, 2), (0, 5), (1, 5)]
for ndim in [1 if type(range_or_shape) is int else len(range_or_shape)]
for axis in range(-ndim, ndim or 1)
for independent in [True, False]))
def testPermutation(self, dtype, range_or_shape, axis, independent):
key = self.seed_prng(0)
is_range = type(range_or_shape) is int
x = (range_or_shape if is_range else
self.rng().permutation(jnp.arange(
np.prod(range_or_shape), dtype=dtype)).reshape(range_or_shape))
shape = ((range_or_shape,) if is_range else range_or_shape)
x_ = np.copy(x)
rand = lambda key, x: random.permutation(key, x, axis, independent=independent)
perm = rand(key, x)
if shape[axis] >= 10:
self.assertFalse(np.all(perm == x)) # seems unlikely!
arr = np.arange(x) if is_range else x
def lsort(x):
if not np.prod(x.shape): return x
ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis])))
return jnp.take(x, ind, axis)
if not independent:
self.assertArraysEqual(lsort(arr), lsort(perm), check_dtypes=not is_range)
if independent and (arr.shape[axis] > 4) and (arr.size // arr.shape[axis] > 4):
# Check for independent shuffling if there are >4 vectors of size >4.
# Chance of false positive is 1 in (5!)^4
with self.assertRaises(AssertionError):
self.assertArraysEqual(lsort(arr), lsort(perm), check_dtypes=not is_range)
self.assertArraysEqual(x_, x)
self.assertArraysEqual(perm, rand(key, np.array(x)))
self.assertArraysEqual(perm, jax.jit(rand, static_argnames=
'x' if is_range else None)(key, x))
def testPermutationErrors(self):
key = self.seed_prng(0)
with self.assertRaises(ValueError):
random.permutation(key, 10, axis=3)
with self.assertRaises(TypeError):
random.permutation(key, 10.)
with self.assertRaises(core.ConcretizationTypeError):
jax.jit(random.permutation)(key, 10)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_p={}_dtype={}".format(p, np.dtype(dtype).name),
"p": p, "dtype": dtype}
for p in [0.1, 0.5, 0.9]
for dtype in jtu.dtypes.floating))
def testBernoulli(self, p, dtype):
key = self.seed_prng(0)
p = np.array(p, dtype=dtype)
rand = lambda key, p: random.bernoulli(key, p, (10000,))
crand = jax.jit(rand)
uncompiled_samples = rand(key, p)
compiled_samples = crand(key, p)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_p={}_{}_{}".format(p, np.dtype(dtype).name, sample_shape),
"p": p, "axis": axis, "dtype": dtype, 'sample_shape': sample_shape}
for (p, axis) in [
([.25] * 4, -1),
([.1, .2, .3, .4], -1),
([[.5, .5], [.1, .9]], 1),
([[.5, .1], [.5, .9]], 0),
]
for sample_shape in [(10000,), (5000, 2)]
for dtype in jtu.dtypes.floating))
def testCategorical(self, p, axis, dtype, sample_shape):
key = self.seed_prng(0)
p = np.array(p, dtype=dtype)
logits = np.log(p) - 42 # test unnormalized
out_shape = tuple(np.delete(logits.shape, axis))
shape = sample_shape + out_shape
rand = partial(random.categorical, shape=shape, axis=axis)
crand = jax.jit(rand)
uncompiled_samples = rand(key, logits)
compiled_samples = crand(key, logits)
if axis < 0:
axis += len(logits.shape)
for samples in [uncompiled_samples, compiled_samples]:
assert samples.shape == shape
samples = jnp.reshape(samples, (10000,) + out_shape)
if len(p.shape[:-1]) > 0:
ps = np.transpose(p, (1, 0)) if axis == 0 else p
for cat_samples, cat_p in zip(samples.transpose(), ps):
pmf = lambda x: np.where(x < len(cat_p), cat_p[np.minimum(len(cat_p) - 1, x)], 0.0)
self._CheckChiSquared(cat_samples, pmf=pmf)
else:
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
self._CheckChiSquared(samples, pmf=pmf)
def testBernoulliShape(self):
key = self.seed_prng(0)
with jax.numpy_rank_promotion('allow'):
x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_b={}_dtype={}".format(a, b, np.dtype(dtype).name),
"a": a, "b": b, "dtype": dtype}
for a in [0.2, 5.]
for b in [0.2, 5.]
for dtype in [np.float64])) # NOTE: KS test fails with float32
def testBeta(self, a, b, dtype):
if not config.x64_enabled:
raise SkipTest("skip test except on X64")
key = self.seed_prng(0)
rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, a, b)
compiled_samples = crand(key, a, b)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)
def testBetaSmallParameters(self, dtype=np.float32):
# Regression test for beta version of https://github.com/google/jax/issues/9896
key = self.seed_prng(0)
a, b = 0.0001, 0.0002
samples = random.beta(key, a, b, shape=(100,), dtype=dtype)
# With such small parameters, all samples should be exactly zero or one.
zeros = samples[samples < 0.5]
self.assertAllClose(zeros, jnp.zeros_like(zeros))
ones = samples[samples >= 0.5]
self.assertAllClose(ones, jnp.ones_like(ones))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testCauchy(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.cauchy(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_alpha={}_dtype={}".format(alpha, np.dtype(dtype).name),
"alpha": alpha, "dtype": dtype}
for alpha in [
np.array([0.2, 1., 5.]),
]
for dtype in jtu.dtypes.floating))
@jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times
def testDirichlet(self, alpha, dtype):
key = self.seed_prng(0)
rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, alpha)
compiled_samples = crand(key, alpha)
for samples in [uncompiled_samples, compiled_samples]:
self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype))
alpha_sum = sum(alpha)
for i, a in enumerate(alpha):
self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
def testDirichletSmallAlpha(self, dtype=np.float32):
# Regression test for https://github.com/google/jax/issues/9896
key = self.seed_prng(0)
alpha = 0.0001 * jnp.ones(3)
samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype)
# Check that results lie on the simplex.
self.assertAllClose(samples.sum(1), jnp.ones(samples.shape[0]),
check_dtypes=False, rtol=1E-5)
# Check that results contain 1 in one of the dimensions:
# this is highly likely to be true when alpha is small.
self.assertAllClose(samples.max(1), jnp.ones(samples.shape[0]),
check_dtypes=False, rtol=1E-5)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testExponential(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.exponential(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name),
"a": a, "dtype": dtype}
for a in [0.1, 1., 10.]
for dtype in jtu.dtypes.floating))
def testGammaVsLogGamma(self, a, dtype):
key = self.seed_prng(0)
rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype)
rand_loggamma = lambda key, a: random.loggamma(key, a, (10000,), dtype)
crand_loggamma = jax.jit(rand_loggamma)
self.assertAllClose(rand_gamma(key, a), jnp.exp(rand_loggamma(key, a)))
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name),
"a": a, "dtype": dtype}
for a in [0.1, 1., 10.]
for dtype in jtu.dtypes.floating))
def testGamma(self, a, dtype):
key = self.seed_prng(0)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, a)
compiled_samples = crand(key, a)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
def testGammaShape(self):
key = self.seed_prng(0)
x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_logspace={}".format(alpha, log_space),
"alpha": alpha, "log_space": log_space}
for log_space in [True, False]
for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
def testGammaGrad(self, log_space, alpha):
rng = self.seed_prng(0)
alphas = np.full((100,), alpha)
z = random.gamma(rng, alphas)
if log_space:
actual_grad = jax.grad(lambda x: lax.exp(random.loggamma(rng, x)).sum())(alphas)
# TODO(jakevdp): this NaN correction is required because we generate negative infinities
# in the log-space computation; see related TODO in the source of random._gamma_one().
actual_grad = jnp.where(jnp.isnan(actual_grad), 0.0, actual_grad)
else:
actual_grad = jax.grad(lambda x: random.gamma(rng, x).sum())(alphas)
eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps)
- scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps)
with np.errstate(over='ignore'):
pdf = scipy.stats.gamma.pdf(z, alpha)
expected_grad = -cdf_dot / pdf
rtol = 2e-2 if jtu.device_under_test() == "tpu" else 7e-4
self.assertAllClose(actual_grad, expected_grad, check_dtypes=True,
rtol=rtol)
def testGammaGradType(self):
# Regression test for https://github.com/google/jax/issues/2130
key = self.seed_prng(0)
a = jnp.array(1., dtype=jnp.float32)
b = jnp.array(3., dtype=jnp.float32)
f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y
# Should not crash with a type error.
jax.vjp(f, a, b)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lam={}_dtype={}".format(lam, np.dtype(dtype).name),
"lam": lam, "dtype": np.dtype(dtype)}
for lam in [0.5, 3, 9, 11, 50, 500]
for dtype in [np.int16, np.int32, np.int64]))
def testPoisson(self, lam, dtype):
key = self.seed_prng(0)
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, lam)
compiled_samples = crand(key, lam)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
# TODO(shoyer): determine error bounds for moments more rigorously (e.g.,
# based on the central limit theorem).
self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False)
self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
def testPoissonBatched(self):
key = self.seed_prng(1)
lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
samples = random.poisson(key, lam, shape=(20000,))
self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)
def testPoissonWithoutShape(self):
key = self.seed_prng(1)
lam = 2 * jnp.ones(10000)
samples = random.poisson(key, lam)
self._CheckChiSquared(samples, scipy.stats.poisson(2.0).pmf)
def testPoissonShape(self):
key = self.seed_prng(0)
x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
assert x.shape == (3, 2)
def testPoissonZeros(self):
key = self.seed_prng(0)
lam = jnp.concatenate([jnp.zeros(10), 20 * jnp.ones(10)])
samples = random.poisson(key, lam, shape=(2, 20))
self.assertArraysEqual(samples[:, :10], jnp.zeros_like(samples[:, :10]))
def testPoissonCornerCases(self):
key = self.seed_prng(0)
lam = jnp.array([-1, 0, jnp.nan])
samples = random.poisson(key, lam, shape=(3,))
self.assertArraysEqual(samples, jnp.array([-1, 0, -1]))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in jtu.dtypes.floating))
def testGumbel(self, dtype):
key = self.seed_prng(0)
rand = lambda key: random.gumbel(key, (10000,), dtype)
crand = jax.jit(rand)