Skip to content

Commit

Permalink
Support int8 embedding and uint8 embedding (#1527)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Support int8 embedding and uint8 embedding
Fix issue #1516
Fix issue #1530

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [x] Breaking Change (fix or feature that could cause existing
functionality not to work as expected)
- [x] Refactoring
- [x] Test cases
- [x] Python SDK impacted, Need to update PyPI

---------

Co-authored-by: salieri <lomlieri@gmail.com>
  • Loading branch information
yangzq50 and Ognimalf authored Jul 25, 2024
1 parent d0e8e7d commit f7b9026
Show file tree
Hide file tree
Showing 86 changed files with 6,770 additions and 4,479 deletions.
10 changes: 6 additions & 4 deletions docs/references/pysdk_api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,19 +357,21 @@ Create an index by `IndexInfo` list.
A IndexInfo struct contains three fields,`column_name`, `index_type`, and `index_param_list`.
- **column_name : str** Name of the column to build index on.
- **index_type : IndexType**
enum type: `IVFFlat` , `Hnsw`, `HnswLVQ`, `FullText`, or `BMP`. Defined in `infinity.index`.
`Note: The difference between Hnsw and HnswLVQ is only adopting different clustering method. The former uses K-Means while the later uses LVQ(Learning Vector Quantization)`
enum type: `IVFFlat` , `Hnsw`, `FullText`, or `BMP`. Defined in `infinity.index`.
`Note: For Hnsw index, add encode=lvq in index_param_list to use LVQ(Locally-adaptive vector quantization)`
- **index_param_list**
A list of InitParameter. The InitParameter struct is like a key-value pair, with two string fields named param_name and param_value. The optional parameters of each type of index are listed below:
- `IVFFlat`: `'centroids_count'`(default:`'128'`), `'metric'`(required)
- `Hnsw`: `'M'`(default:`'16'`), `'ef_construction'`(default:`'50'`), `'ef'`(default:`'50'`), `'metric'`(required)
- `HnswLVQ`:
- `Hnsw`:
- `'M'`(default:`'16'`)
- `'ef_construction'`(default:`'50'`)
- `'ef'`(default:`'50'`)
- `'metric'`(required)
- `ip`: Inner product
- `l2`: Euclidean distance
- `'encode'`(optional)
- `plain`: Plain encoding (default)
- `lvq`: LVQ(Locally-adaptive vector quantization)
- `FullText`: `'ANALYZER'`(default:`'standard'`)
- `BMP`:
- `block_size=1~256`(default: 16): The size of the block in BMP index
Expand Down
25 changes: 19 additions & 6 deletions example/ColBERT_reranker_example/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
# limitations under the License.

from typing import Union
import torch
from infinity.remote_thrift.types import make_match_tensor_expr
# NOTICE: please check which infinity you are using, local or remote
# this statement is for local infinity
from infinity.local_infinity.types import make_match_tensor_expr
# enable the following import statement to use remote infinity
# from infinity.remote_thrift.types import make_match_tensor_expr
# from infinity.common import LOCAL_HOST


class InfinityHelperForColBERT:
Expand All @@ -40,9 +44,9 @@ def create_test_env(self, schema: dict):
# append two hidden columns: "INNER_HIDDEN_ColBERT_TENSOR_ARRAY_FLOAT", "INNER_HIDDEN_ColBERT_TENSOR_ARRAY_BIT"
self.test_db_name = "colbert_test_db"
self.test_table_name = "colbert_test_table"
self.inner_col_txt = "INNER_HIDDEN_ColBERT_TEXT_FOR_BM25"
self.inner_col_float = "INNER_HIDDEN_ColBERT_TENSOR_ARRAY_FLOAT"
self.inner_col_bit = "INNER_HIDDEN_ColBERT_TENSOR_ARRAY_BIT"
self.inner_col_txt = "inner_hidden_colbert_text_for_bm25"
self.inner_col_float = "inner_hidden_colbert_tensor_array_float"
self.inner_col_bit = "inner_hidden_colbert_tensor_array_bit"
if self.inner_col_txt in schema:
raise ValueError(f"Column name {self.inner_col_txt} is reserved for internal use.")
if self.inner_col_float in schema:
Expand All @@ -56,13 +60,22 @@ def create_test_env(self, schema: dict):
from infinity import NetworkAddress
import infinity.index as index
from infinity.common import ConflictType
self.infinity_obj = infinity.connect(NetworkAddress("127.0.0.1", 23817))
# NOTICE: the following statement is for local infinity
self.infinity_obj = infinity.connect("/var/infinity")
# enable the following statement to use remote infinity
# self.infinity_obj = infinity.connect(LOCAL_HOST)
self.infinity_obj.drop_database(self.test_db_name, ConflictType.Ignore)
self.colbert_test_db = self.infinity_obj.create_database(self.test_db_name)
self.colbert_test_table = self.colbert_test_db.create_table(self.test_table_name, schema, ConflictType.Error)
# NOTICE: the following statement is for english text
self.colbert_test_table.create_index("test_ft_index",
[index.IndexInfo(self.inner_col_txt, index.IndexType.FullText, [])],
ConflictType.Error)
# please enable the following statement to use chinese text
# self.colbert_test_table.create_index("test_ft_index",
# [index.IndexInfo(self.inner_col_txt, index.IndexType.FullText,
# [infinity.index.InitParameter("ANALYZER", "chinese")])],
# ConflictType.Error)

# clear the test environment for ColBERT
def clear_test_env(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def faiss_import(self, data_set):
def infinity_benchmark_flat_knn(self, data_set):
threads = int(input("Enter number of threads:\n"))
rounds = int(input("Enter number of rounds:\n"))
benchmark(threads, rounds, data_set, self.dataset_path_map[data_set])
benchmark(threads, rounds, data_set, 200, True, self.dataset_path_map[data_set])

def faiss_benchmark_flat_knn_batch_query(self, data_set):
xq = fvecs_read(self.dataset_path_map[data_set] + self.query_suffix[data_set])
Expand Down
15 changes: 5 additions & 10 deletions python/infinity/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,16 @@

class IndexType(Enum):
IVFFlat = 1
HnswLVQ = 2
Hnsw = 3
FullText = 4
Secondary = 5
EMVB = 6
BMP = 7
Hnsw = 2
FullText = 3
Secondary = 4
EMVB = 5
BMP = 6

def to_ttype(self):
match self:
case IndexType.IVFFlat:
return ttypes.IndexType.IVFFlat
case IndexType.HnswLVQ:
return ttypes.IndexType.HnswLVQ
case IndexType.Hnsw:
return ttypes.IndexType.Hnsw
case IndexType.FullText:
Expand All @@ -54,8 +51,6 @@ def to_local_type(self):
match self:
case IndexType.IVFFlat:
return LocalIndexType.kIVFFlat
case IndexType.HnswLVQ:
return LocalIndexType.kHnswLVQ
case IndexType.Hnsw:
return LocalIndexType.kHnsw
case IndexType.FullText:
Expand Down
8 changes: 8 additions & 0 deletions python/infinity/local_infinity/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def get_ordinary_info(column_info, column_defs, column_name, index):
proto_column_type.logical_type = LogicalType.kFloat
case "double" | "float64":
proto_column_type.logical_type = LogicalType.kDouble
case "float16":
proto_column_type.logical_type = LogicalType.kFloat16
case "bfloat16":
proto_column_type.logical_type = LogicalType.kBFloat16
case "varchar":
proto_column_type.logical_type = LogicalType.kVarchar
case "bool":
Expand Down Expand Up @@ -157,6 +161,8 @@ def get_embedding_info(column_info, column_defs, column_name, index):
embedding_type.element_type = EmbeddingDataType.kElemFloat
elif element_type == "float64" or element_type == "double":
embedding_type.element_type = EmbeddingDataType.kElemDouble
elif element_type == "uint8":
embedding_type.element_type = EmbeddingDataType.kElemUInt8
elif element_type == "int8":
embedding_type.element_type = EmbeddingDataType.kElemInt8
elif element_type == "int16":
Expand Down Expand Up @@ -205,6 +211,8 @@ def get_sparse_info(column_info, column_defs, column_name, index):
sparse_type.element_type = EmbeddingDataType.kElemFloat
elif element_type == "float64" or element_type == "double":
sparse_type.element_type = EmbeddingDataType.kElemDouble
elif element_type == "uint8":
sparse_type.element_type = EmbeddingDataType.kElemUInt8
elif element_type == "int8":
sparse_type.element_type = EmbeddingDataType.kElemInt8
elif element_type == "int16":
Expand Down
23 changes: 11 additions & 12 deletions python/infinity/local_infinity/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,35 +98,34 @@ def knn(
f"Invalid embedding data, type should be embedded, but get {type(embedding_data)}",
)

if (
embedding_data_type == "tinyint"
or embedding_data_type == "smallint"
or embedding_data_type == "int"
or embedding_data_type == "bigint"
):
if embedding_data_type in ["int", "uint8", "int8", "int16", "int32", "int64", "tinyint", "unsigned tinyint",
"smallint", "bigint"]:
embedding_data = [int(x) for x in embedding_data]

data = EmbeddingData()
elem_type = EmbeddingDataType.kElemFloat
if embedding_data_type == "bit":
elem_type = EmbeddingDataType.kElemBit
raise Exception(f"Invalid embedding {embedding_data[0]} type")
elif embedding_data_type == "tinyint":
elif embedding_data_type in ["unsigned tinyint", "uint8"]:
elem_type = EmbeddingDataType.kElemUInt8
data.u8_array_value = embedding_data
elif embedding_data_type in ["tinyint", "int8"]:
elem_type = EmbeddingDataType.kElemInt8
data.i8_array_value = embedding_data
elif embedding_data_type == "smallint":
elif embedding_data_type in ["smallint", "int16"]:
elem_type = EmbeddingDataType.kElemInt16
data.i16_array_value = embedding_data
elif embedding_data_type == "int":
elif embedding_data_type in ["int", "int32"]:
elem_type = EmbeddingDataType.kElemInt32
data.i32_array_value = embedding_data
elif embedding_data_type == "bigint":
elif embedding_data_type in ["bigint", "int64"]:
elem_type = EmbeddingDataType.kElemInt64
data.i64_array_value = embedding_data
elif embedding_data_type == "float":
elif embedding_data_type in ["float", "float32"]:
elem_type = EmbeddingDataType.kElemFloat
data.f32_array_value = embedding_data
elif embedding_data_type == "double":
elif embedding_data_type in ["double", "float64"]:
elem_type = EmbeddingDataType.kElemDouble
data.f64_array_value = embedding_data
else:
Expand Down
3 changes: 3 additions & 0 deletions python/infinity/local_infinity/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def insert(self, data: Union[INSERT_DATA, list[INSERT_DATA]]):
if isinstance(value, str):
constant_expression.literal_type = LiteralType.kString
constant_expression.str_value = value
elif isinstance(value, bool):
constant_expression.literal_type = LiteralType.kBoolean
constant_expression.bool_value = value
elif isinstance(value, int):
constant_expression.literal_type = LiteralType.kInteger
constant_expression.i64_value = value
Expand Down
40 changes: 33 additions & 7 deletions python/infinity/local_infinity/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,17 @@ def logic_type_to_dtype(ttype: WrapDataType):
return dtype('float32')
case LogicalType.kDouble:
return dtype('float64')
case LogicalType.kFloat16:
return dtype('float32')
case LogicalType.kBFloat16:
return dtype('float32')
case LogicalType.kVarchar:
return dtype('str')
case LogicalType.kEmbedding:
if ttype.embedding_type is not None:
match ttype.embedding_type.element_type:
case EmbeddingDataType.kElemUInt8:
return object
case EmbeddingDataType.kElemInt8:
return object
case EmbeddingDataType.kElemInt16:
Expand Down Expand Up @@ -81,6 +87,9 @@ def tensor_to_list(column_data_type, binary_data) -> list[list[Any]]:
mid_res_int = mid_res_int * 256 + j
result.append([f"\u007b0:0{dimension}b\u007d".format(mid_res_int)[::-1]])
return result
elif column_data_type.embedding_type.element_type == EmbeddingDataType.kElemUInt8:
all_list = list(struct.unpack('<{}B'.format(len(binary_data)), binary_data))
return [all_list[i:i + dimension] for i in range(0, len(all_list), dimension)]
elif column_data_type.embedding_type.element_type == EmbeddingDataType.kElemInt8:
all_list = list(struct.unpack('<{}b'.format(len(binary_data)), binary_data))
return [all_list[i:i + dimension] for i in range(0, len(all_list), dimension)]
Expand Down Expand Up @@ -144,6 +153,14 @@ def column_vector_to_list(column_type, column_data_type, column_vectors) -> \
return list(struct.unpack('<{}f'.format(len(column_vector) // 4), column_vector))
case LogicalType.kDouble:
return list(struct.unpack('<{}d'.format(len(column_vector) // 8), column_vector))
case LogicalType.kFloat16:
return list(struct.unpack('<{}e'.format(len(column_vector) // 2), column_vector))
case LogicalType.kBFloat16:
tmp_u16 = np.frombuffer(column_vector, dtype='<i2')
result_arr = np.zeros(2 * len(tmp_u16), dtype='<i2')
result_arr[1::2] = tmp_u16
view_float32 = result_arr.view('<f4')
return list(view_float32)
case LogicalType.kVarchar:
return list(parse_bytes(column_vector))
case LogicalType.kBoolean:
Expand All @@ -158,7 +175,10 @@ def column_vector_to_list(column_type, column_data_type, column_vectors) -> \
case LogicalType.kEmbedding:
dimension = column_data_type.embedding_type.dimension
element_type = column_data_type.embedding_type.element_type
if element_type == EmbeddingDataType.kElemInt8:
if element_type == EmbeddingDataType.kElemUInt8:
all_list = list(struct.unpack('<{}B'.format(len(column_vector)), column_vector))
return [all_list[i:i + dimension] for i in range(0, len(all_list), dimension)]
elif element_type == EmbeddingDataType.kElemInt8:
all_list = list(struct.unpack('<{}b'.format(len(column_vector)), column_vector))
return [all_list[i:i + dimension] for i in range(0, len(all_list), dimension)]
elif element_type == EmbeddingDataType.kElemInt16:
Expand Down Expand Up @@ -225,6 +245,9 @@ def parse_sparse_bytes(column_data_type, column_vector):
case _:
raise NotImplementedError(f"Unsupported type {index_type}")
match element_type:
case EmbeddingDataType.kElemUInt8:
values = struct.unpack('<{}B'.format(nnz), column_vector[offset:offset + nnz])
offset += nnz
case EmbeddingDataType.kElemInt8:
values = struct.unpack('<{}b'.format(nnz), column_vector[offset:offset + nnz])
offset += nnz
Expand Down Expand Up @@ -277,22 +300,25 @@ def make_match_tensor_expr(vector_column_name: str, embedding_data: VEC, embeddi
elem_type = EmbeddingDataType.kElemFloat
if embedding_data_type == 'bit':
raise InfinityException(ErrorCode.INVALID_EMBEDDING_DATA_TYPE, f"Invalid embedding {embedding_data[0]} type")
elif embedding_data_type == 'tinyint' or embedding_data_type == 'int8' or embedding_data_type == 'i8':
elif embedding_data_type in ['unsigned tinyint', 'uint8', 'u8']:
elem_type = EmbeddingDataType.kElemUInt8
data.u8_array_value = np.asarray(embedding_data, dtype=np.uint8).flatten()
elif embedding_data_type in ['tinyint', 'int8', 'i8']:
elem_type = EmbeddingDataType.kElemInt8
data.i8_array_value = np.asarray(embedding_data, dtype=np.int8).flatten()
elif embedding_data_type == 'smallint' or embedding_data_type == 'int16' or embedding_data_type == 'i16':
elif embedding_data_type in ['smallint', 'int16', 'i16']:
elem_type = EmbeddingDataType.kElemInt16
data.i16_array_value = np.asarray(embedding_data, dtype=np.int16).flatten()
elif embedding_data_type == 'int' or embedding_data_type == 'int32' or embedding_data_type == 'i32':
elif embedding_data_type in ['int', 'int32', 'i32']:
elem_type = EmbeddingDataType.kElemInt32
data.i32_array_value = np.asarray(embedding_data, dtype=np.int32).flatten()
elif embedding_data_type == 'bigint' or embedding_data_type == 'int64' or embedding_data_type == 'i64':
elif embedding_data_type in ['bigint', 'int64', 'i64']:
elem_type = EmbeddingDataType.kElemInt64
data.i64_array_value = np.asarray(embedding_data, dtype=np.int64).flatten()
elif embedding_data_type == 'float' or embedding_data_type == 'float32' or embedding_data_type == 'f32':
elif embedding_data_type in ['float', 'float32', 'f32']:
elem_type = EmbeddingDataType.kElemFloat
data.f32_array_value = np.asarray(embedding_data, dtype=np.float32).flatten()
elif embedding_data_type == 'double' or embedding_data_type == 'float64' or embedding_data_type == 'f64':
elif embedding_data_type in ['double', 'float64', 'f64']:
elem_type = EmbeddingDataType.kElemDouble
data.f64_array_value = np.asarray(embedding_data, dtype=np.float64).flatten()
else:
Expand Down
9 changes: 9 additions & 0 deletions python/infinity/local_infinity/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ def traverse_conditions(cons, fn=None):

return parsed_expr

elif isinstance(cons, exp.Boolean):
parsed_expr = WrapParsedExpr()
constant_expr = WrapConstantExpr()
constant_expr.literal_type = LiteralType.kBoolean
constant_expr.bool_value = cons.this
parsed_expr.type = ParsedExprType.kConstant
parsed_expr.constant_expr = constant_expr
return parsed_expr

elif isinstance(cons, exp.Literal):
parsed_expr = WrapParsedExpr()
constant_expr = WrapConstantExpr()
Expand Down
Loading

0 comments on commit f7b9026

Please sign in to comment.