Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support int8 embedding and uint8 embedding #1527

Merged
merged 32 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
38bfb90
int8 l2
Ognimalf Jul 16, 2024
0528776
update some simd tools
yangzq50 Jul 22, 2024
bfdd6a2
update I8L2 distance func
yangzq50 Jul 22, 2024
b3a8eaf
update U8L2 distance func
yangzq50 Jul 22, 2024
9a0527b
add avx512bw info
yangzq50 Jul 22, 2024
eca7e3f
update simd_init
yangzq50 Jul 22, 2024
1615c81
update simd_functions
yangzq50 Jul 22, 2024
9c7e177
update hnsw
yangzq50 Jul 22, 2024
c294b65
update U8IP function
yangzq50 Jul 22, 2024
4f112c8
update U8Cos function
yangzq50 Jul 22, 2024
ffbe53b
update simd_functions
yangzq50 Jul 22, 2024
0e7ec45
update dist_func_cos
yangzq50 Jul 22, 2024
8e02d17
update dist_func_cos
yangzq50 Jul 23, 2024
3e9b102
add uint8 support for embedding
yangzq50 Jul 23, 2024
4b3ff2f
update parser
yangzq50 Jul 23, 2024
16625ea
support u8 embedding input and output
yangzq50 Jul 23, 2024
8584653
support u8 embedding calculation
yangzq50 Jul 23, 2024
3666e55
update support
yangzq50 Jul 23, 2024
f18273a
fix
yangzq50 Jul 23, 2024
47a5873
fix python test
yangzq50 Jul 24, 2024
b626a33
fix unittest
yangzq50 Jul 24, 2024
cf264f6
add python test for float16 and bfloat16 input and output
yangzq50 Jul 25, 2024
f62f7a4
fix
yangzq50 Jul 25, 2024
9999c7c
fix python support for bool
yangzq50 Jul 25, 2024
fb74369
update example
yangzq50 Jul 25, 2024
f7abe1d
fix python support for Hnsw
yangzq50 Jul 25, 2024
345cc3e
fix python support
yangzq50 Jul 25, 2024
c041507
add more info
yangzq50 Jul 25, 2024
1ec7dd5
fix bool support for python
yangzq50 Jul 25, 2024
66f5421
fix bool unittest
yangzq50 Jul 25, 2024
de14d35
fix
yangzq50 Jul 25, 2024
4a1e370
fix
yangzq50 Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions docs/references/pysdk_api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,19 +357,21 @@ Create an index by `IndexInfo` list.
A IndexInfo struct contains three fields,`column_name`, `index_type`, and `index_param_list`.
- **column_name : str** Name of the column to build index on.
- **index_type : IndexType**
enum type: `IVFFlat` , `Hnsw`, `HnswLVQ`, `FullText`, or `BMP`. Defined in `infinity.index`.
`Note: The difference between Hnsw and HnswLVQ is only adopting different clustering method. The former uses K-Means while the later uses LVQ(Learning Vector Quantization)`
enum type: `IVFFlat` , `Hnsw`, `FullText`, or `BMP`. Defined in `infinity.index`.
`Note: For Hnsw index, add encode=lvq in index_param_list to use LVQ(Locally-adaptive vector quantization)`
- **index_param_list**
A list of InitParameter. The InitParameter struct is like a key-value pair, with two string fields named param_name and param_value. The optional parameters of each type of index are listed below:
- `IVFFlat`: `'centroids_count'`(default:`'128'`), `'metric'`(required)
- `Hnsw`: `'M'`(default:`'16'`), `'ef_construction'`(default:`'50'`), `'ef'`(default:`'50'`), `'metric'`(required)
- `HnswLVQ`:
- `Hnsw`:
- `'M'`(default:`'16'`)
- `'ef_construction'`(default:`'50'`)
- `'ef'`(default:`'50'`)
- `'metric'`(required)
- `ip`: Inner product
- `l2`: Euclidean distance
- `'encode'`(optional)
- `plain`: Plain encoding (default)
- `lvq`: LVQ(Locally-adaptive vector quantization)
- `FullText`: `'ANALYZER'`(default:`'standard'`)
- `BMP`:
- `block_size=1~256`(default: 16): The size of the block in BMP index
Expand Down
25 changes: 19 additions & 6 deletions example/ColBERT_reranker_example/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
# limitations under the License.

