diff --git a/main.c b/main.c index a91bf94..0250203 100644 --- a/main.c +++ b/main.c @@ -54,6 +54,34 @@ gmp_allocate_function(size_t size) } +static void * +gmp_reallocate_function(void *ptr, size_t old_size, size_t new_size) +{ + void *ret = realloc(ptr, new_size); + + if (!ret) { + goto err; + } + for (size_t i = gmp_tracker.size - 1; i >= 0; i--) { + if (gmp_tracker.ptrs[i] == ptr) { + gmp_tracker.ptrs[i] = ret; + break; + } + } + return ret; +err: + for (size_t i = 0; i < gmp_tracker.size; i++) { + if (gmp_tracker.ptrs[i]) { + free(gmp_tracker.ptrs[i]); + gmp_tracker.ptrs[i] = NULL; + } + } + gmp_tracker.alloc = 0; + gmp_tracker.size = 0; + longjmp(gmp_env, 1); +} + + static void gmp_free_function(void *ptr, size_t size) { @@ -70,7 +98,6 @@ gmp_free_function(void *ptr, size_t size) } - typedef struct _mpzobject { PyObject_HEAD uint8_t negative; @@ -2607,7 +2634,7 @@ gmp_gcd(PyObject *self, PyObject * const *args, Py_ssize_t nargs) static PyObject * gmp_isqrt(PyObject *self, PyObject *other) { - static MPZ_Object *x, *res; + static MPZ_Object *x, *res = NULL; if (MPZ_CheckExact(other)) { x = (MPZ_Object*)other; @@ -2627,14 +2654,15 @@ gmp_isqrt(PyObject *self, PyObject *other) if (x->negative) { PyErr_SetString(PyExc_ValueError, "isqrt() argument must be nonnegative"); - return NULL; + goto end; } else if (!x->size) { - return (PyObject*)MPZ_FromDigitSign(0, 0); + res = MPZ_FromDigitSign(0, 0); + goto end; } res = MPZ_new((x->size + 1)/2, 0); if (!res) { - return NULL; + goto end; } if (CHECK_NO_MEM_LEAK) { mpn_sqrtrem(res->digits, NULL, x->digits, x->size); @@ -2653,8 +2681,59 @@ gmp_isqrt(PyObject *self, PyObject *other) static PyObject * gmp_factorial(PyObject *self, PyObject *other) { - PyErr_SetString(PyExc_NotImplementedError, "factorial"); - return NULL; + static MPZ_Object *x, *res = NULL; + + if (MPZ_CheckExact(other)) { + x = (MPZ_Object*)other; + Py_INCREF(x); + } + else if (PyLong_Check(other)) { + x = from_int(other); + if (!x) { + goto end; + } + } + else { + PyErr_SetString(PyExc_TypeError, + "factorial() argument must be an integer"); + return NULL; + } + + __mpz_struct tmp; + + tmp._mp_d = x->digits; + tmp._mp_size = (x->negative ? -1 : 1)*x->size; + if (!mpz_fits_ulong_p(&tmp)) { + PyErr_Format(PyExc_OverflowError, + "factorial() argument should not exceed %ld", LONG_MAX); + goto end; + } + if (x->negative) { + PyErr_SetString(PyExc_ValueError, + "factorial() not defined for negative values"); + goto end; + } + + unsigned long n = mpz_get_ui(&tmp); + + if (CHECK_NO_MEM_LEAK) { + mpz_init(&tmp); + mpz_fac_ui(&tmp, n); + } + else { + Py_DECREF(x); + return PyErr_NoMemory(); + } + res = MPZ_new(tmp._mp_size, 0); + if (!res) { + mpz_clear(&tmp); + goto end; + } + mpn_copyi(res->digits, tmp._mp_d, res->size); + mpz_clear(&tmp); +end: + Py_DECREF(x); + return (PyObject*)res; } @@ -2696,7 +2775,8 @@ PyInit_gmp(void) { return NULL; } - mp_set_memory_functions(gmp_allocate_function, NULL, + mp_set_memory_functions(gmp_allocate_function, + gmp_reallocate_function, gmp_free_function); PyObject *numbers = PyImport_ImportModule("numbers"); diff --git a/tests/test_functions.py b/tests/test_functions.py index 8a7207e..5edf99d 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -13,7 +13,6 @@ def test_isqrt(x): assert isqrt(mx) == isqrt(x) == r -@pytest.mark.xfail(reason="diofant/python-gmp#12") @given(integers(min_value=0, max_value=12345)) def test_factorial(x): mx = mpz(x)