Skip to content

Commit

Permalink
sagemathgh-38405: Call more general algorithm when lattice basis isn'…
Browse files Browse the repository at this point in the history
…t trivial

    
fixes sagemath#38400

- [x] I have linked a relevant issue or discussion.
- [x] I have created tests covering the changes.
- [x] I have updated the documentation and checked the documentation
preview.
    
URL: sagemath#38405
Reported by: Martin R. Albrecht
Reviewer(s): Martin R. Albrecht, Matthias Köppe
  • Loading branch information
Release Manager committed Aug 1, 2024
2 parents 03d2081 + 261730a commit a01ce04
Showing 1 changed file with 31 additions and 19 deletions.
50 changes: 31 additions & 19 deletions src/sage/stats/distributions/discrete_gaussian_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,9 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
sage: M = Matrix(ZZ, [[1, 3, 0], [-2, 5, 1], [3, -4, 2]])
sage: D = DGL(M, 1.7)
sage: D._normalisation_factor_zz() # long time
7247.1975...
Traceback (most recent call last):
...
NotImplementedError: center must be at zero and basis must be trivial
sage: Sigma = Matrix(ZZ, [[5, -2, 4], [-2, 10, -5], [4, -5, 5]])
sage: D = DGL(ZZ^3, Sigma, [7, 2, 5])
Expand All @@ -260,19 +262,19 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
sage: D._normalisation_factor_zz()
Traceback (most recent call last):
...
NotImplementedError: basis must be a square matrix for now
NotImplementedError: basis must be a square matrix
sage: D = DGL(ZZ^3, c=(1/2, 0, 0))
sage: D._normalisation_factor_zz()
Traceback (most recent call last):
...
NotImplementedError: lattice must contain 0 for now
NotImplementedError: center must be at zero and basis must be trivial
sage: D = DGL(Matrix(3, 3, 1/2))
sage: D._normalisation_factor_zz()
Traceback (most recent call last):
...
NotImplementedError: lattice must be integral for now
NotImplementedError: lattice must be integral
"""
# If σ > 1:
# We use the Fourier transform g(t) of f(x) = exp(-k^2 / 2σ^2), but
Expand Down Expand Up @@ -312,13 +314,13 @@ def f_or_hat(x):
return sum(self.f((vector(u) + base) * self.B) for u in coords)

if self.B.nrows() != self.B.ncols():
raise NotImplementedError("basis must be a square matrix for now")

if self.is_spherical and not self._c_in_lattice:
raise NotImplementedError("lattice must contain 0 for now")
raise NotImplementedError("basis must be a square matrix")

if self.B.base_ring() != ZZ:
raise NotImplementedError("lattice must be integral for now")
raise NotImplementedError("lattice must be integral")

if self.is_spherical and not self._c_in_lattice_and_lattice_trivial:
raise NotImplementedError("center must be at zero and basis must be trivial")

sigma = self._sigma
prec = DiscreteGaussianDistributionLatticeSampler.compute_precision(
Expand Down Expand Up @@ -583,7 +585,7 @@ def __init__(self, B, sigma=1, c=0, r=None, precision=None, sigma_basis=False):
self.B = B
self.Q = B * B.T
self._G = B.gram_schmidt()[0]
self._c_in_lattice = False
self._c_in_lattice_and_lattice_trivial = False

self.D = None
self.VS = None
Expand Down Expand Up @@ -612,18 +614,19 @@ def _precompute_data(self):
Do not call this method directly, it is called automatically from
:func:`DiscreteGaussianDistributionLatticeSampler.__init__`.
"""

if self.is_spherical:
# deal with trivial case first, it is common
if self._G == 1 and self._c == 0:
self._c_in_lattice = True
if self._c == 0 and self._G == 1:
self._c_in_lattice_and_lattice_trivial = True
D = DiscreteGaussianDistributionIntegerSampler(sigma=self._sigma)
self.D = tuple([D for _ in range(self.B.nrows())])
self.VS = FreeModule(ZZ, self.B.nrows())

else:
w = self.B.solve_left(self._c)
if w in ZZ ** self.B.nrows():
self._c_in_lattice = True
if w in ZZ ** self.B.nrows() and self._G == 1:
self._c_in_lattice_and_lattice_trivial = True
D = []
for i in range(self.B.nrows()):
sigma_ = self._sigma / self._G[i].norm()
Expand Down Expand Up @@ -673,11 +676,20 @@ def __call__(self):
sage: mean_L = sum(L) / len(L) # long time
sage: norm(mean_L.n() - D.c()) < 0.25 # long time
True
sage: import numpy
sage: M = matrix(ZZ, [[1,2],[0,1]])
sage: D = distributions.DiscreteGaussianDistributionLatticeSampler(M, 20.0)
sage: L = [D() for _ in range(2^12)] # long time
sage: div = numpy.mean([abs(x) for x,y in L]) / numpy.mean([abs(y) for x,y, in L]) # long time
sage: 0.9 < div < 1.1 # long time
True
"""
if not self.is_spherical:
v = self._call_non_spherical()
elif self._c_in_lattice:
v = self._call_in_lattice()
elif self._c_in_lattice_and_lattice_trivial:
v = self._call_simple()
else:
v = self._call()
v.set_immutable()
Expand Down Expand Up @@ -807,14 +819,14 @@ def __repr__(self):
sigma_str = f"Σ =\n{self._sigma}"
return f"Discrete Gaussian sampler with Gaussian parameter {sigma_str}, c={self._c} over lattice with basis\n\n{self.B}"

def _call_in_lattice(self):
def _call_simple(self):
r"""
Return a new sample assuming `c \in \Lambda(B)`.
Return a new sample assuming `c \in \Lambda(B)` and `B^* = 1`.
EXAMPLES::
sage: D = distributions.DiscreteGaussianDistributionLatticeSampler(ZZ^3, 3.0, c=(1,0,0))
sage: L = [D._call_in_lattice() for _ in range(2^12)]
sage: L = [D._call_simple() for _ in range(2^12)]
sage: mean_L = sum(L) / len(L)
sage: norm(mean_L.n() - D.c()) < 0.25
True
Expand Down

0 comments on commit a01ce04

Please sign in to comment.