from typing import Union
import torch
from infinity.remote_thrift.types import make_match_tensor_expr
# NOTICE: please check which infinity you are using, local or remote
# this statement is for local infinity
from infinity.local_infinity.types import make_match_tensor_expr
# enable the following import statement to use remote infinity
# from infinity.remote_thrift.types import make_match_tensor_expr
# from infinity.common import LOCAL_HOST


class InfinityHelperForColBERT:
Expand All @@ -40,9 +44,9 @@ def create_test_env(self, schema: dict):
# append two hidden columns: "INNER_HIDDEN_ColBERT_TENSOR_ARRAY_FLOAT", "INNER_HIDDEN_ColBERT_TENSOR_ARRAY_BIT"
self.test_db_name = "colbert_test_db"
self.test_table_name = "colbert_test_table"
self.inner_col_txt = "INNER_HIDDEN_ColBERT_TEXT_FOR_BM25"
self.inner_col_float = "INNER_HIDDEN_ColBERT_TENSOR_ARRAY_FLOAT"
self.inner_col_bit = "INNER_HIDDEN_ColBERT_TENSOR_ARRAY_BIT"
self.inner_col_txt = "inner_hidden_colbert_text_for_bm25"
self.inner_col_float = "inner_hidden_colbert_tensor_array_float"
self.inner_col_bit = "inner_hidden_colbert_tensor_array_bit"
if self.inner_col_txt in schema:
raise ValueError(f"Column name {self.inner_col_txt} is reserved for internal use.")
if self.inner_col_float in schema:
Expand All @@ -56,13 +60,22 @@ def create_test_env(self, schema: dict):
from infinity import NetworkAddress
import infinity.index as index
from infinity.common import ConflictType
self.infinity_obj = infinity.connect(NetworkAddress("127.0.0.1", 23817))
# NOTICE: the following statement is for local infinity
self.infinity_obj = infinity.connect("/var/infinity")
# enable the following statement to use remote infinity
# self.infinity_obj = infinity.connect(LOCAL_HOST)
self.infinity_obj.drop_database(self.test_db_name, ConflictType.Ignore)
self.colbert_test_db = self.infinity_obj.create_database(self.test_db_name)
self.colbert_test_table = self.colbert_test_db.create_table(self.test_table_name, schema, ConflictType.Error)
# NOTICE: the following statement is for english text
self.colbert_test_table.create_index("test_ft_index",
[index.IndexInfo(self.inner_col_txt, index.IndexType.FullText, [])],
ConflictType.Error)
# please enable the following statement to use chinese text
# self.colbert_test_table.create_index("test_ft_index",
# [index.IndexInfo(self.inner_col_txt, index.IndexType.FullText,
# [infinity.index.InitParameter("ANALYZER", "chinese")])],
# ConflictType.Error)

