Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-43977: Use tp_flags for collection matching #25723

Merged
Merged
7 changes: 7 additions & 0 deletions Include/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ Code can use PyType_HasFeature(type_ob, flag_value) to test whether the
given type object has a specified feature.
*/

#ifndef Py_LIMITED_API
/* Set if instances of the type object are treated as sequences for pattern matching */
#define Py_TPFLAGS_SEQUENCE (1 << 5)
/* Set if instances of the type object are treated as mappings for pattern matching */
#define Py_TPFLAGS_MAPPING (1 << 6)
#endif

/* Set if the type object is immutable: type attributes cannot be set nor deleted */
#define Py_TPFLAGS_IMMUTABLETYPE (1UL << 8)

Expand Down
8 changes: 6 additions & 2 deletions Lib/_collections_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ def __isub__(self, it):

### MAPPINGS ###

TPFLAGS_MAPPING = 1 << 6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is public and can be relied upon, it needs to be added to the docs for collections.abc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be internal. I'll hide it.


class Mapping(Collection):
"""A Mapping is a generic container for associating key/value
Expand All @@ -804,6 +805,8 @@ class Mapping(Collection):

__slots__ = ()

__flags__ = TPFLAGS_MAPPING

@abstractmethod
def __getitem__(self, key):
raise KeyError
Expand Down Expand Up @@ -842,7 +845,6 @@ def __eq__(self, other):

__reversed__ = None


Mapping.register(mappingproxy)


Expand Down Expand Up @@ -1011,6 +1013,7 @@ def setdefault(self, key, default=None):

### SEQUENCES ###

TPFLAGS_SEQUENCE = 1 << 5

class Sequence(Reversible, Collection):
"""All the operations on a read-only sequence.
Expand All @@ -1021,6 +1024,8 @@ class Sequence(Reversible, Collection):

__slots__ = ()

__flags__ = TPFLAGS_SEQUENCE

@abstractmethod
def __getitem__(self, index):
raise IndexError
Expand Down Expand Up @@ -1072,7 +1077,6 @@ def count(self, value):
'S.count(value) -> integer -- return number of occurrences of value'
return sum(1 for v in self if v is value or v == value)


Sequence.register(tuple)
Sequence.register(str)
Sequence.register(range)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Use tp_flags on the class object to determine if the subject is a sequence
or mapping when pattern matching. Avoids the need to import collections.abc
when pattern matching.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some markup to make this a bit richer:

Suggested change
Use tp_flags on the class object to determine if the subject is a sequence
or mapping when pattern matching. Avoids the need to import collections.abc
when pattern matching.
Use :c:member:`~PyTypeObject.tp_flags` on the class object to determine if the subject is a sequence
or mapping when pattern matching. Avoids the need to import :mod:`collections.abc`
when pattern matching.

29 changes: 29 additions & 0 deletions Modules/_abc.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ PyDoc_STRVAR(_abc__doc__,
_Py_IDENTIFIER(__abstractmethods__);
_Py_IDENTIFIER(__class__);
_Py_IDENTIFIER(__dict__);
_Py_IDENTIFIER(__flags__);
_Py_IDENTIFIER(__bases__);
_Py_IDENTIFIER(_abc_impl);
_Py_IDENTIFIER(__subclasscheck__);
Expand Down Expand Up @@ -417,6 +418,8 @@ compute_abstract_methods(PyObject *self)
return ret;
}

#define COLLECTION_FLAGS (Py_TPFLAGS_SEQUENCE | Py_TPFLAGS_MAPPING)

