Skip to content

Commit

Permalink
Add interpret method and batching in collections (#26)
Browse files Browse the repository at this point in the history
This PR adds a few things:
1. We surface token scores in SearchResults.
2. We add an interpret method on collections to use our internal
tokenizer.
3. We add an add_batch method to collections and handle padding
internally.

There was a bugfix on the tokenizer to add special tokens.
  • Loading branch information
mtbarta committed May 19, 2024
1 parent bbdec40 commit 6bc71d3
Show file tree
Hide file tree
Showing 33 changed files with 619 additions and 211 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ callgrind-colbert: build-conda

py-docs:
rm -rf docs/build
sphinx-apidoc -o docs/source/ ./build/lintdb/python/lintdb
sphinx-apidoc -o docs/source/ ./builds/python/lintdb/python/lintdb
cd docs && make html
cp icon.svg docs/build/html/icon.svg

debug-conda:
conda debug lintdb --python 3.10 --output-id 'lintdb-*-py*'
Expand Down
7 changes: 0 additions & 7 deletions docs/docs/source/modules.rst

This file was deleted.

3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
chardet = "^5.2.0"
sphinx-immaterial = "^0.11.11"
sphinx-immaterial = "^0.11.11"
myst-parser
4 changes: 2 additions & 2 deletions docs/source/benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Benchmarks
=============

Mac OS
=======
-------

Benchmark: LoTTE Lifestyle 40k
-------------------------------
Expand Down Expand Up @@ -36,7 +36,7 @@ lintdb:
Linux
======
------

Benchmark: LoTTE Lifestyle 40k
-------------------------------
Expand Down
30 changes: 15 additions & 15 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join('..', '..', '_build_python/lintdb/python/build/lib')))
sys.path.insert(0, os.path.abspath(os.path.join('..', '..', 'builds/python/lintdb/python/build/lib')))

# Configuration file for the Sphinx documentation builder.
#
Expand All @@ -27,7 +27,7 @@
author = 'DeployQL'

# The full version, including alpha/beta/rc tags
release = '0.1'
release = '0.3.0'


# -- General configuration ---------------------------------------------------
Expand Down Expand Up @@ -119,19 +119,19 @@
},
],
# BEGIN: version_dropdown
# "version_dropdown": True,
# "version_info": [
# {
# "version": "https://sphinx-immaterial.rtfd.io",
# "title": "ReadTheDocs",
# "aliases": [],
# },
# {
# "version": "https://jbms.github.io/sphinx-immaterial",
# "title": "Github Pages",
# "aliases": [],
# },
# ],
"version_dropdown": True,
"version_info": [
{
"version": "LintDB/v0.3.0",
"title": "v0.3.0",
"aliases": [],
},
{
"version": "LintDB/v0.2.1",
"title": "v0.2.1",
"aliases": [],
},
],
# END: version_dropdown
"toc_title_is_page_title": True
# BEGIN: social icons
Expand Down
28 changes: 28 additions & 0 deletions docs/source/icon.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
:parser: myst_parser.sphinx_

.. toctree::
:maxdepth: 2
:maxdepth: 3
:caption: Contents:

installation
Expand Down
10 changes: 5 additions & 5 deletions docs/docs/source/lintdb.rst → docs/source/python.rst
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
lintdb package
python package
==============

Submodules
----------

lintdb.lintdb module
--------------------
python.setup module
-------------------

.. automodule:: lintdb.lintdb
.. automodule:: python.setup
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------

