From c774d5df1c866cbb594ae4f2af6ed9c68aec5fe8 Mon Sep 17 00:00:00 2001 From: Anja Kefala Date: Mon, 22 Aug 2022 06:31:05 -0700 Subject: [PATCH] ARROW-17131: [Python] add StructType().field(): returns a field by name or index (#13652) Authored-by: anjakefala Signed-off-by: David Li --- python/pyarrow/tests/test_types.py | 11 +++++ python/pyarrow/types.pxi | 74 +++++++++++++++++++++++++++--- 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py index 8cb7cea684274..0ef9f5a86ec6f 100644 --- a/python/pyarrow/tests/test_types.py +++ b/python/pyarrow/tests/test_types.py @@ -577,14 +577,24 @@ def test_struct_type(): assert ty['b'] == ty[2] + assert ty['b'] == ty.field('b') + + assert ty[2] == ty.field(2) + # Not found with pytest.raises(KeyError): ty['c'] + with pytest.raises(KeyError): + ty.field('c') + # Neither integer nor string with pytest.raises(TypeError): ty[None] + with pytest.raises(TypeError): + ty.field(None) + for a, b in zip(ty, fields): a == b @@ -634,6 +644,7 @@ def test_union_type(): def check_fields(ty, fields): assert ty.num_fields == len(fields) assert [ty[i] for i in range(ty.num_fields)] == fields + assert [ty.field(i) for i in range(ty.num_fields)] == fields fields = [pa.field('x', pa.list_(pa.int32())), pa.field('y', pa.binary())] diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 8407f95c984c3..1babbc41549c7 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -429,12 +429,23 @@ cdef class StructType(DataType): Examples -------- >>> import pyarrow as pa + + Accessing fields using direct indexing: + >>> struct_type = pa.struct({'x': pa.int32(), 'y': pa.string()}) >>> struct_type[0] pyarrow.Field >>> struct_type['y'] pyarrow.Field + Accessing fields using ``field()``: + + >>> struct_type.field(1) + pyarrow.Field + >>> struct_type.field('x') + pyarrow.Field + + # Creating a schema from the struct type's fields: >>> pa.schema(list(struct_type)) x: int32 y: string @@ -494,6 +505,41 @@ cdef class StructType(DataType): """ return self.struct_type.GetFieldIndex(tobytes(name)) + def field(self, i): + """ + Select a field by its column name or numeric index. + + Parameters + ---------- + i : int or str + + Returns + ------- + pyarrow.Field + + Examples + -------- + + >>> import pyarrow as pa + >>> struct_type = pa.struct({'x': pa.int32(), 'y': pa.string()}) + + Select the second field: + + >>> struct_type.field(1) + pyarrow.Field + + Select the field named 'x': + + >>> struct_type.field('x') + pyarrow.Field + """ + if isinstance(i, (bytes, str)): + return self.field_by_name(i) + elif isinstance(i, int): + return DataType.field(self, i) + else: + raise TypeError('Expected integer or string index') + def get_all_field_indices(self, name): """ Return sorted list of indices for the fields with the given name. @@ -525,13 +571,10 @@ cdef class StructType(DataType): def __getitem__(self, i): """ Return the struct field with the given index or name. + + Alias of ``field``. """ - if isinstance(i, (bytes, str)): - return self.field_by_name(i) - elif isinstance(i, int): - return self.field(i) - else: - raise TypeError('Expected integer or string index') + return self.field(i) def __reduce__(self): return struct, (list(self),) @@ -579,9 +622,28 @@ cdef class UnionType(DataType): for i in range(len(self)): yield self[i] + def field(self, i): + """ + Return a child field by its numeric index. + + Parameters + ---------- + i : int + + Returns + ------- + pyarrow.Field + """ + if isinstance(i, int): + return DataType.field(self, i) + else: + raise TypeError('Expected integer') + def __getitem__(self, i): """ Return a child field by its index. + + Alias of ``field``. """ return self.field(i)