Skip to content

Commit

Permalink
pythongh-101277: Add chain type to module state
Browse files Browse the repository at this point in the history
  • Loading branch information
erlend-aasland committed Feb 1, 2023
1 parent 50316a7 commit da955ae
Showing 1 changed file with 38 additions and 58 deletions.
96 changes: 38 additions & 58 deletions Modules/itertoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

typedef struct {
PyTypeObject *accumulate_type;
PyTypeObject *chain_type;
PyTypeObject *combinations_type;
PyTypeObject *compress_type;
PyTypeObject *count_type;
Expand Down Expand Up @@ -70,7 +71,7 @@ class itertools.cycle "cycleobject *" "clinic_state()->cycle_type"
class itertools.dropwhile "dropwhileobject *" "clinic_state()->dropwhile_type"
class itertools.takewhile "takewhileobject *" "clinic_state()->takewhile_type"
class itertools.starmap "starmapobject *" "clinic_state()->starmap_type"
class itertools.chain "chainobject *" "&chain_type"
class itertools.chain "chainobject *" "clinic_state()->chain_type"
class itertools.combinations "combinationsobject *" "clinic_state()->combinations_type"
class itertools.combinations_with_replacement "cwr_object *" "clinic_state()->cwr_type"
class itertools.permutations "permutationsobject *" "clinic_state()->permutations_type"
Expand All @@ -80,7 +81,7 @@ class itertools.filterfalse "filterfalseobject *" "clinic_state()->filterfalse_t
class itertools.count "countobject *" "clinic_state()->count_type"
class itertools.pairwise "pairwiseobject *" "clinic_state()->pairwise_type"
[clinic start generated code]*/
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=28ffff5c0c93eed7]*/
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=102319065d3e090b]*/

static PyTypeObject teedataobject_type;
static PyTypeObject tee_type;
Expand Down Expand Up @@ -2038,8 +2039,6 @@ typedef struct {
PyObject *active; /* Currently running input iterator */
} chainobject;

static PyTypeObject chain_type;