# clear the test environment for ColBERT
def clear_test_env(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def faiss_import(self, data_set):
def infinity_benchmark_flat_knn(self, data_set):
threads = int(input("Enter number of threads:\n"))
rounds = int(input("Enter number of rounds:\n"))
benchmark(threads, rounds, data_set, self.dataset_path_map[data_set])
benchmark(threads, rounds, data_set, 200, True, self.dataset_path_map[data_set])

def faiss_benchmark_flat_knn_batch_query(self, data_set):
xq = fvecs_read(self.dataset_path_map[data_set] + self.query_suffix[data_set])
Expand Down
15 changes: 5 additions & 10 deletions python/infinity/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,16 @@

class IndexType(Enum):
IVFFlat = 1
HnswLVQ = 2
Hnsw = 3
FullText = 4
Secondary = 5
EMVB = 6
BMP = 7
Hnsw = 2
FullText = 3
Secondary = 4
EMVB = 5
BMP = 6

def to_ttype(self):
match self:
case IndexType.IVFFlat:
return ttypes.IndexType.IVFFlat
case IndexType.HnswLVQ:
return ttypes.IndexType.HnswLVQ
case IndexType.Hnsw:
return ttypes.IndexType.Hnsw
case IndexType.FullText:
Expand All @@ -54,8 +51,6 @@ def to_local_type(self):
match self:
case IndexType.IVFFlat:
return LocalIndexType.kIVFFlat
case IndexType.HnswLVQ:
return LocalIndexType.kHnswLVQ
case IndexType.Hnsw:
return LocalIndexType.kHnsw
case IndexType.FullText:
Expand Down
8 changes: 8 additions & 0 deletions python/infinity/local_infinity/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def get_ordinary_info(column_info, column_defs, column_name, index):
proto_column_type.logical_type = LogicalType.kFloat
case "double" | "float64":
proto_column_type.logical_type = LogicalType.kDouble
case "float16":
proto_column_type.logical_type = LogicalType.kFloat16
case "bfloat16":
proto_column_type.logical_type = LogicalType.kBFloat16
case "varchar":
proto_column_type.logical_type = LogicalType.kVarchar
case "bool":
Expand Down Expand Up @@ -157,6 +161,8 @@ def get_embedding_info(column_info, column_defs, column_name, index):
embedding_type.element_type = EmbeddingDataType.kElemFloat
elif element_type == "float64" or element_type == "double":
embedding_type.element_type = EmbeddingDataType.kElemDouble
elif element_type == "uint8":
embedding_type.element_type = EmbeddingDataType.kElemUInt8
elif element_type == "int8":
embedding_type.element_type = EmbeddingDataType.kElemInt8
elif element_type == "int16":
Expand Down Expand Up @@ -205,6 +211,8 @@ def get_sparse_info(column_info, column_defs, column_name, index):
sparse_type.element_type = EmbeddingDataType.kElemFloat
elif element_type == "float64" or element_type == "double":
sparse_type.element_type = EmbeddingDataType.kElemDouble
elif element_type == "uint8":
sparse_type.element_type = EmbeddingDataType.kElemUInt8
elif element_type == "int8":
sparse_type.element_type = EmbeddingDataType.kElemInt8
elif element_type == "int16":
Expand Down
23 changes: 11 additions & 12 deletions python/infinity/local_infinity/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,35 +98,34 @@ def knn(
f"Invalid embedding data, type should be embedded, but get {type(embedding_data)}",
)

if (
embedding_data_type == "tinyint"
or embedding_data_type == "smallint"
or embedding_data_type == "int"
or embedding_data_type == "bigint"
):
if embedding_data_type in ["int", "uint8", "int8", "int16", "int32", "int64", "tinyint", "unsigned tinyint",
"smallint", "bigint"]:
embedding_data = [int(x) for x in embedding_data]

data = EmbeddingData()
elem_type = EmbeddingDataType.kElemFloat
if embedding_data_type == "bit":
elem_type = EmbeddingDataType.kElemBit
raise Exception(f"Invalid embedding {embedding_data[0]} type")
elif embedding_data_type == "tinyint":
elif embedding_data_type in ["unsigned tinyint", "uint8"]:
elem_type = EmbeddingDataType.kElemUInt8
data.u8_array_value = embedding_data
elif embedding_data_type in ["tinyint", "int8"]:
elem_type = EmbeddingDataType.kElemInt8
data.i8_array_value = embedding_data
elif embedding_data_type == "smallint":
elif embedding_data_type in ["smallint", "int16"]:
elem_type = EmbeddingDataType.kElemInt16
data.i16_array_value = embedding_data
elif embedding_data_type == "int":
elif embedding_data_type in ["int", "int32"]:
elem_type = EmbeddingDataType.kElemInt32
data.i32_array_value = embedding_data
elif embedding_data_type == "bigint":
elif embedding_data_type in ["bigint", "int64"]:
elem_type = EmbeddingDataType.kElemInt64
data.i64_array_value = embedding_data
elif embedding_data_type == "float":
elif embedding_data_type in ["float", "float32"]:
elem_type = EmbeddingDataType.kElemFloat
data.f32_array_value = embedding_data
elif embedding_data_type == "double":
elif embedding_data_type in ["double", "float64"]:
elem_type = EmbeddingDataType.kElemDouble
data.f64_array_value = embedding_data
else:
Expand Down
3 changes: 3 additions & 0 deletions python/infinity/local_infinity/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def insert(self, data: Union[INSERT_DATA, list[INSERT_DATA]]):
if isinstance(value, str):
constant_expression.literal_type = LiteralType.kString
constant_expression.str_value = value
elif isinstance(value, bool):
constant_expression.literal_type = LiteralType.kBoolean
constant_expression.bool_value = value
elif isinstance(value, int):
constant_expression.literal_type = LiteralType.kInteger
constant_expression.i64_value = value
Expand Down
40 changes: 33 additions & 7 deletions python/infinity/local_infinity/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,17 @@ def logic_type_to_dtype(ttype: WrapDataType):
return dtype('float32')
case LogicalType.kDouble:
return dtype('float64')
case LogicalType.kFloat16:
return dtype('float32')
case LogicalType.kBFloat16:
return dtype('float32')
case LogicalType.kVarchar:
return dtype('str')
case LogicalType.kEmbedding:
if ttype.embedding_type is not None:
match ttype.embedding_type.element_type:
case EmbeddingDataType.kElemUInt8:
return object
case EmbeddingDataType.kElemInt8:
return object
case EmbeddingDataType.kElemInt16:
Expand Down Expand Up @@ -81,6 +87,9 @@ def tensor_to_list(column_data_type, binary_data) -> list[list[Any]]:
mid_res_int = mid_res_int * 256 + j
result.append([f"\u007b0:0{dimension}b\u007d".format(mid_res_int)[::-1]])
return result
elif column_data_type.embedding_type.element_type == EmbeddingDataType.kElemUInt8:
all_list = list(struct.unpack('<{}B'.format(len(binary_data)), binary_data))
return [all_list[i:i + dimension] for i in range(0, len(all_list), dimension)]
elif column_data_type.embedding_type.element_type == EmbeddingDataType.kElemInt8:
all_list = list(struct.unpack('<{}b'.format(len(binary_data)), binary_data))
return [all_list[i:i + dimension] for i in range(0, len(all_list), dimension)]
Expand Down Expand Up @@ -144,6 +153,14 @@ def column_vector_to_list(column_type, column_data_type, column_vectors) -> \
return list(struct.unpack('<{}f'.format(len(column_vector) // 4), column_vector))
case LogicalType.kDouble:
return list(struct.unpack('<{}d'.format(len(column_vector) // 8), column_vector))
case LogicalType.kFloat16:
return list(struct.unpack('<{}e'.format(len(column_vector) // 2), column_vector))
case LogicalType.kBFloat16:
tmp_u16 = np.frombuffer(column_vector, dtype='<i2')
result_arr = np.zeros(2 * len(tmp_u16), dtype='<i2')
result_arr[1::2] = tmp_u16
view_float32 = result_arr.view('<f4')
return list(view_float32)
case LogicalType.kVarchar:
return list(parse_bytes(column_vector))
case LogicalType.kBoolean:
Expand All @@ -158,7 +175,10 @@ def column_vector_to_list(column_type, column_data_type, column_vectors) -> \
case LogicalType.kEmbedding:
dimension = column_data_type.embedding_type.dimension
element_type = column_data_type.embedding_type.element_type
if element_type == EmbeddingDataType.kElemInt8:
if element_type == EmbeddingDataType.kElemUInt8:
all_list = list(struct.unpack('<{}B'.format(len(column_vector)), column_vector))
return [all_list[i:i + dimension] for i in range(0, len(all_list), dimension)]
elif element_type == EmbeddingDataType.kElemInt8:
all_list = list(struct.unpack('<{}b'.format(len(column_vector)), column_vector))
return [all_list[i:i + dimension] for i in range(0, len(all_list), dimension)]
elif element_type == EmbeddingDataType.kElemInt16:
Expand Down Expand Up @@ -225,6 +245,9 @@ def parse_sparse_bytes(column_data_type, column_vector):
case _:
raise NotImplementedError(f"Unsupported type {index_type}")
match element_type:
case EmbeddingDataType.kElemUInt8:
values = struct.unpack('<{}B'.format(nnz), column_vector[offset:offset + nnz])
offset += nnz
case EmbeddingDataType.kElemInt8:
values = struct.unpack('<{}b'.format(nnz), column_vector[offset:offset + nnz])
offset += nnz
Expand Down Expand Up @@ -277,22 +300,25 @@ def make_match_tensor_expr(vector_column_name: str, embedding_data: VEC, embeddi
elem_type = EmbeddingDataType.kElemFloat
if embedding_data_type == 'bit':
raise InfinityException(ErrorCode.INVALID_EMBEDDING_DATA_TYPE, f"Invalid embedding {embedding_data[0]} type")
elif embedding_data_type == 'tinyint' or embedding_data_type == 'int8' or embedding_data_type == 'i8':
elif embedding_data_type in ['unsigned tinyint', 'uint8', 'u8']:
elem_type = EmbeddingDataType.kElemUInt8
data.u8_array_value = np.asarray(embedding_data, dtype=np.uint8).flatten()
elif embedding_data_type in ['tinyint', 'int8', 'i8']:
elem_type = EmbeddingDataType.kElemInt8
data.i8_array_value = np.asarray(embedding_data, dtype=np.int8).flatten()
elif embedding_data_type == 'smallint' or embedding_data_type == 'int16' or embedding_data_type == 'i16':
elif embedding_data_type in ['smallint', 'int16', 'i16']:
elem_type = EmbeddingDataType.kElemInt16
data.i16_array_value = np.asarray(embedding_data, dtype=np.int16).flatten()
elif embedding_data_type == 'int' or embedding_data_type == 'int32' or embedding_data_type == 'i32':
elif embedding_data_type in ['int', 'int32', 'i32']:
elem_type = EmbeddingDataType.kElemInt32
data.i32_array_value = np.asarray(embedding_data, dtype=np.int32).flatten()
elif embedding_data_type == 'bigint' or embedding_data_type == 'int64' or embedding_data_type == 'i64':
elif embedding_data_type in ['bigint', 'int64', 'i64']:
elem_type = EmbeddingDataType.kElemInt64
data.i64_array_value = np.asarray(embedding_data, dtype=np.int64).flatten()
elif embedding_data_type == 'float' or embedding_data_type == 'float32' or embedding_data_type == 'f32':
elif embedding_data_type in ['float', 'float32', 'f32']:
elem_type = EmbeddingDataType.kElemFloat
data.f32_array_value = np.asarray(embedding_data, dtype=np.float32).flatten()
elif embedding_data_type == 'double' or embedding_data_type == 'float64' or embedding_data_type == 'f64':
elif embedding_data_type in ['double', 'float64', 'f64']:
elem_type = EmbeddingDataType.kElemDouble
data.f64_array_value = np.asarray(embedding_data, dtype=np.float64).flatten()
else:
Expand Down
9 changes: 9 additions & 0 deletions python/infinity/local_infinity/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ def traverse_conditions(cons, fn=None):

return parsed_expr

elif isinstance(cons, exp.Boolean):
parsed_expr = WrapParsedExpr()
constant_expr = WrapConstantExpr()
constant_expr.literal_type = LiteralType.kBoolean
constant_expr.bool_value = cons.this
parsed_expr.type = ParsedExprType.kConstant
parsed_expr.constant_expr = constant_expr
return parsed_expr

elif isinstance(cons, exp.Literal):
parsed_expr = WrapParsedExpr()
constant_expr = WrapConstantExpr()
Expand Down
Loading
Loading