Skip to content

Commit

Permalink
Add support for extended data types (vector, bson) in query results (#26
Browse files Browse the repository at this point in the history
)

* Support for extended data types
  • Loading branch information
kesmit13 authored Apr 25, 2024
1 parent eb54f3c commit dfff654
Show file tree
Hide file tree
Showing 12 changed files with 904 additions and 10 deletions.
195 changes: 190 additions & 5 deletions accel.c
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,21 @@
#define MYSQL_TYPE_STRING 254
#define MYSQL_TYPE_GEOMETRY 255

// SingleStoreDB extended types
#define MYSQL_TYPE_BSON 1001
#define MYSQL_TYPE_FLOAT32_VECTOR_JSON 2001
#define MYSQL_TYPE_FLOAT64_VECTOR_JSON 2002
#define MYSQL_TYPE_INT8_VECTOR_JSON 2003
#define MYSQL_TYPE_INT16_VECTOR_JSON 2004
#define MYSQL_TYPE_INT32_VECTOR_JSON 2005
#define MYSQL_TYPE_INT64_VECTOR_JSON 2006
#define MYSQL_TYPE_FLOAT32_VECTOR 3001
#define MYSQL_TYPE_FLOAT64_VECTOR 3002
#define MYSQL_TYPE_INT8_VECTOR 3003
#define MYSQL_TYPE_INT16_VECTOR 3004
#define MYSQL_TYPE_INT32_VECTOR 3005
#define MYSQL_TYPE_INT64_VECTOR 3006

#define MYSQL_TYPE_CHAR MYSQL_TYPE_TINY
#define MYSQL_TYPE_INTERVAL MYSQL_TYPE_ENUM