static PyObject *
chain_new_internal(PyTypeObject *type, PyObject *source)
{
Expand All @@ -2061,9 +2060,12 @@ chain_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
PyObject *source;

if ((type == &chain_type || type->tp_init == chain_type.tp_init) &&
!_PyArg_NoKeywords("chain", kwds))
itertools_state *st = find_state_by_type(type);
if ((type == st->chain_type || type->tp_init == st->chain_type->tp_init)
&& !_PyArg_NoKeywords("chain", kwds))
{
return NULL;
}

source = PyObject_GetIter(args);
if (source == NULL)
Expand Down Expand Up @@ -2096,15 +2098,18 @@ itertools_chain_from_iterable(PyTypeObject *type, PyObject *arg)
static void
chain_dealloc(chainobject *lz)
{
PyTypeObject *tp = Py_TYPE(lz);
PyObject_GC_UnTrack(lz);
Py_XDECREF(lz->active);
Py_XDECREF(lz->source);
Py_TYPE(lz)->tp_free(lz);
tp->tp_free(lz);
Py_DECREF(tp);
}

static int
chain_traverse(chainobject *lz, visitproc visit, void *arg)
{
Py_VISIT(Py_TYPE(lz));
Py_VISIT(lz->source);
Py_VISIT(lz->active);
return 0;
Expand Down Expand Up @@ -2209,48 +2214,24 @@ static PyMethodDef chain_methods[] = {
{NULL, NULL} /* sentinel */
};

static PyTypeObject chain_type = {
PyVarObject_HEAD_INIT(NULL, 0)
"itertools.chain", /* tp_name */
sizeof(chainobject), /* tp_basicsize */
0, /* tp_itemsize */
/* methods */
(destructor)chain_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_as_async */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_BASETYPE, /* tp_flags */
chain_doc, /* tp_doc */
(traverseproc)chain_traverse, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(iternextfunc)chain_next, /* tp_iternext */
chain_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
chain_new, /* tp_new */
PyObject_GC_Del, /* tp_free */
static PyType_Slot chain_slots[] = {
{Py_tp_dealloc, chain_dealloc},
{Py_tp_getattro, PyObject_GenericGetAttr},
{Py_tp_doc, (void *)chain_doc},
{Py_tp_traverse, chain_traverse},
{Py_tp_iter, PyObject_SelfIter},
{Py_tp_iternext, chain_next},
{Py_tp_methods, chain_methods},
{Py_tp_new, chain_new},
{Py_tp_free, PyObject_GC_Del},
{0, NULL},
};

static PyType_Spec chain_spec = {
.name = "itertools.chain",
.basicsize = sizeof(chainobject),
.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE,
.slots = chain_slots,
};


Expand Down Expand Up @@ -3656,13 +3637,14 @@ accumulate_next(accumulateobject *lz)
static PyObject *
accumulate_reduce(accumulateobject *lz, PyObject *Py_UNUSED(ignored))
{
PyTypeObject *tp = Py_TYPE(lz);
itertools_state *state = find_state_by_type(tp);

if (lz->initial != Py_None) {
PyObject *it;

assert(lz->total == NULL);
if (PyType_Ready(&chain_type) < 0)
return NULL;
it = PyObject_CallFunction((PyObject *)&chain_type, "(O)O",
it = PyObject_CallFunction((PyObject *)(state->chain_type), "(O)O",
lz->initial, lz->it);
if (it == NULL)
return NULL;
Expand All @@ -3672,9 +3654,7 @@ accumulate_reduce(accumulateobject *lz, PyObject *Py_UNUSED(ignored))
if (lz->total == Py_None) {
PyObject *it;

if (PyType_Ready(&chain_type) < 0)
return NULL;
it = PyObject_CallFunction((PyObject *)&chain_type, "(O)O",
it = PyObject_CallFunction((PyObject *)(state->chain_type), "(O)O",
lz->total, lz->it);
if (it == NULL)
return NULL;
Expand All @@ -3683,8 +3663,6 @@ accumulate_reduce(accumulateobject *lz, PyObject *Py_UNUSED(ignored))
if (it == NULL)
return NULL;

PyTypeObject *tp = Py_TYPE(lz);
itertools_state *state = find_state_by_type(tp);
return Py_BuildValue("O(NiO)", state->islice_type, it, 1, Py_None);
}
return Py_BuildValue("O(OO)O", Py_TYPE(lz),
Expand Down Expand Up @@ -4656,6 +4634,7 @@ itertoolsmodule_traverse(PyObject *mod, visitproc visit, void *arg)
{
itertools_state *state = get_module_state(mod);
Py_VISIT(state->accumulate_type);
Py_VISIT(state->chain_type);
Py_VISIT(state->combinations_type);
Py_VISIT(state->compress_type);
Py_VISIT(state->count_type);
Expand All @@ -4681,6 +4660,7 @@ itertoolsmodule_clear(PyObject *mod)
{
itertools_state *state = get_module_state(mod);
Py_CLEAR(state->accumulate_type);
Py_CLEAR(state->chain_type);
Py_CLEAR(state->combinations_type);
Py_CLEAR(state->compress_type);
Py_CLEAR(state->count_type);
Expand Down Expand Up @@ -4723,6 +4703,7 @@ itertoolsmodule_exec(PyObject *mod)
{
itertools_state *state = get_module_state(mod);
ADD_TYPE(mod, state->accumulate_type, &accumulate_spec);
ADD_TYPE(mod, state->chain_type, &chain_spec);
ADD_TYPE(mod, state->combinations_type, &combinations_spec);
ADD_TYPE(mod, state->compress_type, &compress_spec);
ADD_TYPE(mod, state->count_type, &count_spec);
Expand All @@ -4743,7 +4724,6 @@ itertoolsmodule_exec(PyObject *mod)

PyTypeObject *typelist[] = {
&batched_type,
&chain_type,
&tee_type,
&teedataobject_type
};
Expand Down

0 comments on commit da955ae

Please sign in to comment.