Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interpret method and batching in collections #26

Merged
merged 5 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ callgrind-colbert: build-conda

py-docs:
rm -rf docs/build
sphinx-apidoc -o docs/source/ ./build/lintdb/python/lintdb
sphinx-apidoc -o docs/source/ ./builds/python/lintdb/python/lintdb
cd docs && make html
cp icon.svg docs/build/html/icon.svg

debug-conda:
conda debug lintdb --python 3.10 --output-id 'lintdb-*-py*'
Expand Down
7 changes: 0 additions & 7 deletions docs/docs/source/modules.rst

This file was deleted.

3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
chardet = "^5.2.0"
sphinx-immaterial = "^0.11.11"
sphinx-immaterial = "^0.11.11"
myst-parser
4 changes: 2 additions & 2 deletions docs/source/benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Benchmarks
=============

Mac OS
=======
-------

Benchmark: LoTTE Lifestyle 40k
-------------------------------
Expand Down Expand Up @@ -36,7 +36,7 @@ lintdb:


Linux
======
------

Benchmark: LoTTE Lifestyle 40k
-------------------------------
Expand Down
30 changes: 15 additions & 15 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join('..', '..', '_build_python/lintdb/python/build/lib')))
sys.path.insert(0, os.path.abspath(os.path.join('..', '..', 'builds/python/lintdb/python/build/lib')))

# Configuration file for the Sphinx documentation builder.
#
Expand All @@ -27,7 +27,7 @@
author = 'DeployQL'

# The full version, including alpha/beta/rc tags
release = '0.1'
release = '0.3.0'


# -- General configuration ---------------------------------------------------
Expand Down Expand Up @@ -119,19 +119,19 @@
},
],
# BEGIN: version_dropdown
# "version_dropdown": True,
# "version_info": [
# {
# "version": "https://sphinx-immaterial.rtfd.io",
# "title": "ReadTheDocs",
# "aliases": [],
# },
# {
# "version": "https://jbms.github.io/sphinx-immaterial",
# "title": "Github Pages",
# "aliases": [],
# },
# ],
"version_dropdown": True,
"version_info": [
{
"version": "LintDB/v0.3.0",
"title": "v0.3.0",
"aliases": [],
},
{
"version": "LintDB/v0.2.1",
"title": "v0.2.1",
"aliases": [],
},
],
# END: version_dropdown
"toc_title_is_page_title": True
# BEGIN: social icons
Expand Down
28 changes: 28 additions & 0 deletions docs/source/icon.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
:parser: myst_parser.sphinx_

.. toctree::
:maxdepth: 2
:maxdepth: 3
:caption: Contents:

installation
Expand Down
10 changes: 5 additions & 5 deletions docs/docs/source/lintdb.rst → docs/source/python.rst
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
lintdb package
python package
==============

Submodules
----------

lintdb.lintdb module
--------------------
python.setup module
-------------------

.. automodule:: lintdb.lintdb
.. automodule:: python.setup
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------