.. automodule:: lintdb
.. automodule:: python
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion lintdb/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ set(
Collection.h
constants.h
cf.h
RawPassage.h
EmbeddingBlock.h
Encoder.h
Passages.h
quantizers/Binarizer.h
quantizers/Quantizer.h
quantizers/ProductEncoder.h
Expand Down
64 changes: 61 additions & 3 deletions lintdb/Collection.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "lintdb/Collection.h"
#include "lintdb/RawPassage.h"
#include "lintdb/EmbeddingBlock.h"
#include <glog/logging.h>
#include <iostream>
#include "lintdb/utils/progress_bar.h"
Expand Down Expand Up @@ -29,11 +29,48 @@ namespace lintdb {
input.attention_mask = attn;

auto output = model->encode(input);
auto passage = RawPassage(output.data(), ids.size(), model->get_dims(), id, metadata);

auto passage = EmbeddingPassage(output.data(), ids.size(), model->get_dims(), id, metadata);
index->add(tenant, {passage});
}

void Collection::add_batch(
const uint64_t tenant,
const std::vector<TextPassage> passages
) const {

std::vector<ModelInput> inputs;
for(auto passage: passages) {
auto ids = tokenizer->encode(passage.data);

ModelInput input;
input.input_ids = ids;

std::vector<int32_t> attn;
for(auto id: ids) {
if(id == 0) {
attn.push_back(0);
} else {
attn.push_back(1);
}
}
input.attention_mask = attn;

inputs.push_back(input);
}

auto output = model->encode(inputs);

std::vector<EmbeddingPassage> embedded_passages;
for (size_t i = 0; i < passages.size(); i++) {
auto encoded_vector = output.get(i);
auto passage = EmbeddingPassage(encoded_vector.data(), inputs[i].input_ids.size(), model->get_dims(), passages[i].id, passages[i].metadata);
embedded_passages.push_back(passage);
}

index->add(tenant, embedded_passages);
}

std::vector<SearchResult> Collection::search(
const uint64_t tenant,
const std::string& text,
Expand All @@ -58,6 +95,27 @@ namespace lintdb {
return index->search(tenant, output.data(), ids.size(), model->get_dims(), opts.n_probe, k, opts);
}

std::vector<TokenScore> Collection::interpret(
const std::string& text,
const std::vector<float> scores
) {
auto ids = tokenizer->encode(text);
std::vector<std::string> tokens;
for(auto id: ids) {
tokens.push_back(tokenizer->decode({id}));
}

std::vector<TokenScore> results;
for(size_t i = 0; i < ids.size(); i++) {
if (tokenizer->is_special(ids[i])) {
continue;
}
results.push_back({tokens[i], scores[i]});
}

return results;
}

void Collection::train(const std::vector<std::string> texts) {
std::vector<float> embeddings;
size_t num_embeddings = 0;
Expand Down
25 changes: 24 additions & 1 deletion lintdb/Collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "lintdb/index.h"
#include "lintdb/index_builder/EmbeddingModel.h"
#include "lintdb/index_builder/Tokenizer.h"
#include "lintdb/Passages.h"
#include "lintdb/api.h"
#include <string>
#include <vector>
Expand All @@ -14,6 +15,12 @@ namespace lintdb {
std::string tokenizer_file;
size_t max_length = 512;
};

struct TokenScore {
std::string token;
float score;
};

/**
* Collection is a collection of documents. Instead of dealing directly with vectors, this
* class allows you to add and search for documents by text.
Expand All @@ -27,9 +34,20 @@ namespace lintdb {
* @param tenant The tenant id.
* @param id The document id.
* @param text The text to add.
* @param metadata a dictionary of metadata to store with the document. only accepts strings.
*/
void add(const uint64_t tenant, const uint64_t id, const std::string& text, const std::map<std::string, std::string>& metadata) const;


/**
* Add a batch of texts to the index.
*
* @param tenant The tenant id.
* @param passages A list of EmbeddingPassage objects to add.
*/
void add_batch(
const uint64_t tenant,
const std::vector<TextPassage> passages
) const;
/**
* Search the index for similar documents.
*
Expand All @@ -44,6 +62,11 @@ namespace lintdb {
const size_t k,
const SearchOptions& opts=SearchOptions()) const;

std::vector<TokenScore> interpret(
const std::string& text,
const std::vector<float> scores
);

void train(const std::vector<std::string> texts);

private:
Expand Down
2 changes: 1 addition & 1 deletion lintdb/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ DefaultEncoder::DefaultEncoder(
}

std::unique_ptr<EncodedDocument> DefaultEncoder::encode_vectors(
const RawPassage& doc) {
const EmbeddingPassage& doc) {
LINTDB_THROW_IF_NOT(nlist <= std::numeric_limits<code_t>::max());
auto num_tokens = doc.embedding_block.num_tokens;

Expand Down
9 changes: 5 additions & 4 deletions lintdb/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#include <memory>
#include <string>
#include <vector>
#include "lintdb/RawPassage.h"
#include "lintdb/Passages.h"
#include "lintdb/EmbeddingBlock.h"
#include "lintdb/SearchOptions.h"
#include "lintdb/api.h"
#include "lintdb/invlists/EncodedDocument.h"
Expand Down Expand Up @@ -38,11 +39,11 @@ struct Encoder {
virtual size_t get_num_centroids() const = 0;
virtual size_t get_nbits() const = 0;
/**
* Encode vectors translates the embeddings given to us in RawPassage to
* Encode vectors translates the embeddings given to us in EmbeddingPassage to
* the internal representation that we expect to see in the inverted lists.
*/
virtual std::unique_ptr<EncodedDocument> encode_vectors(
const RawPassage& doc) = 0;
const EmbeddingPassage& doc) = 0;

/**
* Decode vectors translates out of our internal representation.
Expand Down Expand Up @@ -138,7 +139,7 @@ struct DefaultEncoder : public Encoder {
}

std::unique_ptr<EncodedDocument> encode_vectors(
const RawPassage& doc) override;
const EmbeddingPassage& doc) override;

std::vector<float> decode_vectors(
gsl::span<const code_t> codes,
Expand Down
Loading

0 comments on commit 6bc71d3

Please sign in to comment.