Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call more general algorithm when lattice basis isn't trivial #38405

Merged
merged 8 commits into from
Aug 3, 2024
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 31 additions & 19 deletions src/sage/stats/distributions/discrete_gaussian_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
sage: M = Matrix(ZZ, [[1, 3, 0], [-2, 5, 1], [3, -4, 2]])
sage: D = DGL(M, 1.7)
sage: D._normalisation_factor_zz() # long time
7247.1975...
Traceback (most recent call last):
...
NotImplementedError: Center must be at zero and basis must be trivial.

sage: Sigma = Matrix(ZZ, [[5, -2, 4], [-2, 10, -5], [4, -5, 5]])
sage: D = DGL(ZZ^3, Sigma, [7, 2, 5])
Expand All @@ -262,19 +264,19 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
sage: D._normalisation_factor_zz()
Traceback (most recent call last):
...
NotImplementedError: basis must be a square matrix for now
NotImplementedError: Basis must be a square matrix.

sage: D = DGL(ZZ^3, c=(1/2, 0, 0))
sage: D._normalisation_factor_zz()
Traceback (most recent call last):
...
NotImplementedError: lattice must contain 0 for now
NotImplementedError: Center must be at zero and basis must be trivial.

sage: D = DGL(Matrix(3, 3, 1/2))
sage: D._normalisation_factor_zz()
Traceback (most recent call last):
...
NotImplementedError: lattice must be integral for now
NotImplementedError: Lattice must be integral.
"""
# If σ > 1:
# We use the Fourier transform g(t) of f(x) = exp(-k^2 / 2σ^2), but
Expand Down Expand Up @@ -314,13 +316,13 @@ def f_or_hat(x):
return sum(self.f((vector(u) + base) * self.B) for u in coords)

if self.B.nrows() != self.B.ncols():
raise NotImplementedError("basis must be a square matrix for now")

if self.is_spherical and not self._c_in_lattice:
raise NotImplementedError("lattice must contain 0 for now")
raise NotImplementedError("Basis must be a square matrix.")

if self.B.base_ring() != ZZ:
raise NotImplementedError("lattice must be integral for now")
raise NotImplementedError("Lattice must be integral.")
mkoeppe marked this conversation as resolved.
Show resolved Hide resolved

if self.is_spherical and not self._c_in_lattice_and_lattice_trivial:
raise NotImplementedError("Center must be at zero and basis must be trivial.")

sigma = self._sigma
prec = DiscreteGaussianDistributionLatticeSampler.compute_precision(
Expand Down Expand Up @@ -585,7 +587,7 @@ def __init__(self, B, sigma=1, c=0, r=None, precision=None, sigma_basis=False):
self.B = B
self.Q = B * B.T
self._G = B.gram_schmidt()[0]
self._c_in_lattice = False
self._c_in_lattice_and_lattice_trivial = False

self.D = None
self.VS = None
Expand Down Expand Up @@ -614,18 +616,19 @@ def _precompute_data(self):
Do not call this method directly, it is called automatically from
:func:`DiscreteGaussianDistributionLatticeSampler.__init__`.
"""

if self.is_spherical:
# deal with trivial case first, it is common
if self._G == 1 and self._c == 0:
self._c_in_lattice = True
if self._c == 0 and self._G == 1:
self._c_in_lattice_and_lattice_trivial = True
D = DiscreteGaussianDistributionIntegerSampler(sigma=self._sigma)
self.D = tuple([D for _ in range(self.B.nrows())])
self.VS = FreeModule(ZZ, self.B.nrows())

else:
w = self.B.solve_left(self._c)
if w in ZZ ** self.B.nrows():
self._c_in_lattice = True
if w in ZZ ** self.B.nrows() and self._G == 1:
self._c_in_lattice_and_lattice_trivial = True
D = []
for i in range(self.B.nrows()):
sigma_ = self._sigma / self._G[i].norm()
Expand Down Expand Up @@ -675,11 +678,20 @@ def __call__(self):
sage: mean_L = sum(L) / len(L) # long time
sage: norm(mean_L.n() - D.c()) < 0.25 # long time
True

sage: import numpy
sage: M = matrix(ZZ, [[1,2],[0,1]])
sage: D = distributions.DiscreteGaussianDistributionLatticeSampler(M, 20.0)
sage: L = [D() for _ in range(2^12)] # long time
sage: div = (numpy.mean([abs(x) for x,y in L])/numpy.mean([abs(y) for x,y, in L])) # long time
mkoeppe marked this conversation as resolved.
Show resolved Hide resolved
sage: 0.9 < div < 1.1 # long time
True

"""
if not self.is_spherical:
v = self._call_non_spherical()
elif self._c_in_lattice:
v = self._call_in_lattice()
elif self._c_in_lattice_and_lattice_trivial:
v = self._call_simple()
else:
v = self._call()
v.set_immutable()
Expand Down Expand Up @@ -809,14 +821,14 @@ def __repr__(self):
sigma_str = f"Σ =\n{self._sigma}"
return f"Discrete Gaussian sampler with Gaussian parameter {sigma_str}, c={self._c} over lattice with basis\n\n{self.B}"

def _call_in_lattice(self):
def _call_simple(self):
r"""
Return a new sample assuming `c \in \Lambda(B)`.
Return a new sample assuming `c \in \Lambda(B)` and `B^* = 1`.

EXAMPLES::

sage: D = distributions.DiscreteGaussianDistributionLatticeSampler(ZZ^3, 3.0, c=(1,0,0))
sage: L = [D._call_in_lattice() for _ in range(2^12)]
sage: L = [D._call_simple() for _ in range(2^12)]
sage: mean_L = sum(L) / len(L)
sage: norm(mean_L.n() - D.c()) < 0.25
True
Expand Down
Loading