Skip to content

Commit

Permalink
Extend the GroupMetadata functionality to support NumPy arrays (#2085)
Browse files Browse the repository at this point in the history
* Extend GroupMetadata functionality
* Add tests
  • Loading branch information
kounelisagis authored Oct 23, 2024
1 parent d8109c3 commit 61b1ce2
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 70 deletions.
20 changes: 11 additions & 9 deletions tiledb/cc/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,20 @@ bool is_tdb_str(tiledb_datatype_t type) {
}

py::size_t get_ncells(py::dtype type) {
if (type.is(py::dtype("S")))
return type.itemsize() == 0 ? TILEDB_VAR_NUM : type.itemsize();

if (type.is(py::dtype("U"))) {
auto np_unicode_size = py::dtype("U").itemsize();
return type.itemsize() == 0 ? TILEDB_VAR_NUM
: type.itemsize() / np_unicode_size;
}

auto np = py::module::import("numpy");
auto np_issubdtype = np.attr("issubdtype");
auto np_complexfloating = np.attr("complexfloating");
auto np_character = np.attr("character");

py::bool_ ischaracter = np_issubdtype(type, np_character);
if (ischaracter) {
py::dtype base_dtype =
np.attr("dtype")(py::make_tuple(type.attr("kind"), 1));
if (type.itemsize() == 0)
return TILEDB_VAR_NUM;
return type.itemsize() / base_dtype.itemsize();
}

py::bool_ iscomplexfloating = np_issubdtype(type, np_complexfloating);
if (iscomplexfloating)
return 2;
Expand Down
118 changes: 96 additions & 22 deletions tiledb/cc/group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,14 @@ void put_metadata_numpy(Group &group, const std::string &key, py::array value) {
throw py::type_error(e.what());
}

if (is_tdb_str(value_type) && value.size() > 1)
throw py::type_error("array/list of strings not supported");

py::buffer_info value_buffer = value.request();
if (value_buffer.ndim != 1)
if (value.ndim() != 1)
throw py::type_error("Only 1D Numpy arrays can be stored as metadata");

py::size_t ncells = get_ncells(value.dtype());
if (ncells != 1)
throw py::type_error("Unsupported dtype for metadata");
throw py::type_error("Unsupported dtype '" +
std::string(py::str(value.dtype())) +
"' for metadata");

auto value_num = is_tdb_str(value_type) ? value.nbytes() : value.size();
group.put_metadata(key, value_type, value_num,
Expand All @@ -40,8 +38,10 @@ void put_metadata_numpy(Group &group, const std::string &key, py::array value) {

void put_metadata(Group &group, const std::string &key,
tiledb_datatype_t value_type, uint32_t value_num,
const char *value) {
group.put_metadata(key, value_type, value_num, value);
py::buffer &value) {

py::buffer_info info = value.request();
group.put_metadata(key, value_type, value_num, info.ptr);
}

bool has_metadata(Group &group, const std::string &key) {
Expand All @@ -60,28 +60,102 @@ std::string get_key_from_index(Group &group, uint64_t index) {
return key;
}

py::tuple get_metadata(Group &group, const std::string &key) {
tiledb_datatype_t tdb_type;
uint32_t value_num;
const void *value;
py::object unpack_metadata_val(tiledb_datatype_t value_type, uint32_t value_num,
const char *value_ptr) {
if (value_num == 0)
throw TileDBError("internal error: unexpected value_num==0");

if (value_type == TILEDB_STRING_UTF8) {
return value_ptr == nullptr ? py::str() : py::str(value_ptr, value_num);
}

if (value_type == TILEDB_BLOB || value_type == TILEDB_CHAR ||
value_type == TILEDB_STRING_ASCII) {
return value_ptr == nullptr ? py::bytes() : py::bytes(value_ptr, value_num);
}

group.get_metadata(key, &tdb_type, &value_num, &value);
if (value_ptr == nullptr)
return py::tuple();

py::tuple unpacked(value_num);
for (uint32_t i = 0; i < value_num; i++) {
switch (value_type) {
case TILEDB_INT64:
unpacked[i] = *((int64_t *)value_ptr);
break;
case TILEDB_FLOAT64:
unpacked[i] = *((double *)value_ptr);
break;
case TILEDB_FLOAT32:
unpacked[i] = *((float *)value_ptr);
break;
case TILEDB_INT32:
unpacked[i] = *((int32_t *)value_ptr);
break;
case TILEDB_UINT32:
unpacked[i] = *((uint32_t *)value_ptr);
break;
case TILEDB_UINT64:
unpacked[i] = *((uint64_t *)value_ptr);
break;
case TILEDB_INT8:
unpacked[i] = *((int8_t *)value_ptr);
break;
case TILEDB_UINT8:
unpacked[i] = *((uint8_t *)value_ptr);
break;
case TILEDB_INT16:
unpacked[i] = *((int16_t *)value_ptr);
break;
case TILEDB_UINT16:
unpacked[i] = *((uint16_t *)value_ptr);
break;
default:
throw TileDBError("TileDB datatype not supported");
}
value_ptr += tiledb_datatype_size(value_type);
}

if (value_num > 1)
return unpacked;

py::dtype value_type = tdb_to_np_dtype(tdb_type, 1);
// for single values, return the value directly
return unpacked[0];
}

py::array py_buf;
if (value == nullptr) {
py_buf = py::array(value_type, 0);
return py::make_tuple(py_buf, tdb_type);
py::array unpack_metadata_ndarray(tiledb_datatype_t value_type,
uint32_t value_num, const char *value_ptr) {
py::dtype dtype = tdb_to_np_dtype(value_type, 1);

if (value_ptr == nullptr) {
auto np = py::module::import("numpy");
return np.attr("empty")(py::make_tuple(0), dtype);
}

if (tdb_type == TILEDB_STRING_UTF8) {
value_type = py::dtype("|S1");
// special case for TILEDB_STRING_UTF8: TileDB assumes size=1
if (value_type != TILEDB_STRING_UTF8) {
value_num *= tiledb_datatype_size(value_type);
}

py_buf = py::array(value_type, value_num, value);
auto buf = py::memoryview::from_memory(value_ptr, value_num);

return py::make_tuple(py_buf, tdb_type);
auto np = py::module::import("numpy");
return np.attr("frombuffer")(buf, dtype);
}

py::tuple get_metadata(Group &group, const py::str &key, bool is_ndarray) {
tiledb_datatype_t tdb_type;
uint32_t value_num;
const char *value_ptr;

group.get_metadata(key, &tdb_type, &value_num, (const void **)&value_ptr);
if (is_ndarray) {
auto arr = unpack_metadata_ndarray(tdb_type, value_num, value_ptr);
return py::make_tuple(arr, tdb_type);
} else {
auto arr = unpack_metadata_val(tdb_type, value_num, value_ptr);
return py::make_tuple(arr, tdb_type);
}
}

bool has_member(Group &group, std::string obj) {
Expand Down
79 changes: 43 additions & 36 deletions tiledb/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class Group(CtxMixin, lt.Group):
"""

_NP_DATA_PREFIX = "__np_flat_"
_NP_SHAPE_PREFIX = "__np_shape_"

_mode_to_query_type = {
"r": lt.QueryType.READ,
Expand Down Expand Up @@ -112,19 +113,21 @@ def __setitem__(self, key: str, value: GroupMetadataValueType):

put_metadata = self._group._put_metadata
if isinstance(value, np.ndarray):
put_metadata(f"{Group._NP_DATA_PREFIX}{key}", np.array(value))
elif isinstance(value, bytes):
put_metadata(key, lt.DataType.BLOB, len(value), value)
elif isinstance(value, str):
value = value.encode("UTF-8")
put_metadata(key, lt.DataType.STRING_UTF8, len(value), value)
elif isinstance(value, (list, tuple)):
put_metadata(key, np.array(value))
flat_value = value.ravel()
put_metadata(f"{Group._NP_DATA_PREFIX}{key}", flat_value)
if value.shape != flat_value.shape:
# If the value is not a 1D ndarray, store its associated shape.
# The value's shape will be stored as separate metadata with the correct prefix.
self.__setitem__(f"{Group._NP_SHAPE_PREFIX}{key}", value.shape)
else:
if isinstance(value, int):
# raise OverflowError too large to convert to int64
value = np.int64(value)
put_metadata(key, np.array([value]))
from .metadata import pack_metadata_val

packed_buf = pack_metadata_val(value)
tiledb_type = packed_buf.tdbtype
value_num = packed_buf.value_num
data_view = packed_buf.data

put_metadata(key, tiledb_type, value_num, data_view)

def __getitem__(self, key: str, include_type=False) -> GroupMetadataValueType:
"""
Expand All @@ -137,25 +140,20 @@ def __getitem__(self, key: str, include_type=False) -> GroupMetadataValueType:
raise TypeError(f"Unexpected key type '{type(key)}': expected str")

if self._group._has_metadata(key):
pass
data, tdb_type = self._group._get_metadata(key, False)
elif self._group._has_metadata(f"{Group._NP_DATA_PREFIX}{key}"):
key = f"{Group._NP_DATA_PREFIX}{key}"
data, tdb_type = self._group._get_metadata(
f"{Group._NP_DATA_PREFIX}{key}", True
)
# reshape numpy array back to original shape, if needed
shape_key = f"{Group._NP_SHAPE_PREFIX}{key}"
if self._group._has_metadata(shape_key):
shape, tdb_type = self._group._get_metadata(shape_key, False)
data = data.reshape(shape)
else:
raise KeyError(f"KeyError: {key}")

data, tdb_type = self._group._get_metadata(key)
dtype = DataType.from_tiledb(tdb_type).np_dtype
if np.issubdtype(dtype, np.character):
value = data.tobytes()
if np.issubdtype(dtype, np.str_):
value = value.decode("UTF-8")
elif key.startswith(Group._NP_DATA_PREFIX):
value = data
elif len(data) == 1:
value = data[0]
else:
value = tuple(data)
return (value, tdb_type) if include_type else value
return (data, tdb_type) if include_type else data

def __delitem__(self, key: str):
"""Removes the entry from the Group metadata.
Expand All @@ -168,8 +166,8 @@ def __delitem__(self, key: str):

# key may be stored as is or it may be prefixed (for numpy values)
# we don't know this here so delete all potential internal keys
self._group._delete_metadata(key)
self._group._delete_metadata(f"{Group._NP_DATA_PREFIX}{key}")
for k in key, Group._NP_DATA_PREFIX + key, Group._NP_SHAPE_PREFIX + key:
self._group._delete_metadata(k)

def __contains__(self, key: str) -> bool:
"""
Expand All @@ -193,12 +191,19 @@ def __len__(self) -> int:
:return: Number of entries in the Group metadata
"""
return self._group._metadata_num()
num = self._group._metadata_num()
# subtract the _NP_SHAPE_PREFIX prefixed keys
for key in self._iter(keys_only=True):
if key.startswith(Group._NP_SHAPE_PREFIX):
num -= 1

return num

def _iter(self, keys_only: bool = True, dump: bool = False):
"""
Iterate over Group metadata keys or (key, value) tuples
:param keys_only: whether to yield just keys or values too
:param dump: whether to yield a formatted string for each metadata entry
"""
if keys_only and dump:
raise ValueError("keys_only and dump cannot both be True")
Expand All @@ -207,9 +212,6 @@ def _iter(self, keys_only: bool = True, dump: bool = False):
for i in range(metadata_num):
key = self._group._get_key_from_index(i)

if key.startswith(Group._NP_DATA_PREFIX):
key = key[len(Group._NP_DATA_PREFIX) :]

if keys_only:
yield key
else:
Expand All @@ -226,11 +228,16 @@ def _iter(self, keys_only: bool = True, dump: bool = False):
yield key, val

def __iter__(self):
for key in self._iter():
yield key
np_data_prefix_len = len(Group._NP_DATA_PREFIX)
for key in self._iter(keys_only=True):
if key.startswith(Group._NP_DATA_PREFIX):
yield key[np_data_prefix_len:]
elif not key.startswith(Group._NP_SHAPE_PREFIX):
yield key
# else: ignore the shape keys

def __repr__(self):
return str(dict(self._iter(keys_only=False)))
return str(dict(self))

def setdefault(self, key, default=None):
raise NotImplementedError(
Expand Down
4 changes: 2 additions & 2 deletions tiledb/tests/cc/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def test_group_metadata(tmp_path):
grp._open(lt.QueryType.READ)
assert grp._metadata_num() == 2
assert grp._has_metadata("int")
assert_array_equal(grp._get_metadata("int")[0], int_data)
assert_array_equal(grp._get_metadata("int", False)[0], int_data)
assert grp._has_metadata("flt")
assert_array_equal(grp._get_metadata("flt")[0], flt_data)
assert_array_equal(grp._get_metadata("flt", False)[0], flt_data)
grp._close()

time.sleep(0.001)
Expand Down
Loading

0 comments on commit 61b1ce2

Please sign in to comment.