From 3b8fee0635faf945ec9c1a43f023e62718f93204 Mon Sep 17 00:00:00 2001 From: anjakefala Date: Tue, 19 Jul 2022 14:50:22 -0700 Subject: [PATCH] ARROW-17131: [Python] add StructType().field(): returns a field by name or index --- python/pyarrow/tests/test_types.py | 10 ++++++ python/pyarrow/types.pxi | 55 ++++++++++++++++++++++++++---- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py index 8cb7cea684274..cabf69ed07af0 100644 --- a/python/pyarrow/tests/test_types.py +++ b/python/pyarrow/tests/test_types.py @@ -577,14 +577,24 @@ def test_struct_type(): assert ty['b'] == ty[2] + assert ty['b'] == ty.field('b') + + assert ty[2] == ty.field(2) + # Not found with pytest.raises(KeyError): ty['c'] + with pytest.raises(KeyError): + ty.field('c') + # Neither integer nor string with pytest.raises(TypeError): ty[None] + with pytest.raises(TypeError): + ty.field(None) + for a, b in zip(ty, fields): a == b diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 8407f95c984c3..1dae52f2fef81 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -429,12 +429,23 @@ cdef class StructType(DataType): Examples -------- >>> import pyarrow as pa + + Accessing fields using direct indexing: + >>> struct_type = pa.struct({'x': pa.int32(), 'y': pa.string()}) >>> struct_type[0] pyarrow.Field >>> struct_type['y'] pyarrow.Field + Accessing fields using ``field()``: + + >>> struct_type.field(1) + pyarrow.Field + >>> struct_type.field('x') + pyarrow.Field + + # Creating a schema from the struct type's fields: >>> pa.schema(list(struct_type)) x: int32 y: string @@ -494,6 +505,41 @@ cdef class StructType(DataType): """ return self.struct_type.GetFieldIndex(tobytes(name)) + def field(self, i): + """ + Select a field by its column name or numeric index. + + Parameters + ---------- + i : int or str + + Returns + ------- + pyarrow.Field + + Examples + -------- + + >>> import pyarrow as pa + >>> struct_type = pa.struct({'x': pa.int32(), 'y': pa.string()}) + + Select the second field: + + >>> struct_type.field(1) + pyarrow.Field + + Select the field named 'x': + + >>> struct_type.field('x') + pyarrow.Field + """ + if isinstance(i, (bytes, str)): + return self.field_by_name(i) + elif isinstance(i, int): + return DataType.field(self, i) + else: + raise TypeError('Expected integer or string index') + def get_all_field_indices(self, name): """ Return sorted list of indices for the fields with the given name. @@ -525,13 +571,10 @@ cdef class StructType(DataType): def __getitem__(self, i): """ Return the struct field with the given index or name. + + Alias of ``field``. """ - if isinstance(i, (bytes, str)): - return self.field_by_name(i) - elif isinstance(i, int): - return self.field(i) - else: - raise TypeError('Expected integer or string index') + return self.field(i) def __reduce__(self): return struct, (list(self),)