From f206545c1b29839bd2831295acdb7782e88c3475 Mon Sep 17 00:00:00 2001 From: Agisilaos Kounelis <36283973+kounelisagis@users.noreply.github.com> Date: Wed, 23 Oct 2024 14:33:43 +0300 Subject: [PATCH] Add `TILEDB_DATETIME_DAY` type support for Arrow (#2002) * Add in place buffer shift for TILEDB_DATETIME_DAY * Add tests --- tiledb/py_arrow_io_impl.h | 32 ++++-- tiledb/tests/test_pandas_dataframe.py | 135 +++++++++++++++++++++++++- 2 files changed, 160 insertions(+), 7 deletions(-) diff --git a/tiledb/py_arrow_io_impl.h b/tiledb/py_arrow_io_impl.h index f2e8a60dda..060bb5d374 100644 --- a/tiledb/py_arrow_io_impl.h +++ b/tiledb/py_arrow_io_impl.h @@ -233,6 +233,8 @@ ArrowInfo tiledb_buffer_arrow_fmt(BufferInfo bufferinfo, bool use_list = true) { return ArrowInfo("tsu:"); case TILEDB_DATETIME_NS: return ArrowInfo("tsn:"); + case TILEDB_DATETIME_DAY: + return ArrowInfo("tdD"); // TILEDB_BOOL is stored as a uint8_t but arrow::Type::BOOL is 1 bit case TILEDB_BOOL: return ArrowInfo("C"); @@ -242,7 +244,6 @@ ArrowInfo tiledb_buffer_arrow_fmt(BufferInfo bufferinfo, bool use_list = true) { case TILEDB_DATETIME_YEAR: case TILEDB_DATETIME_MONTH: case TILEDB_DATETIME_WEEK: - case TILEDB_DATETIME_DAY: case TILEDB_DATETIME_HR: case TILEDB_DATETIME_MIN: case TILEDB_DATETIME_PS: @@ -739,6 +740,14 @@ int64_t flags_for_buffer(BufferInfo binfo) { return 0; } +template T cast_checked(uint64_t val) { + if (val > std::numeric_limits::max()) { + throw tiledb::TileDBError( + "[TileDB-Arrow] Value too large to cast to requested type"); + } + return static_cast(val); +} + void ArrowExporter::export_(const std::string &name, ArrowArray *array, ArrowSchema *schema, ArrowAdapter::release_cb cb, void *cb_data) { @@ -762,13 +771,11 @@ void ArrowExporter::export_(const std::string &name, ArrowArray *array, if (bufferinfo.is_var) { buffers = {nullptr, bufferinfo.offsets, bufferinfo.data}; } else { - cpp_schema = new CPPArrowSchema(name, arrow_fmt.fmt_, std::nullopt, - arrow_flags, {}, {}); buffers = {nullptr, bufferinfo.data}; } cpp_schema->export_ptr(schema); - size_t elem_num = 0; + size_t elem_num = bufferinfo.data_num; if (bufferinfo.is_var) { // adjust for arrow offset unless empty result elem_num = (bufferinfo.offsets_num == 0) ? 0 : bufferinfo.offsets_num - 1; @@ -778,8 +785,21 @@ void ArrowExporter::export_(const std::string &name, ArrowArray *array, // take the size of the entire buffer and divide by the size of each // element elem_num = bufferinfo.data_num / bufferinfo.tdbtype.cell_val_num; - } else { - elem_num = bufferinfo.data_num; + } else if (arrow_fmt.fmt_ == "tdD") { + // for Arrow date32 we only need the first 4 bytes of each 8-byte + // TILEDB_DATETIME_DAY element which we keep by in-place left shifting + for (size_t i = 0; i < bufferinfo.data_num; i++) { + uint32_t lost_data = *(reinterpret_cast( + static_cast(buffers[1]) + i * 8 + 4)); + if (lost_data != 0) { + throw tiledb::TileDBError( + "[TileDB-Arrow] Non-zero data detected in the memory buffer at " + "position that will be overwritten"); + } + + static_cast(buffers[1])[i] = + cast_checked(static_cast(buffers[1])[i]); + } } } diff --git a/tiledb/tests/test_pandas_dataframe.py b/tiledb/tests/test_pandas_dataframe.py index 4d80aeba47..f7374252c0 100644 --- a/tiledb/tests/test_pandas_dataframe.py +++ b/tiledb/tests/test_pandas_dataframe.py @@ -5,6 +5,7 @@ import string import sys import uuid +from collections import OrderedDict import numpy as np import pyarrow @@ -16,6 +17,7 @@ from .common import ( DiskTestCase, + assert_dict_arrays_equal, dtype_max, dtype_min, has_pandas, @@ -219,7 +221,6 @@ def test_object_dtype(self): "