Expand Down Expand Up @@ -333,6 +348,8 @@ typedef struct {
inline int IMAX(int a, int b) { return((a) > (b) ? a : b); }
inline int IMIN(int a, int b) { return((a) < (b) ? a : b); }

static PyObject *create_numpy_array(PyObject *py_memview, char *data_format, int data_type, PyObject *py_objs);

char *_PyUnicode_AsUTF8(PyObject *unicode) {
PyObject *bytes = PyUnicode_AsEncodedString(unicode, "utf-8", "strict");
if (!bytes) return NULL;
Expand Down Expand Up @@ -396,6 +413,14 @@ typedef struct {
PyObject *DataFrame;
PyObject *Table;
PyObject *from_pylist;
PyObject *int8;
PyObject *int16;
PyObject *int32;
PyObject *int64;
PyObject *float32;
PyObject *float64;
PyObject *unpack;
PyObject *decode;
} PyStrings;

static PyStrings PyStr = {0};
Expand All @@ -417,6 +442,8 @@ typedef struct {
PyObject *polars_DataFrame;
PyObject *pyarrow_Table;
PyObject *pyarrow_Table_from_pylist;
PyObject *struct_unpack;
PyObject *bson_decode;
} PyFunctions;

static PyFunctions PyFunc = {0};
Expand All @@ -428,6 +455,9 @@ typedef struct {
PyObject *namedtuple_kwargs;
PyObject *create_numpy_array_args;
PyObject *create_numpy_array_kwargs;
PyObject *create_numpy_array_kwargs_vector[7];
PyObject *struct_unpack_args;
PyObject *bson_decode_args;
} PyObjects;

static PyObjects PyObj = {0};
Expand Down Expand Up @@ -498,6 +528,7 @@ int ensure_numpy() {
return 0;

error:
PyErr_Clear();
return -1;
}

Expand All @@ -516,6 +547,7 @@ int ensure_pandas() {
return 0;

error:
PyErr_Clear();
return -1;
}

Expand All @@ -534,6 +566,7 @@ int ensure_polars() {
return 0;

error:
PyErr_Clear();
return -1;
}

Expand All @@ -555,6 +588,26 @@ int ensure_pyarrow() {
return 0;

error:
PyErr_Clear();
return -1;
}


int ensure_bson() {
if (PyFunc.bson_decode) goto exit;

// Import bson if it exists
PyObject *bson_mod = PyImport_ImportModule("bson");
if (!bson_mod) goto error;

PyFunc.bson_decode = PyObject_GetAttr(bson_mod, PyStr.decode);
if (!PyFunc.bson_decode) goto error;

exit:
return 0;

error:
PyErr_Clear();
return -1;
}

Expand Down Expand Up @@ -768,9 +821,10 @@ static int State_init(StateObject *self, PyObject *args, PyObject *kwds) {
NULL : py_converter;
Py_XINCREF(self->py_invalid_values[i]);

self->py_converters[i] = (!py_converter
// || py_converter == Py_None
|| py_converter == py_default_converter) ?
self->py_converters[i] = ((!py_converter || py_converter == py_default_converter)
// TODO: Need C accelerated converters for extended types
// && self->type_codes[i] < 256
) ?
NULL : py_converter;
Py_XINCREF(self->py_converters[i]);
}
Expand Down Expand Up @@ -1420,7 +1474,10 @@ static PyObject *read_row_from_packet(
PyObject *py_result = NULL;
PyObject *py_item = NULL;
PyObject *py_str = NULL;
PyObject *py_memview = NULL;
char end = '\0';
char *cast_type_codes[] = {"", "f", "d", "b", "h", "l", "q"};
int item_type_lengths[] = {0, 4, 8, 1, 2, 4, 8};

int sign = 1;
int year = 0;
Expand All @@ -1436,7 +1493,7 @@ static PyObject *read_row_from_packet(
case ACCEL_OUT_ARROW:
py_result = PyDict_New();
break;
case ACCEL_OUT_STRUCTSEQUENCES: {
case ACCEL_OUT_STRUCTSEQUENCES: {
if (!py_state->structsequence) goto error;
py_result = PyStructSequence_New(py_state->structsequence);
break;
Expand Down Expand Up @@ -1665,6 +1722,12 @@ static PyObject *read_row_from_packet(
case MYSQL_TYPE_VARCHAR:
case MYSQL_TYPE_VAR_STRING:
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_FLOAT32_VECTOR_JSON:
case MYSQL_TYPE_FLOAT64_VECTOR_JSON:
case MYSQL_TYPE_INT8_VECTOR_JSON:
case MYSQL_TYPE_INT16_VECTOR_JSON:
case MYSQL_TYPE_INT32_VECTOR_JSON:
case MYSQL_TYPE_INT64_VECTOR_JSON:
if (!py_state->encodings[i]) {
py_item = PyBytes_FromStringAndSize(out, out_l);
if (!py_item) goto error;
Expand All @@ -1675,13 +1738,86 @@ static PyObject *read_row_from_packet(
if (!py_item) goto error;

// Parse JSON string.
if (py_state->type_codes[i] == MYSQL_TYPE_JSON && py_state->options.parse_json) {
if ((py_state->type_codes[i] == MYSQL_TYPE_JSON && py_state->options.parse_json)
|| (py_state->type_codes[i] >= MYSQL_TYPE_FLOAT32_VECTOR_JSON
&& py_state->type_codes[i] <= MYSQL_TYPE_INT64_VECTOR_JSON)) {
py_str = py_item;
py_item = PyObject_CallFunctionObjArgs(PyFunc.json_loads, py_str, NULL);
Py_CLEAR(py_str);
if (!py_item) goto error;
}

if (ensure_numpy() == 0) {
switch (py_state->type_codes[i]) {
case MYSQL_TYPE_FLOAT32_VECTOR_JSON:
case MYSQL_TYPE_FLOAT64_VECTOR_JSON:
case MYSQL_TYPE_INT8_VECTOR_JSON:
case MYSQL_TYPE_INT16_VECTOR_JSON:
case MYSQL_TYPE_INT32_VECTOR_JSON:
case MYSQL_TYPE_INT64_VECTOR_JSON:
CHECKRC(PyTuple_SetItem(PyObj.create_numpy_array_args, 0, py_item));
py_item = PyObject_Call(
PyFunc.numpy_array,
PyObj.create_numpy_array_args,
PyObj.create_numpy_array_kwargs_vector[py_state->type_codes[i] % 1000]
);
if (!py_item) goto error;
}
}

break;

case MYSQL_TYPE_FLOAT32_VECTOR:
case MYSQL_TYPE_FLOAT64_VECTOR:
case MYSQL_TYPE_INT8_VECTOR:
case MYSQL_TYPE_INT16_VECTOR:
case MYSQL_TYPE_INT32_VECTOR:
case MYSQL_TYPE_INT64_VECTOR:
if (ensure_numpy() == 0) {
py_memview = PyMemoryView_FromMemory(out, out_l, PyBUF_WRITE);
if (!py_memview) goto error;

py_item = create_numpy_array(
py_memview,
cast_type_codes[py_state->type_codes[i] % 1000],
py_state->type_codes[i],
NULL
);
Py_CLEAR(py_memview);
if (!py_item) goto error;

} else {
py_memview = PyBytes_FromStringAndSize(out, out_l);
if (!py_memview) goto error;

CHECKRC(PyTuple_SetItem(PyObj.struct_unpack_args, 0,
PyUnicode_FromFormat("<%l%s", out_l / item_type_lengths[i], cast_type_codes[i])));
CHECKRC(PyTuple_SetItem(PyObj.struct_unpack_args, 1, py_memview));

py_item = PyObject_Call(
PyFunc.struct_unpack,
PyObj.struct_unpack_args,
NULL
);
if (!py_item) goto error;
}

break;

case MYSQL_TYPE_BSON:
py_item = PyBytes_FromStringAndSize(out, out_l);
if (!py_item) goto error;

if (ensure_bson() == 0) {
CHECKRC(PyTuple_SetItem(PyObj.bson_decode_args, 0, py_item));
py_item = PyObject_Call(
PyFunc.bson_decode,
PyObj.bson_decode_args,
NULL
);
if (!py_item) goto error;
}

break;

default:
Expand Down Expand Up @@ -4470,6 +4606,14 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) {
PyStr.DataFrame = PyUnicode_FromString("DataFrame");
PyStr.Table = PyUnicode_FromString("Table");
PyStr.from_pylist = PyUnicode_FromString("from_pylist");
PyStr.int8 = PyUnicode_FromString("int8");
PyStr.int16 = PyUnicode_FromString("int16");
PyStr.int32 = PyUnicode_FromString("int32");
PyStr.int64 = PyUnicode_FromString("int64");
PyStr.float32 = PyUnicode_FromString("float32");
PyStr.float64 = PyUnicode_FromString("float64");
PyStr.unpack = PyUnicode_FromString("unpack");
PyStr.decode = PyUnicode_FromString("decode");

PyObject *decimal_mod = PyImport_ImportModule("decimal");
if (!decimal_mod) goto error;
Expand All @@ -4479,6 +4623,8 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) {
if (!json_mod) goto error;
PyObject *collections_mod = PyImport_ImportModule("collections");
if (!collections_mod) goto error;
PyObject *struct_mod = PyImport_ImportModule("struct");
if (!struct_mod) goto error;

PyFunc.decimal_Decimal = PyObject_GetAttr(decimal_mod, PyStr.Decimal);
if (!PyFunc.decimal_Decimal) goto error;
Expand All @@ -4494,6 +4640,8 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) {
if (!PyFunc.json_loads) goto error;
PyFunc.collections_namedtuple = PyObject_GetAttr(collections_mod, PyStr.namedtuple);
if (!PyFunc.collections_namedtuple) goto error;
PyFunc.struct_unpack = PyObject_GetAttr(struct_mod, PyStr.unpack);
if (!PyFunc.struct_unpack) goto error;

PyObj.namedtuple_kwargs = PyDict_New();
if (!PyObj.namedtuple_kwargs) goto error;
Expand All @@ -4510,6 +4658,43 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) {
goto error;
}

PyObj.create_numpy_array_kwargs_vector[1] = PyDict_New();
if (!PyObj.create_numpy_array_kwargs_vector[1]) goto error;
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[1], "dtype", PyStr.float32)) {
goto error;
}
PyObj.create_numpy_array_kwargs_vector[2] = PyDict_New();
if (!PyObj.create_numpy_array_kwargs_vector[2]) goto error;
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[2], "dtype", PyStr.float64)) {
goto error;
}
PyObj.create_numpy_array_kwargs_vector[3] = PyDict_New();
if (!PyObj.create_numpy_array_kwargs_vector[3]) goto error;
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[3], "dtype", PyStr.int8)) {
goto error;
}
PyObj.create_numpy_array_kwargs_vector[4] = PyDict_New();
if (!PyObj.create_numpy_array_kwargs_vector[4]) goto error;
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[4], "dtype", PyStr.int16)) {
goto error;
}
PyObj.create_numpy_array_kwargs_vector[5] = PyDict_New();
if (!PyObj.create_numpy_array_kwargs_vector[5]) goto error;
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[5], "dtype", PyStr.int32)) {
goto error;
}
PyObj.create_numpy_array_kwargs_vector[6] = PyDict_New();
if (!PyObj.create_numpy_array_kwargs_vector[6]) goto error;
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[6], "dtype", PyStr.int64)) {
goto error;
}

