Skip to content

Commit

Permalink
apacheGH-43388: [Python] Give precedence to pycapsule interface in pa…
Browse files Browse the repository at this point in the history
….schema(..)
  • Loading branch information
jorisvandenbossche committed Jul 30, 2024
1 parent f00a306 commit db7dfeb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
35 changes: 27 additions & 8 deletions python/pyarrow/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

from collections import OrderedDict
from collections.abc import Iterator
from collections.abc import Iterator, Mapping
from functools import partial
import datetime
import sys
Expand Down Expand Up @@ -1325,17 +1325,36 @@ def test_types_come_back_with_specific_type():
assert type(type_back) is type(arrow_type)


def test_schema_import_c_schema_interface():
class Wrapper:
def __init__(self, schema):
self.schema = schema
class SchemaWrapper:
def __init__(self, schema):
self.schema = schema

def __arrow_c_schema__(self):
return self.schema.__arrow_c_schema__()


class SchemaMapping(Mapping):
def __init__(self, schema):
self.schema = schema

def __arrow_c_schema__(self):
return self.schema.__arrow_c_schema__()

def __getitem__(self, key):
return self.schema[key]

def __iter__(self):
return iter(self.schema)

def __len__(self):
return len(self.schema)

def __arrow_c_schema__(self):
return self.schema.__arrow_c_schema__()

@pytest.mark.parametrize("wrapper_class", [SchemaWrapper, SchemaMapping])
def test_schema_import_c_schema_interface(wrapper_class):
schema = pa.schema([pa.field("field_name", pa.int32())], metadata={"a": "b"})
assert schema.metadata == {b"a": b"b"}
wrapped_schema = Wrapper(schema)
wrapped_schema = wrapper_class(schema)

assert pa.schema(wrapped_schema) == schema
assert pa.schema(wrapped_schema).metadata == {b"a": b"b"}
Expand Down
7 changes: 4 additions & 3 deletions python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -5347,14 +5347,15 @@ def schema(fields, metadata=None):
Field py_field
vector[shared_ptr[CField]] c_fields

if isinstance(fields, Mapping):
fields = fields.items()
elif hasattr(fields, "__arrow_c_schema__"):
if hasattr(fields, "__arrow_c_schema__"):
result = Schema._import_from_c_capsule(fields.__arrow_c_schema__())
if metadata is not None:
result = result.with_metadata(metadata)
return result

if isinstance(fields, Mapping):
fields = fields.items()

for item in fields:
if isinstance(item, tuple):
py_field = field(*item)
Expand Down

0 comments on commit db7dfeb

Please sign in to comment.