/*[clinic input]
_abc._abc_init

Expand Down Expand Up @@ -446,6 +449,27 @@ _abc__abc_init(PyObject *module, PyObject *self)
return NULL;
}
Py_DECREF(data);
if (PyType_Check(self)) {
PyTypeObject *cls = (PyTypeObject *)self;
PyObject *flags = _PyDict_GetItemIdWithError(cls->tp_dict, &PyId___flags__);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this whole mechanism makes me a bit uncomfortable. I'm not aware of any case where writing to __flags__ in Pythonland actually works like this.

But I can't really think of anything better, so I guess it's fine. Maybe we could just use a name like _abc_tpflags or something instead of __flags__?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it is just a name. I like using a dunder name, as it is a special thing. __abc_tpflags__ perhaps?

if (flags == NULL) {
if (PyErr_Occurred()) {
return NULL;
}
}
else {
if (PyLong_CheckExact(flags)) {
long val = PyLong_AsLong(flags);
if (val == -1 && PyErr_Occurred()) {
return NULL;
}
((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS);
}
if (_PyDict_DelItemId(cls->tp_dict, &PyId___flags__) < 0) {
return NULL;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we deleting __flags__ here? It looks like it will get overwritten anyways:

>>> class C:
...     __flags__ = "Spam"
... 
>>> C.__flags__
284160

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __flags__ set in Sequence is in Sequence.__dict__, but normally __flags__ is a descriptor.
So, it is probably safest to delete it. I suspect it doesn't matter much in practice.

}
}
Py_RETURN_NONE;
}

Expand Down Expand Up @@ -499,6 +523,11 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
/* Invalidate negative cache */
get_abc_state(module)->abc_invalidation_counter++;

if (PyType_Check(subclass) && PyType_Check(self) &&
!PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

{
((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS);
}
Py_INCREF(subclass);
return subclass;
}
Expand Down
3 changes: 2 additions & 1 deletion Modules/_collectionsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -1662,7 +1662,8 @@ static PyTypeObject deque_type = {
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_SEQUENCE,
/* tp_flags */
deque_doc, /* tp_doc */
(traverseproc)deque_traverse, /* tp_traverse */
Expand Down
3 changes: 2 additions & 1 deletion Modules/arraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -2848,7 +2848,8 @@ static PyType_Spec array_spec = {
.name = "array.array",
.basicsize = sizeof(arrayobject),
.flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
Py_TPFLAGS_IMMUTABLETYPE),
Py_TPFLAGS_IMMUTABLETYPE |
Py_TPFLAGS_SEQUENCE),
.slots = array_slots,
};

