Skip to content

Commit

Permalink
Add factorial()
Browse files Browse the repository at this point in the history
and override default reallocate_function.

Closes diofant#12
Closes diofant#14
  • Loading branch information
skirpichev committed Dec 14, 2024
1 parent 009ca50 commit d14532c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 9 deletions.
96 changes: 88 additions & 8 deletions main.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,34 @@ gmp_allocate_function(size_t size)
}


static void *
gmp_reallocate_function(void *ptr, size_t old_size, size_t new_size)
{
void *ret = realloc(ptr, new_size);

if (!ret) {
goto err;
}
for (size_t i = gmp_tracker.size - 1; i >= 0; i--) {
if (gmp_tracker.ptrs[i] == ptr) {
gmp_tracker.ptrs[i] = ret;
break;
}
}
return ret;
err:
for (size_t i = 0; i < gmp_tracker.size; i++) {
if (gmp_tracker.ptrs[i]) {
free(gmp_tracker.ptrs[i]);
gmp_tracker.ptrs[i] = NULL;
}
}
gmp_tracker.alloc = 0;
gmp_tracker.size = 0;
longjmp(gmp_env, 1);
}


static void
gmp_free_function(void *ptr, size_t size)
{
Expand All @@ -70,7 +98,6 @@ gmp_free_function(void *ptr, size_t size)
}



typedef struct _mpzobject {
PyObject_HEAD
uint8_t negative;
Expand Down Expand Up @@ -2607,7 +2634,7 @@ gmp_gcd(PyObject *self, PyObject * const *args, Py_ssize_t nargs)
static PyObject *
gmp_isqrt(PyObject *self, PyObject *other)
{
static MPZ_Object *x, *res;
static MPZ_Object *x, *res = NULL;

if (MPZ_CheckExact(other)) {
x = (MPZ_Object*)other;
Expand All @@ -2627,14 +2654,15 @@ gmp_isqrt(PyObject *self, PyObject *other)
if (x->negative) {
PyErr_SetString(PyExc_ValueError,
"isqrt() argument must be nonnegative");
return NULL;
goto end;
}
else if (!x->size) {
return (PyObject*)MPZ_FromDigitSign(0, 0);
res = MPZ_FromDigitSign(0, 0);
goto end;
}
res = MPZ_new((x->size + 1)/2, 0);
if (!res) {
return NULL;
goto end;
}
if (CHECK_NO_MEM_LEAK) {
mpn_sqrtrem(res->digits, NULL, x->digits, x->size);
Expand All @@ -2653,8 +2681,59 @@ gmp_isqrt(PyObject *self, PyObject *other)
static PyObject *
gmp_factorial(PyObject *self, PyObject *other)
{
PyErr_SetString(PyExc_NotImplementedError, "factorial");
return NULL;
static MPZ_Object *x, *res = NULL;

if (MPZ_CheckExact(other)) {
x = (MPZ_Object*)other;
Py_INCREF(x);
}
else if (PyLong_Check(other)) {
x = from_int(other);
if (!x) {
goto end;
}
}
else {
PyErr_SetString(PyExc_TypeError,
"factorial() argument must be an integer");
return NULL;
}

__mpz_struct tmp;

tmp._mp_d = x->digits;
tmp._mp_size = (x->negative ? -1 : 1)*x->size;
if (!mpz_fits_ulong_p(&tmp)) {
PyErr_Format(PyExc_OverflowError,
"factorial() argument should not exceed %ld", LONG_MAX);
goto end;
}
if (x->negative) {
PyErr_SetString(PyExc_ValueError,
"factorial() not defined for negative values");
goto end;
}

unsigned long n = mpz_get_ui(&tmp);

if (CHECK_NO_MEM_LEAK) {
mpz_init(&tmp);
mpz_fac_ui(&tmp, n);
}
else {
Py_DECREF(x);
return PyErr_NoMemory();
}
res = MPZ_new(tmp._mp_size, 0);
if (!res) {
mpz_clear(&tmp);
goto end;
}
mpn_copyi(res->digits, tmp._mp_d, res->size);
mpz_clear(&tmp);
end:
Py_DECREF(x);
return (PyObject*)res;
}


Expand Down Expand Up @@ -2696,7 +2775,8 @@ PyInit_gmp(void)
{
return NULL;
}
mp_set_memory_functions(gmp_allocate_function, NULL,
mp_set_memory_functions(gmp_allocate_function,
gmp_reallocate_function,
gmp_free_function);

PyObject *numbers = PyImport_ImportModule("numbers");
Expand Down
1 change: 0 additions & 1 deletion tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def test_isqrt(x):
assert isqrt(mx) == isqrt(x) == r


@pytest.mark.xfail(reason="diofant/python-gmp#12")
@given(integers(min_value=0, max_value=12345))
def test_factorial(x):
mx = mpz(x)
Expand Down

0 comments on commit d14532c

Please sign in to comment.