Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vlad/l2 sumofsquares #486

Merged
merged 54 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
3166e79
Distance metrics added for flat
cainamisir Jun 19, 2024
e946cbb
Added cosine distance function [skip ci]
cainamisir Jun 20, 2024
3796274
Changed distance to Enum
cainamisir Jun 20, 2024
9e914dc
exception for invalid distance metric
cainamisir Jun 20, 2024
81565e7
enum use and exception for non-flat distance metrics
cainamisir Jun 20, 2024
504d046
distance metric test
cainamisir Jun 20, 2024
c571281
formatting
cainamisir Jun 20, 2024
5af20ee
Cleaning up integration by moving to create
cainamisir Jul 1, 2024
2f67be5
Formatting and cleaning up test
cainamisir Jul 1, 2024
3525ff9
Merge branch 'main' of https://github.com/TileDB-Inc/TileDB-Vector-Se…
cainamisir Jul 4, 2024
27b93f3
Cosine optimization
cainamisir Jul 4, 2024
9396bdc
Formatting
cainamisir Jul 4, 2024
4417a28
distance metric integration in C++ layer for vamana+ivfpq
cainamisir Jul 5, 2024
face742
Final FLAT changes
cainamisir Jul 15, 2024
e096fbb
IVF_FLAT cosine distance + setup for vamana and pq
cainamisir Jul 15, 2024
cfe1a5b
Normalization node
cainamisir Jul 17, 2024
b4f5518
Delete apis/python/test/test_scale.py
cainamisir Jul 17, 2024
2926ba2
Formatting and test
cainamisir Jul 18, 2024
8a8bcf8
Merge branch 'vlad/distIVF' of https://github.com/TileDB-Inc/TileDB-V…
cainamisir Jul 18, 2024
f37cd37
better metadata setting
cainamisir Jul 18, 2024
ab6b3fe
Optimizing cosine for IVF
cainamisir Jul 21, 2024
7cb919f
Remove comments
cainamisir Jul 22, 2024
c56b817
Try at distance function type erasure
cainamisir Jul 24, 2024
eb9ffe1
Broken distance metric (not persisting)
cainamisir Jul 29, 2024
6fae482
Working distance metrics
cainamisir Jul 29, 2024
b476fdd
Run all tests
cainamisir Jul 29, 2024
c13408b
Formatting
cainamisir Jul 29, 2024
7b54b99
Merge branch 'main' of https://github.com/TileDB-Inc/TileDB-Vector-Se…
cainamisir Jul 29, 2024
32bb6ae
Merge errors and formatting
cainamisir Jul 29, 2024
c376a5c
Bugfix
cainamisir Jul 30, 2024
38e6905
Redundant code removal
cainamisir Jul 30, 2024
34595dd
Merge branch 'main' of https://github.com/TileDB-Inc/TileDB-Vector-Se…
cainamisir Jul 30, 2024
a02ca75
Fixing merge errors
cainamisir Jul 30, 2024
d06923f
Implementing review suggestions
cainamisir Jul 31, 2024
530ae64
remove redundant variables
cainamisir Aug 1, 2024
85ba8bf
whitespace
cainamisir Aug 1, 2024
f9547c8
Fix warnings
cainamisir Aug 1, 2024
1bce407
test update
cainamisir Aug 1, 2024
b95cfcd
format
cainamisir Aug 1, 2024
ff1aa90
Extra testing
cainamisir Aug 1, 2024
2a47a83
pr suggestions
cainamisir Aug 1, 2024
0138564
Optimize ingestion for cosine ivf flat
cainamisir Aug 6, 2024
c90ddd9
Inner product normal results (non-negative) and extra testing
cainamisir Aug 6, 2024
6dbb21d
Change inner product to reciprocal
cainamisir Aug 7, 2024
5257451
Merge branch 'vlad/distivf' into vlad/distvamana
cainamisir Aug 9, 2024
4f50c0f
Changing L2 to Sum of squares and adding actual L2
cainamisir Aug 11, 2024
2bacb2a
Merge branch 'main' of https://github.com/TileDB-Inc/TileDB-Vector-Se…
cainamisir Aug 20, 2024
2e3f949
Bugfix
cainamisir Aug 20, 2024
55627c6
Formatting
cainamisir Aug 20, 2024
9a28531
remove duplicate function
cainamisir Aug 20, 2024
8b299a0
Remove whitespace
cainamisir Aug 20, 2024
07a50cb
Bugfix
cainamisir Aug 22, 2024
49f8835
Change error message
cainamisir Aug 22, 2024
806ad54
Fix test
cainamisir Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apis/python/src/tiledb/vector_search/flat_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def create(
group_exists: bool = False,
config: Optional[Mapping[str, Any]] = None,
storage_version: str = STORAGE_VERSION,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.L2,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.SUM_OF_SQUARES,
**kwargs,
) -> FlatIndex:
"""
Expand Down
5 changes: 4 additions & 1 deletion apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ def __init__(
self.storage_version = self.group.meta.get("storage_version", "0.1")
try:
self.distance_metric = vspy.DistanceMetric(
self.group.meta.get("distance_metric", vspy.DistanceMetric.L2)
self.group.meta.get(
"distance_metric", vspy.DistanceMetric.SUM_OF_SQUARES
)
)
except ValueError:
raise ValueError(
f"Invalid distance metric in metadata: {self.group.meta.get('distance_metric')}."
)

if (
not storage_formats[self.storage_version]["SUPPORT_TIMETRAVEL"]
and timestamp is not None
Expand Down
8 changes: 4 additions & 4 deletions apis/python/src/tiledb/vector_search/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def ingest(
] = None,
write_centroids_resources: Optional[Mapping[str, Any]] = None,
partial_index_resources: Optional[Mapping[str, Any]] = None,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.L2,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.SUM_OF_SQUARES,
normalized: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -197,7 +197,7 @@ def ingest(
partial_index_resources: Optional[Mapping[str, Any]]
Resources to request when performing the computation of partial indexing, only applies to BATCH mode
distance_metric: vspy.DistanceMetric
Distance metric to use for the index, defaults to 'vspy.DistanceMetric.L2'. Options are 'vspy.DistanceMetric.L2', 'vspy.DistanceMetric.INNER_PRODUCT', 'vspy.DistanceMetric.COSINE'.
Distance metric to use for the index, defaults to 'vspy.DistanceMetric.SUM_OF_SQUARES'. Options are 'vspy.DistanceMetric.SUM_OF_SQUARES', 'vspy.DistanceMetric.INNER_PRODUCT', 'vspy.DistanceMetric.COSINE', 'vspy.DistanceMetric.L2'.
"""
import enum
import json
Expand Down Expand Up @@ -1252,7 +1252,7 @@ def init_centroids(
config: Optional[Mapping[str, Any]] = None,
verbose: bool = False,
trace_id: Optional[str] = None,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.L2,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.SUM_OF_SQUARES,
) -> np.ndarray:
logger = setup(config, verbose)
logger.debug(
Expand Down Expand Up @@ -1682,7 +1682,7 @@ def ingest_vectors_udf(
config: Optional[Mapping[str, Any]] = None,
verbose: bool = False,
trace_id: Optional[str] = None,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.L2,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.SUM_OF_SQUARES,
):
import os
import random
Expand Down
7 changes: 4 additions & 3 deletions apis/python/src/tiledb/vector_search/ivf_flat_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _taskgraph_query(
num_partitions: int = -1,
num_workers: int = -1,
config: Optional[Mapping[str, Any]] = None,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.L2,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.SUM_OF_SQUARES,
):
"""
Query an IVF_FLAT index using TileDB cloud taskgraphs
Expand Down Expand Up @@ -536,7 +536,7 @@ def create(
group_exists: bool = False,
config: Optional[Mapping[str, Any]] = None,
storage_version: str = STORAGE_VERSION,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.L2,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.SUM_OF_SQUARES,
**kwargs,
) -> IVFFlatIndex:
"""
Expand All @@ -562,7 +562,8 @@ def create(
"""
validate_storage_version(storage_version)
if (
distance_metric != vspy.DistanceMetric.L2
distance_metric != vspy.DistanceMetric.SUM_OF_SQUARES
and distance_metric != vspy.DistanceMetric.L2
and distance_metric != vspy.DistanceMetric.COSINE
):
raise ValueError(
Expand Down
7 changes: 5 additions & 2 deletions apis/python/src/tiledb/vector_search/ivf_pq_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def create(
config: Optional[Mapping[str, Any]] = None,
storage_version: str = STORAGE_VERSION,
partitions: Optional[int] = None,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.L2,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.SUM_OF_SQUARES,
**kwargs,
) -> IVFPQIndex:
"""
Expand Down Expand Up @@ -181,7 +181,10 @@ def create(
raise ValueError(
f"Number of dimensions ({dimensions}) must be divisible by num_subspaces ({num_subspaces})."
)
if distance_metric != vspy.DistanceMetric.L2:
if (
distance_metric != vspy.DistanceMetric.SUM_OF_SQUARES
and distance_metric != vspy.DistanceMetric.L2
):
raise ValueError(
f"Distance metric {distance_metric} is not supported in IVF_PQ"
)
Expand Down
115 changes: 96 additions & 19 deletions apis/python/src/tiledb/vector_search/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,14 @@ static void declare_qv_query_heap_infinite_ram(
size_t nprobe,
size_t k_nn,
size_t nthreads,
DistanceMetric distance_metric = DistanceMetric::L2) -> py::tuple {
DistanceMetric distance_metric =
DistanceMetric::SUM_OF_SQUARES) -> py::tuple {
auto mat = ColMajorPartitionedMatrixWrapper<T, Id_Type, Id_Type>(
parts, ids, indices);

auto top_centroids = detail::ivf::ivf_top_centroids(
centroids, query_vectors, nprobe, nthreads);
if (distance_metric == DistanceMetric::L2) {
if (distance_metric == DistanceMetric::SUM_OF_SQUARES) {
auto r = detail::ivf::qv_query_heap_infinite_ram(
top_centroids,
mat,
Expand Down Expand Up @@ -203,6 +204,16 @@ static void declare_qv_query_heap_infinite_ram(
nthreads,
cosine_distance_normalized{});
return make_python_pair(std::move(r));
} else if (distance_metric == DistanceMetric::L2) {
auto r = detail::ivf::qv_query_heap_infinite_ram(
top_centroids,
mat,
query_vectors,
nprobe,
k_nn,
nthreads,
sqrt_sum_of_squares_distance{});
return make_python_pair(std::move(r));
} else {
throw std::runtime_error("Invalid distance metric");
}
Expand All @@ -227,11 +238,11 @@ static void declare_qv_query_heap_finite_ram(
size_t upper_bound,
size_t nthreads,
uint64_t timestamp,
DistanceMetric distance_metric = DistanceMetric::L2)
DistanceMetric distance_metric = DistanceMetric::SUM_OF_SQUARES)
-> py::tuple { // std::tuple<ColMajorMatrix<float>,
// ColMajorMatrix<size_t>> { //
// TODO change return type
if (distance_metric == DistanceMetric::L2) {
if (distance_metric == DistanceMetric::SUM_OF_SQUARES) {
auto r = detail::ivf::qv_query_heap_finite_ram<T, Id_Type>(
ctx,
parts_uri,
Expand Down Expand Up @@ -276,6 +287,21 @@ static void declare_qv_query_heap_finite_ram(
timestamp,
cosine_distance_normalized{});
return make_python_pair(std::move(r));
} else if (distance_metric == DistanceMetric::L2) {
auto r = detail::ivf::qv_query_heap_finite_ram<T, Id_Type>(
ctx,
parts_uri,
centroids,
query_vectors,
indices,
ids_uri,
nprobe,
k_nn,
upper_bound,
nthreads,
timestamp,
sqrt_sum_of_squares_distance{});
return make_python_pair(std::move(r));
} else {
throw std::runtime_error("Invalid distance metric");
}
Expand All @@ -296,7 +322,7 @@ static void declare_nuv_query_heap_infinite_ram(
size_t nprobe,
size_t k_nn,
size_t nthreads,
DistanceMetric distance_metric = DistanceMetric::L2)
DistanceMetric distance_metric = DistanceMetric::SUM_OF_SQUARES)
-> std::tuple<
ColMajorMatrix<float>,
ColMajorMatrix<uint64_t>> { // TODO change return type
Expand All @@ -307,7 +333,7 @@ static void declare_nuv_query_heap_infinite_ram(
detail::ivf::partition_ivf_flat_index<Id_Type>(
centroids, query_vectors, nprobe, nthreads);

if (distance_metric == DistanceMetric::L2) {
if (distance_metric == DistanceMetric::SUM_OF_SQUARES) {
auto r = detail::ivf::nuv_query_heap_infinite_ram(
mat,
active_partitions,
Expand Down Expand Up @@ -337,6 +363,16 @@ static void declare_nuv_query_heap_infinite_ram(
nthreads,
cosine_distance_normalized{});
return r;
} else if (distance_metric == DistanceMetric::L2) {
auto r = detail::ivf::nuv_query_heap_infinite_ram(
mat,
active_partitions,
query_vectors,
active_queries,
k_nn,
nthreads,
sqrt_sum_of_squares_distance{});
return r;
} else {
throw std::runtime_error("Invalid distance metric");
}
Expand All @@ -360,7 +396,7 @@ static void declare_nuv_query_heap_finite_ram(
size_t upper_bound,
size_t nthreads,
uint64_t timestamp,
DistanceMetric distance_metric = DistanceMetric::L2)
DistanceMetric distance_metric = DistanceMetric::SUM_OF_SQUARES)
-> std::tuple<
ColMajorMatrix<float>,
ColMajorMatrix<uint64_t>> { // TODO change return type
Expand All @@ -381,7 +417,7 @@ static void declare_nuv_query_heap_finite_ram(
upper_bound,
temporal_policy);

if (distance_metric == DistanceMetric::L2) {
if (distance_metric == DistanceMetric::SUM_OF_SQUARES) {
auto r = detail::ivf::nuv_query_heap_finite_ram_reg_blocked(
mat,
query_vectors,
Expand Down Expand Up @@ -411,6 +447,16 @@ static void declare_nuv_query_heap_finite_ram(
nthreads,
cosine_distance_normalized{});
return r;
} else if (distance_metric == DistanceMetric::L2) {
auto r = detail::ivf::nuv_query_heap_finite_ram_reg_blocked(
mat,
query_vectors,
active_queries,
k_nn,
upper_bound,
nthreads,
sqrt_sum_of_squares_distance{});
return r;
} else {
throw std::runtime_error("Invalid distance metric");
}
Expand Down Expand Up @@ -702,9 +748,9 @@ static void declare_vq_query_heap(py::module& m, const std::string& suffix) {
const std::vector<uint64_t>& ids,
int k,
size_t nthreads,
DistanceMetric distance_metric = DistanceMetric::L2)
DistanceMetric distance_metric = DistanceMetric::SUM_OF_SQUARES)
-> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<uint64_t>> {
if (distance_metric == DistanceMetric::L2) {
if (distance_metric == DistanceMetric::SUM_OF_SQUARES) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, ids, k, nthreads, sum_of_squares_distance{});
return r;
Expand All @@ -716,6 +762,15 @@ static void declare_vq_query_heap(py::module& m, const std::string& suffix) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, ids, k, nthreads, cosine_distance{});
return r;
} else if (distance_metric == DistanceMetric::L2) {
auto r = detail::flat::vq_query_heap(
data,
query_vectors,
ids,
k,
nthreads,
sqrt_sum_of_squares_distance{});
return r;
} else {
throw std::runtime_error("Invalid distance metric");
}
Expand All @@ -732,9 +787,9 @@ static void declare_vq_query_heap_pyarray(
const std::vector<uint64_t>& ids,
int k,
size_t nthreads,
DistanceMetric distance_metric = DistanceMetric::L2)
DistanceMetric distance_metric = DistanceMetric::SUM_OF_SQUARES)
-> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<uint64_t>> {
if (distance_metric == DistanceMetric::L2) {
if (distance_metric == DistanceMetric::SUM_OF_SQUARES) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, ids, k, nthreads, sum_of_squares_distance{});
return r;
Expand All @@ -746,6 +801,15 @@ static void declare_vq_query_heap_pyarray(
auto r = detail::flat::vq_query_heap(
data, query_vectors, ids, k, nthreads, cosine_distance{});
return r;
} else if (distance_metric == DistanceMetric::L2) {
auto r = detail::flat::vq_query_heap(
data,
query_vectors,
ids,
k,
nthreads,
sqrt_sum_of_squares_distance{});
return r;
} else {
throw std::runtime_error("Invalid distance metric");
}
Expand Down Expand Up @@ -869,9 +933,9 @@ PYBIND11_MODULE(_tiledbvspy, m) {
ColMajorMatrix<float>& query_vectors,
int k,
size_t nthreads,
DistanceMetric distance_metric = DistanceMetric::L2)
DistanceMetric distance_metric = DistanceMetric::SUM_OF_SQUARES)
-> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<uint64_t>> {
if (distance_metric == DistanceMetric::L2) {
if (distance_metric == DistanceMetric::SUM_OF_SQUARES) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, k, nthreads, sum_of_squares_distance{});
return r;
Expand All @@ -883,6 +947,10 @@ PYBIND11_MODULE(_tiledbvspy, m) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, k, nthreads, cosine_distance{});
return r;
} else if (distance_metric == DistanceMetric::L2) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, k, nthreads, sqrt_sum_of_squares_distance{});
return r;
} else {
throw std::runtime_error("Invalid distance metric");
}
Expand All @@ -894,9 +962,9 @@ PYBIND11_MODULE(_tiledbvspy, m) {
ColMajorMatrix<float>& query_vectors,
int k,
size_t nthreads,
DistanceMetric distance_metric = DistanceMetric::L2)
DistanceMetric distance_metric = DistanceMetric::SUM_OF_SQUARES)
-> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<uint64_t>> {
if (distance_metric == DistanceMetric::L2) {
if (distance_metric == DistanceMetric::SUM_OF_SQUARES) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, k, nthreads, sum_of_squares_distance{});
return r;
Expand All @@ -908,6 +976,10 @@ PYBIND11_MODULE(_tiledbvspy, m) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, k, nthreads, cosine_distance{});
return r;
} else if (distance_metric == DistanceMetric::L2) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, k, nthreads, sqrt_sum_of_squares_distance{});
return r;
} else {
throw std::runtime_error("Invalid distance metric");
}
Expand All @@ -919,9 +991,9 @@ PYBIND11_MODULE(_tiledbvspy, m) {
ColMajorMatrix<float>& query_vectors,
int k,
size_t nthreads,
DistanceMetric distance_metric = DistanceMetric::L2)
DistanceMetric distance_metric = DistanceMetric::SUM_OF_SQUARES)
-> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<uint64_t>> {
if (distance_metric == DistanceMetric::L2) {
if (distance_metric == DistanceMetric::SUM_OF_SQUARES) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, k, nthreads, sum_of_squares_distance{});
return r;
Expand All @@ -933,6 +1005,10 @@ PYBIND11_MODULE(_tiledbvspy, m) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, k, nthreads, cosine_distance{});
return r;
} else if (distance_metric == DistanceMetric::L2) {
auto r = detail::flat::vq_query_heap(
data, query_vectors, k, nthreads, sqrt_sum_of_squares_distance{});
return r;
} else {
throw std::runtime_error("Invalid distance metric");
}
Expand Down Expand Up @@ -1014,9 +1090,10 @@ PYBIND11_MODULE(_tiledbvspy, m) {
declare_debug_matrix<uint64_t>(m, "_u64");

py::enum_<DistanceMetric>(m, "DistanceMetric")
.value("L2", DistanceMetric::L2)
.value("SUM_OF_SQUARES", DistanceMetric::SUM_OF_SQUARES)
.value("INNER_PRODUCT", DistanceMetric::INNER_PRODUCT)
.value("COSINE", DistanceMetric::COSINE)
.value("L2", DistanceMetric::L2)
.export_values();

/* === Module inits === */
Expand Down
Loading
Loading