Expand Down
3 changes: 2 additions & 1 deletion Objects/descrobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,8 @@ PyTypeObject PyDictProxy_Type = {
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_MAPPING, /* tp_flags */
0, /* tp_doc */
mappingproxy_traverse, /* tp_traverse */
0, /* tp_clear */
Expand Down
2 changes: 1 addition & 1 deletion Objects/dictobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -3553,7 +3553,7 @@ PyTypeObject PyDict_Type = {
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_BASETYPE | Py_TPFLAGS_DICT_SUBCLASS |
_Py_TPFLAGS_MATCH_SELF, /* tp_flags */
_Py_TPFLAGS_MATCH_SELF | Py_TPFLAGS_MAPPING, /* tp_flags */
dictionary_doc, /* tp_doc */
dict_traverse, /* tp_traverse */
dict_tp_clear, /* tp_clear */
Expand Down
2 changes: 1 addition & 1 deletion Objects/listobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -3053,7 +3053,7 @@ PyTypeObject PyList_Type = {
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_BASETYPE | Py_TPFLAGS_LIST_SUBCLASS |
_Py_TPFLAGS_MATCH_SELF, /* tp_flags */
_Py_TPFLAGS_MATCH_SELF | Py_TPFLAGS_SEQUENCE, /* tp_flags */
list___init____doc__, /* tp_doc */
(traverseproc)list_traverse, /* tp_traverse */
(inquiry)_list_clear, /* tp_clear */
Expand Down
3 changes: 2 additions & 1 deletion Objects/memoryobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -3287,7 +3287,8 @@ PyTypeObject PyMemoryView_Type = {
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
&memory_as_buffer, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_SEQUENCE, /* tp_flags */
memoryview__doc__, /* tp_doc */
(traverseproc)memory_traverse, /* tp_traverse */
(inquiry)memory_clear, /* tp_clear */
Expand Down
2 changes: 1 addition & 1 deletion Objects/rangeobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ PyTypeObject PyRange_Type = {
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_SEQUENCE, /* tp_flags */
range_doc, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
Expand Down
2 changes: 1 addition & 1 deletion Objects/tupleobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ PyTypeObject PyTuple_Type = {
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_BASETYPE | Py_TPFLAGS_TUPLE_SUBCLASS |
_Py_TPFLAGS_MATCH_SELF, /* tp_flags */
_Py_TPFLAGS_MATCH_SELF | Py_TPFLAGS_SEQUENCE, /* tp_flags */
tuple_new__doc__, /* tp_doc */
(traverseproc)tupletraverse, /* tp_traverse */
0, /* tp_clear */
Expand Down
7 changes: 6 additions & 1 deletion Objects/typeobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -5648,10 +5648,15 @@ inherit_special(PyTypeObject *type, PyTypeObject *base)
else if (PyType_IsSubtype(base, &PyDict_Type)) {
type->tp_flags |= Py_TPFLAGS_DICT_SUBCLASS;
}

if (PyType_HasFeature(base, _Py_TPFLAGS_MATCH_SELF)) {
type->tp_flags |= _Py_TPFLAGS_MATCH_SELF;
}
if (PyType_HasFeature(base, Py_TPFLAGS_SEQUENCE)) {
type->tp_flags |= Py_TPFLAGS_SEQUENCE;
}
if (PyType_HasFeature(base, Py_TPFLAGS_MAPPING)) {
type->tp_flags |= Py_TPFLAGS_MAPPING;
}
}

static int
Expand Down
72 changes: 8 additions & 64 deletions Python/ceval.c
Original file line number Diff line number Diff line change
Expand Up @@ -3889,76 +3889,20 @@ _PyEval_EvalFrameDefault(PyThreadState *tstate, PyFrameObject *f, int throwflag)
}

case TARGET(MATCH_MAPPING): {
// PUSH(isinstance(TOS, _collections_abc.Mapping))
PyObject *subject = TOP();
// Fast path for dicts:
if (PyDict_Check(subject)) {
Py_INCREF(Py_True);
PUSH(Py_True);
DISPATCH();
}
// Lazily import _collections_abc.Mapping, and keep it handy on the
// PyInterpreterState struct (it gets cleaned up at exit):
PyInterpreterState *interp = PyInterpreterState_Get();
if (interp->map_abc == NULL) {
PyObject *abc = PyImport_ImportModule("_collections_abc");
if (abc == NULL) {
goto error;
}
interp->map_abc = PyObject_GetAttrString(abc, "Mapping");
if (interp->map_abc == NULL) {
goto error;
}
}
int match = PyObject_IsInstance(subject, interp->map_abc);
if (match < 0) {
goto error;
}
PUSH(PyBool_FromLong(match));
int match = Py_TYPE(subject)->tp_flags & Py_TPFLAGS_MAPPING;
PyObject *res = match ? Py_True : Py_False;
Py_INCREF(res);
PUSH(res);
DISPATCH();
}

case TARGET(MATCH_SEQUENCE): {
// PUSH(not isinstance(TOS, (bytearray, bytes, str))
// and isinstance(TOS, _collections_abc.Sequence))
PyObject *subject = TOP();
// Fast path for lists and tuples:
if (PyType_FastSubclass(Py_TYPE(subject),
Py_TPFLAGS_LIST_SUBCLASS |
Py_TPFLAGS_TUPLE_SUBCLASS))
{
Py_INCREF(Py_True);
PUSH(Py_True);
DISPATCH();
}
// Bail on some possible Sequences that we intentionally exclude:
if (PyType_FastSubclass(Py_TYPE(subject),
Py_TPFLAGS_BYTES_SUBCLASS |
Py_TPFLAGS_UNICODE_SUBCLASS) ||
PyByteArray_Check(subject))
{
Py_INCREF(Py_False);
PUSH(Py_False);
DISPATCH();
}
// Lazily import _collections_abc.Sequence, and keep it handy on the
// PyInterpreterState struct (it gets cleaned up at exit):
PyInterpreterState *interp = PyInterpreterState_Get();
if (interp->seq_abc == NULL) {
PyObject *abc = PyImport_ImportModule("_collections_abc");
if (abc == NULL) {
goto error;
}
interp->seq_abc = PyObject_GetAttrString(abc, "Sequence");
if (interp->seq_abc == NULL) {
goto error;
}
}
int match = PyObject_IsInstance(subject, interp->seq_abc);
if (match < 0) {
goto error;
}
PUSH(PyBool_FromLong(match));
int match = Py_TYPE(subject)->tp_flags & Py_TPFLAGS_SEQUENCE;
PyObject *res = match ? Py_True : Py_False;
Py_INCREF(res);
PUSH(res);
DISPATCH();
}

Expand Down