Skip to content

Commit

Permalink
Merge pull request #6057 from inailuig:fix-complex-normal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 362930693
  • Loading branch information
jax authors committed Mar 15, 2021
2 parents 1ad99d3 + d78fe6b commit 3fb6a11
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def _normal(key, shape, dtype) -> jnp.ndarray:
dtype = dtypes.dtype_real(dtype)
_re = _normal_real(key_re, shape, dtype)
_im = _normal_real(key_im, shape, dtype)
return 1 / sqrt2 * (_re + 1j * _im)
return (_re + 1j * _im) / sqrt2
else:
return _normal_real(key, shape, dtype) # type: ignore

Expand Down
1 change: 1 addition & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def testNormalComplex(self, dtype):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(jnp.real(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
self._CheckKolmogorovSmirnovCDF(jnp.imag(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
self.assertEqual(dtype, samples.dtype)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
Expand Down

0 comments on commit 3fb6a11

Please sign in to comment.