PyObj.struct_unpack_args = PyTuple_New(2);
if (!PyObj.struct_unpack_args) goto error;

PyObj.bson_decode_args = PyTuple_New(1);
if (!PyObj.bson_decode_args) goto error;

return PyModule_Create(&_singlestoredb_accelmodule);

error:
Expand Down
6 changes: 6 additions & 0 deletions singlestoredb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@
environ='SINGLESTOREDB_TRACK_ENV',
)

register_option(
'enable_extended_data_types', 'bool', check_bool, True,
'Should extended data types (BSON, vector) be enabled?',
environ='SINGLESTOREDB_ENABLE_EXTENDED_DATA_TYPES',
)

register_option(
'fusion.enabled', 'bool', check_bool, False,
'Should Fusion SQL queries be enabled?',
Expand Down
3 changes: 3 additions & 0 deletions singlestoredb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,7 @@ def connect(
inf_as_null: Optional[bool] = None,
encoding_errors: Optional[str] = None,
track_env: Optional[bool] = None,
enable_extended_data_types: Optional[bool] = None,
) -> Connection:
"""
Return a SingleStoreDB connection.
Expand Down Expand Up @@ -1361,6 +1362,8 @@ def connect(
The error handler name for value decoding errors
track_env : bool, optional
Should the connection track the SINGLESTOREDB_URL environment variable?
enable_extended_data_types : bool, optional
Should extended data types (BSON, vector) be enabled?
Examples
--------
Expand Down
Loading

0 comments on commit dfff654

Please sign in to comment.