From 7a9eeaaea988b3fc2c6e59baa67f9ff9b01385d6 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 29 Jul 2024 20:26:29 -0400 Subject: [PATCH] More DataType methods (#78) * Export DataType constructors * Added DataType methods --- DEVELOP.md | 7 +- arro3-core/python/arro3/core/_core.pyi | 145 ++++++++++++ docs/api/core/datatype.md | 8 + mkdocs.yml | 1 + pyo3-arrow/src/datatypes.rs | 295 ++++++++++++++++++++++++- 5 files changed, 447 insertions(+), 9 deletions(-) create mode 100644 docs/api/core/datatype.md diff --git a/DEVELOP.md b/DEVELOP.md index c681f62..d01772a 100644 --- a/DEVELOP.md +++ b/DEVELOP.md @@ -5,9 +5,8 @@ rm -rf .venv poetry install # Note: need to install core first because others depend on core -poetry run maturin build -m arro3-core/Cargo.toml -o dist -poetry run maturin build -m arro3-compute/Cargo.toml -o dist -poetry run maturin build -m arro3-io/Cargo.toml -o dist -poetry run pip install dist/* +poetry run maturin develop -m arro3-core/Cargo.toml +poetry run maturin develop -m arro3-compute/Cargo.toml +poetry run maturin develop -m arro3-io/Cargo.toml poetry run mkdocs serve ``` diff --git a/arro3-core/python/arro3/core/_core.pyi b/arro3-core/python/arro3/core/_core.pyi index a31e6b1..14107e8 100644 --- a/arro3-core/python/arro3/core/_core.pyi +++ b/arro3-core/python/arro3/core/_core.pyi @@ -124,8 +124,43 @@ class DataType: @classmethod def from_arrow_pycapsule(cls, capsule) -> DataType: """Construct this object from a bare Arrow PyCapsule""" + @property def bit_width(self) -> int | None: ... + def equals( + self, other: ArrowSchemaExportable, *, check_metadata: bool = False + ) -> bool: + """Return true if type is equivalent to passed value. + + Args: + other: _description_ + check_metadata: Whether nested Field metadata equality should be checked as well. Defaults to False. + + Returns: + _description_ + """ + @property + def list_size(self) -> int | None: + """The size of the list in the case of fixed size lists. + + This will return `None` if the data type is not a fixed size list. + + Examples: + + ```py + from arro3.core import DataType + DataType.list(DataType.int32(), 2).list_size + # 2 + ``` + + Returns: + _description_ + """ + @property + def num_fields(self) -> int: + """The number of child fields.""" + ################# #### Constructors + ################# @classmethod def null(cls) -> DataType: """Create instance of null type.""" @@ -383,6 +418,116 @@ class DataType: _description_ """ + ################## + #### Type Checking + ################## + @staticmethod + def is_boolean(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_integer(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_signed_integer(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_unsigned_integer(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_int8(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_int16(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_int32(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_int64(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_uint8(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_uint16(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_uint32(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_uint64(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_floating(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_float16(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_float32(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_float64(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_decimal(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_decimal128(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_decimal256(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_list(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_large_list(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_fixed_size_list(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_list_view(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_large_list_view(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_struct(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_union(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_nested(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_run_end_encoded(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_temporal(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_timestamp(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_date(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_date32(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_date64(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_time(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_time32(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_time64(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_duration(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_interval(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_null(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_binary(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_unicode(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_string(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_large_binary(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_large_unicode(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_large_string(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_binary_view(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_string_view(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_fixed_size_binary(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_map(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_dictionary(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_primitive(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_numeric(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_dictionary_key_type(t: ArrowSchemaExportable) -> bool: ... + class Field: def __init__( self, diff --git a/docs/api/core/datatype.md b/docs/api/core/datatype.md new file mode 100644 index 0000000..672baed --- /dev/null +++ b/docs/api/core/datatype.md @@ -0,0 +1,8 @@ +# arro3.core.DataType + +::: arro3.core.DataType + options: + filters: + - "!^_" + - "^__arrow" + show_if_no_docstring: true diff --git a/mkdocs.yml b/mkdocs.yml index 91fa4a7..074cbd7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,6 +28,7 @@ nav: - api/core/record-batch.md - api/core/schema.md - api/core/table.md + - api/core/datatype.md - api/core/types.md - api/compute.md - api/io.md diff --git a/pyo3-arrow/src/datatypes.rs b/pyo3-arrow/src/datatypes.rs index 73ac550..a5799b4 100644 --- a/pyo3-arrow/src/datatypes.rs +++ b/pyo3-arrow/src/datatypes.rs @@ -118,8 +118,8 @@ impl PyDataType { to_schema_pycapsule(py, &self.0) } - pub fn __eq__(&self, other: &PyDataType) -> bool { - self.0 == other.0 + pub fn __eq__(&self, other: PyDataType) -> bool { + self.equals(other, false) } pub fn __repr__(&self) -> String { @@ -147,10 +147,75 @@ impl PyDataType { Ok(Self::new(data_type)) } + #[getter] pub fn bit_width(&self) -> Option { self.0.primitive_width() } + #[pyo3(signature=(other, *, check_metadata=false))] + fn equals(&self, other: PyDataType, check_metadata: bool) -> bool { + let other = other.into_inner(); + if check_metadata { + self.0 == other + } else { + self.0.equals_datatype(&other) + } + } + + #[getter] + fn list_size(&self) -> Option { + match &self.0 { + DataType::FixedSizeList(_, list_size) => Some(*list_size), + _ => None, + } + } + + #[getter] + fn num_fields(&self) -> usize { + match &self.0 { + DataType::Null + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Timestamp(_, _) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Interval(_) + | DataType::Binary + | DataType::FixedSizeBinary(_) + | DataType::LargeBinary + | DataType::BinaryView + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => 0, + DataType::List(_) + | DataType::ListView(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + | DataType::LargeListView(_) => 1, + DataType::Struct(fields) => fields.len(), + DataType::Union(fields, _) => fields.len(), + // Is this accurate? + DataType::Dictionary(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => 2, + } + } + + ///////////////////// Constructors + #[classmethod] fn null(_: &Bound) -> Self { Self(DataType::Null) @@ -361,10 +426,10 @@ impl PyDataType { } #[classmethod] - fn dictionary(_: &Bound, index_type: PyField, value_type: PyField) -> Self { + fn dictionary(_: &Bound, index_type: PyDataType, value_type: PyDataType) -> Self { Self(DataType::Dictionary( - Box::new(index_type.into_inner().data_type().clone()), - Box::new(value_type.into_inner().data_type().clone()), + Box::new(index_type.into_inner()), + Box::new(value_type.into_inner()), )) } @@ -375,4 +440,224 @@ impl PyDataType { value_type.into_inner(), )) } + + ///////////////////// Type checking + + #[staticmethod] + fn is_boolean(t: PyDataType) -> bool { + t.0 == DataType::Boolean + } + + #[staticmethod] + fn is_integer(t: PyDataType) -> bool { + t.0.is_integer() + } + + #[staticmethod] + fn is_signed_integer(t: PyDataType) -> bool { + t.0.is_signed_integer() + } + + #[staticmethod] + fn is_unsigned_integer(t: PyDataType) -> bool { + t.0.is_unsigned_integer() + } + + #[staticmethod] + fn is_int8(t: PyDataType) -> bool { + t.0 == DataType::Int8 + } + #[staticmethod] + fn is_int16(t: PyDataType) -> bool { + t.0 == DataType::Int16 + } + #[staticmethod] + fn is_int32(t: PyDataType) -> bool { + t.0 == DataType::Int32 + } + #[staticmethod] + fn is_int64(t: PyDataType) -> bool { + t.0 == DataType::Int64 + } + #[staticmethod] + fn is_uint8(t: PyDataType) -> bool { + t.0 == DataType::UInt8 + } + #[staticmethod] + fn is_uint16(t: PyDataType) -> bool { + t.0 == DataType::UInt16 + } + #[staticmethod] + fn is_uint32(t: PyDataType) -> bool { + t.0 == DataType::UInt32 + } + #[staticmethod] + fn is_uint64(t: PyDataType) -> bool { + t.0 == DataType::UInt64 + } + #[staticmethod] + fn is_floating(t: PyDataType) -> bool { + t.0.is_floating() + } + #[staticmethod] + fn is_float16(t: PyDataType) -> bool { + t.0 == DataType::Float16 + } + #[staticmethod] + fn is_float32(t: PyDataType) -> bool { + t.0 == DataType::Float32 + } + #[staticmethod] + fn is_float64(t: PyDataType) -> bool { + t.0 == DataType::Float64 + } + #[staticmethod] + fn is_decimal(t: PyDataType) -> bool { + matches!(t.0, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) + } + #[staticmethod] + fn is_decimal128(t: PyDataType) -> bool { + matches!(t.0, DataType::Decimal128(_, _)) + } + #[staticmethod] + fn is_decimal256(t: PyDataType) -> bool { + matches!(t.0, DataType::Decimal256(_, _)) + } + + #[staticmethod] + fn is_list(t: PyDataType) -> bool { + matches!(t.0, DataType::List(_)) + } + #[staticmethod] + fn is_large_list(t: PyDataType) -> bool { + matches!(t.0, DataType::LargeList(_)) + } + #[staticmethod] + fn is_fixed_size_list(t: PyDataType) -> bool { + matches!(t.0, DataType::FixedSizeList(_, _)) + } + #[staticmethod] + fn is_list_view(t: PyDataType) -> bool { + matches!(t.0, DataType::ListView(_)) + } + #[staticmethod] + fn is_large_list_view(t: PyDataType) -> bool { + matches!(t.0, DataType::LargeListView(_)) + } + #[staticmethod] + fn is_struct(t: PyDataType) -> bool { + matches!(t.0, DataType::Struct(_)) + } + #[staticmethod] + fn is_union(t: PyDataType) -> bool { + matches!(t.0, DataType::Union(_, _)) + } + #[staticmethod] + fn is_nested(t: PyDataType) -> bool { + t.0.is_nested() + } + #[staticmethod] + fn is_run_end_encoded(t: PyDataType) -> bool { + t.0.is_run_ends_type() + } + #[staticmethod] + fn is_temporal(t: PyDataType) -> bool { + t.0.is_temporal() + } + #[staticmethod] + fn is_timestamp(t: PyDataType) -> bool { + matches!(t.0, DataType::Timestamp(_, _)) + } + #[staticmethod] + fn is_date(t: PyDataType) -> bool { + matches!(t.0, DataType::Date32 | DataType::Date64) + } + #[staticmethod] + fn is_date32(t: PyDataType) -> bool { + t.0 == DataType::Date32 + } + #[staticmethod] + fn is_date64(t: PyDataType) -> bool { + t.0 == DataType::Date64 + } + #[staticmethod] + fn is_time(t: PyDataType) -> bool { + matches!(t.0, DataType::Time32(_) | DataType::Time64(_)) + } + #[staticmethod] + fn is_time32(t: PyDataType) -> bool { + matches!(t.0, DataType::Time32(_)) + } + #[staticmethod] + fn is_time64(t: PyDataType) -> bool { + matches!(t.0, DataType::Time64(_)) + } + #[staticmethod] + fn is_duration(t: PyDataType) -> bool { + matches!(t.0, DataType::Duration(_)) + } + #[staticmethod] + fn is_interval(t: PyDataType) -> bool { + matches!(t.0, DataType::Interval(_)) + } + #[staticmethod] + fn is_null(t: PyDataType) -> bool { + t.0 == DataType::Null + } + #[staticmethod] + fn is_binary(t: PyDataType) -> bool { + t.0 == DataType::Binary + } + #[staticmethod] + fn is_unicode(t: PyDataType) -> bool { + t.0 == DataType::Utf8 + } + #[staticmethod] + fn is_string(t: PyDataType) -> bool { + t.0 == DataType::Utf8 + } + #[staticmethod] + fn is_large_binary(t: PyDataType) -> bool { + t.0 == DataType::LargeBinary + } + #[staticmethod] + fn is_large_unicode(t: PyDataType) -> bool { + t.0 == DataType::LargeUtf8 + } + #[staticmethod] + fn is_large_string(t: PyDataType) -> bool { + t.0 == DataType::LargeUtf8 + } + #[staticmethod] + fn is_binary_view(t: PyDataType) -> bool { + t.0 == DataType::BinaryView + } + #[staticmethod] + fn is_string_view(t: PyDataType) -> bool { + t.0 == DataType::Utf8View + } + #[staticmethod] + fn is_fixed_size_binary(t: PyDataType) -> bool { + matches!(t.0, DataType::FixedSizeBinary(_)) + } + #[staticmethod] + fn is_map(t: PyDataType) -> bool { + matches!(t.0, DataType::Map(_, _)) + } + #[staticmethod] + fn is_dictionary(t: PyDataType) -> bool { + matches!(t.0, DataType::Dictionary(_, _)) + } + #[staticmethod] + fn is_primitive(t: PyDataType) -> bool { + t.0.is_primitive() + } + #[staticmethod] + fn is_numeric(t: PyDataType) -> bool { + t.0.is_numeric() + } + #[staticmethod] + fn is_dictionary_key_type(t: PyDataType) -> bool { + t.0.is_dictionary_key_type() + } }