diff --git a/src/fpylll/fplll/fplll.pxd b/src/fpylll/fplll/fplll.pxd index 60d6e58d..74cc7b45 100644 --- a/src/fpylll/fplll/fplll.pxd +++ b/src/fpylll/fplll/fplll.pxd @@ -755,6 +755,7 @@ cdef extern from "fplll/enum/enumerate_ext.h" namespace "fplll": bool dual, bool findsubsols) void set_external_enumerator(function[extenum_fc_enumerate] extenum) + function[extenum_fc_enumerate] get_external_enumerator() # SVP diff --git a/src/fpylll/util.pyx b/src/fpylll/util.pyx index 8b5384ab..89c8d3e0 100644 --- a/src/fpylll/util.pyx +++ b/src/fpylll/util.pyx @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- include "fpylll/config.pxi" - + +from contextlib import contextmanager from fpylll.fplll.decl cimport fp_nr_t from fpylll.fplll.fplll cimport FP_NR, RandGen, dpe_t @@ -8,6 +9,7 @@ from fpylll.fplll.fplll cimport FT_DEFAULT, FT_DOUBLE, FT_LONG_DOUBLE, FT_DPE, F from fpylll.fplll.fplll cimport IntType, ZT_LONG, ZT_MPZ from fpylll.fplll.fplll cimport adjust_radius_to_gh_bound as adjust_radius_to_gh_bound_c from fpylll.fplll.fplll cimport set_external_enumerator as set_external_enumerator_c +from fpylll.fplll.fplll cimport get_external_enumerator as get_external_enumerator_c from fpylll.fplll.fplll cimport extenum_fc_enumerate from fpylll.fplll.fplll cimport get_root_det as get_root_det_c from fpylll.fplll.fplll cimport PRUNER_METRIC_PROBABILITY_OF_SHORTEST, PRUNER_METRIC_EXPECTED_SOLUTIONS, PrunerMetric @@ -199,28 +201,27 @@ def set_precision(unsigned int prec): raise ValueError("Precision (%d) too small."%prec) return FP_NR[mpfr_t].set_prec(prec) -class PrecisionContext: - def __init__(self, prec): - """Create new precision context. - - :param prec: internal precision - - """ - self.prec = prec - - def __enter__(self): - self.prec = set_precision(self.prec) - - def __exit__(self, exception_type, exception_value, exception_traceback): - self.prec = set_precision(self.prec) - +@contextmanager def precision(prec): - """Create new precision context. + """Run with precision ``prec`` temporarily. - :param prec: internal precision + :param prec: temporary precision + :returns: temporary precision being used + + >>> from fpylll import FPLLL + >>> with FPLLL.precision(212) as prec: print(prec) + 212 + >>> FPLLL.get_precision() + 53 + >>> with FPLLL.precision(212): FPLLL.get_precision() + 212 """ - return PrecisionContext(prec) + old_prec = set_precision(prec) + try: + yield get_precision() + finally: + set_precision(old_prec) def adjust_radius_to_gh_bound(double dist, int dist_expo, int block_size, double root_det, double gh_factor): @@ -291,8 +292,7 @@ cpdef set_external_enumerator(enumerator): We grab the external enumeration function >>> fn = enumlib._Z17enumlib_enumerateidSt8functionIFvPdmbS0_S0_EES_IFddS0_EES_IFvdS0_iEEbb # doctest: +SKIP - - and pass it to Fplll + and pass it to FPLLL >>> FPLLL.set_external_enumerator(fn) # doctest: +SKIP @@ -300,6 +300,8 @@ cpdef set_external_enumerator(enumerator): >>> FPLLL.set_external_enumerator(None) # doctest: +SKIP + :param enumerator: CTypes handle + """ import ctypes cdef unsigned long p @@ -309,6 +311,21 @@ cpdef set_external_enumerator(enumerator): p = ctypes.cast(enumerator, ctypes.c_void_p).value set_external_enumerator_c(void_ptr_to_function(p)) +@contextmanager +def external_enumerator(enumerator): + """ + Temporarily use ``enumerator``. + + :param enumerator: CTypes handle + + """ + cdef function[extenum_fc_enumerate] fn = get_external_enumerator_c() + set_external_enumerator(enumerator) + try: + yield + finally: + set_external_enumerator_c(fn) + def set_threads(int th=1): """ Set the number of threads. @@ -333,12 +350,43 @@ def get_threads(): """ return get_threads_c() +@contextmanager +def threads(int th=1): + """ + Run with ``th`` threads temporarily + + :param th: number of threads ≥ 1 + :returns: number of threads used + + >>> from fpylll import FPLLL + >>> import multiprocessing + >>> max_th = multiprocessing.cpu_count() + >>> with FPLLL.threads(4) as th: th == min(max_th, 4) + True + >>> FPLLL.get_threads() + 1 + >>> with FPLLL.threads(4) as th: FPLLL.get_threads() == min(max_th, 4) + True + + """ + old_th = get_threads() + set_threads(th) + try: + yield get_threads() + finally: + set_threads(old_th) class FPLLL: set_precision = staticmethod(set_precision) get_precision = staticmethod(get_precision) + precision = staticmethod(precision) + set_threads = staticmethod(set_threads) get_threads = staticmethod(get_threads) + threads = staticmethod(threads) + set_random_seed = staticmethod(set_random_seed) randint = staticmethod(randint) + set_external_enumerator = staticmethod(set_external_enumerator) + external_enumerator = staticmethod(external_enumerator)