.. automodule:: lintdb
.. automodule:: python
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion lintdb/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ set(
Collection.h
constants.h
cf.h
RawPassage.h
EmbeddingBlock.h
Encoder.h
Passages.h
quantizers/Binarizer.h
quantizers/Quantizer.h
quantizers/ProductEncoder.h
Expand Down
64 changes: 61 additions & 3 deletions lintdb/Collection.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "lintdb/Collection.h"
#include "lintdb/RawPassage.h"
#include "lintdb/EmbeddingBlock.h"
#include <glog/logging.h>
#include <iostream>
#include "lintdb/utils/progress_bar.h"
Expand Down Expand Up @@ -29,11 +29,48 @@ namespace lintdb {
input.attention_mask = attn;

auto output = model->encode(input);
auto passage = RawPassage(output.data(), ids.size(), model->get_dims(), id, metadata);

auto passage = EmbeddingPassage(output.data(), ids.size(), model->get_dims(), id, metadata);
index->add(tenant, {passage});
}

void Collection::add_batch(
const uint64_t tenant,
const std::vector<TextPassage> passages
) const {

std::vector<ModelInput> inputs;
for(auto passage: passages) {
auto ids = tokenizer->encode(passage.data);

ModelInput input;
input.input_ids = ids;

std::vector<int32_t> attn;
for(auto id: ids) {
if(id == 0) {
attn.push_back(0);
} else {
attn.push_back(1);
}
}
input.attention_mask = attn;

inputs.push_back(input);
}

auto output = model->encode(inputs);

std::vector<EmbeddingPassage> embedded_passages;
for (size_t i = 0; i < passages.size(); i++) {
auto encoded_vector = output.get(i);
auto passage = EmbeddingPassage(encoded_vector.data(), inputs[i].input_ids.size(), model->get_dims(), passages[i].id, passages[i].metadata);
embedded_passages.push_back(passage);
}

index->add(tenant, embedded_passages);
}

std::vector<SearchResult> Collection::search(
const uint64_t tenant,
const std::string& text,
Expand All @@ -58,6 +95,27 @@ namespace lintdb {
return index->search(tenant, output.data(), ids.size(), model->get_dims(), opts.n_probe, k, opts);
}

std::vector<TokenScore> Collection::interpret(
const std::string& text,
const std::vector<float> scores
) {
auto ids = tokenizer->encode(text);
std::vector<std::string> tokens;
for(auto id: ids) {
tokens.push_back(tokenizer->decode({id}));
}

std::vector<TokenScore> results;
for(size_t i = 0; i < ids.size(); i++) {
if (tokenizer->is_special(ids[i])) {
continue;
}
results.push_back({tokens[i], scores[i]});
}

return results;
}

void Collection::train(const std::vector<std::string> texts) {
std::vector<float> embeddings;
size_t num_embeddings = 0;
Expand Down
25 changes: 24 additions & 1 deletion lintdb/Collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "lintdb/index.h"
#include "lintdb/index_builder/EmbeddingModel.h"
#include "lintdb/index_builder/Tokenizer.h"
#include "lintdb/Passages.h"
#include "lintdb/api.h"
#include <string>
#include <vector>
Expand All @@ -14,6 +15,12 @@ namespace lintdb {
std::string tokenizer_file;
size_t max_length = 512;
};

struct TokenScore {
std::string token;
float score;
};

/**
* Collection is a collection of documents. Instead of dealing directly with vectors, this
* class allows you to add and search for documents by text.
Expand All @@ -27,9 +34,20 @@ namespace lintdb {
* @param tenant The tenant id.
* @param id The document id.
* @param text The text to add.
* @param metadata a dictionary of metadata to store with the document. only accepts strings.
*/
void add(const uint64_t tenant, const uint64_t id, const std::string& text, const std::map<std::string, std::string>& metadata) const;


/**
* Add a batch of texts to the index.
*
* @param tenant The tenant id.
* @param passages A list of EmbeddingPassage objects to add.
*/
void add_batch(
const uint64_t tenant,
const std::vector<TextPassage> passages
) const;
/**
* Search the index for similar documents.
*
Expand All @@ -44,6 +62,11 @@ namespace lintdb {
const size_t k,
const SearchOptions& opts=SearchOptions()) const;

std::vector<TokenScore> interpret(
const std::string& text,
const std::vector<float> scores
);

void train(const std::vector<std::string> texts);

private:
Expand Down
2 changes: 1 addition & 1 deletion lintdb/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
// for normalized vectors, this should be the same as IP.
this->coarse_quantizer = std::make_unique<faiss::IndexFlatIP>(dim);
auto quantizer_config = QuantizerConfig{
.nlist = nlist,

Check warning on line 59 in lintdb/Encoder.cpp

View workflow job for this annotation

GitHub Actions / Run Cmake

designated initializers are a C++20 extension [-Wc++20-designator]
.nbits = nbits,
.niter = niter,
.dim = dim,
Expand All @@ -66,7 +66,7 @@
}

std::unique_ptr<EncodedDocument> DefaultEncoder::encode_vectors(
const RawPassage& doc) {
const EmbeddingPassage& doc) {
LINTDB_THROW_IF_NOT(nlist <= std::numeric_limits<code_t>::max());
auto num_tokens = doc.embedding_block.num_tokens;

Expand Down Expand Up @@ -294,7 +294,7 @@
encoder->coarse_quantizer = std::move(coarse_quantizer);

auto quantizer_config = QuantizerConfig{
.nlist = config.nlist,

Check warning on line 297 in lintdb/Encoder.cpp

View workflow job for this annotation

GitHub Actions / Run Cmake

designated initializers are a C++20 extension [-Wc++20-designator]
.nbits = config.nbits,
.niter = config.niter,
.dim = config.dim,
Expand Down
9 changes: 5 additions & 4 deletions lintdb/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#include <memory>
#include <string>
#include <vector>
#include "lintdb/RawPassage.h"
#include "lintdb/Passages.h"
#include "lintdb/EmbeddingBlock.h"
#include "lintdb/SearchOptions.h"
#include "lintdb/api.h"
#include "lintdb/invlists/EncodedDocument.h"
Expand Down Expand Up @@ -38,11 +39,11 @@ struct Encoder {
virtual size_t get_num_centroids() const = 0;
virtual size_t get_nbits() const = 0;
/**
* Encode vectors translates the embeddings given to us in RawPassage to
* Encode vectors translates the embeddings given to us in EmbeddingPassage to
* the internal representation that we expect to see in the inverted lists.
*/
virtual std::unique_ptr<EncodedDocument> encode_vectors(
const RawPassage& doc) = 0;
const EmbeddingPassage& doc) = 0;

/**
* Decode vectors translates out of our internal representation.
Expand Down Expand Up @@ -138,7 +139,7 @@ struct DefaultEncoder : public Encoder {
}

std::unique_ptr<EncodedDocument> encode_vectors(
const RawPassage& doc) override;
const EmbeddingPassage& doc) override;

std::vector<float> decode_vectors(
gsl::span<const code_t> codes,
Expand Down
Loading
Loading