Skip to content

Commit

Permalink
--prefix for llm embed-multi, refs #215
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Sep 3, 2023
1 parent 70a3d4b commit c8c0f80
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/help.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ Options:
--sql TEXT Read input using this SQL query
--attach <TEXT FILE>... Additional databases to attach - specify alias
and file path
--prefix TEXT Prefix to add to the IDs
-m, --model TEXT Embedding model to use
--store Store the text itself in the database
-d, --database FILE
Expand Down
5 changes: 3 additions & 2 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ def get_db():
multiple=True,
help="Additional databases to attach - specify alias and file path",
)
@click.option("--prefix", help="Prefix to add to the IDs", default="")
@click.option("-m", "--model", help="Embedding model to use")
@click.option("--store", is_flag=True, help="Store the text itself in the database")
@click.option(
Expand All @@ -1003,7 +1004,7 @@ def get_db():
envvar="LLM_EMBEDDINGS_DB",
)
def embed_multi(
collection, input_path, format, files, sql, attach, model, store, database
collection, input_path, format, files, sql, attach, prefix, model, store, database
):
"""
Store embeddings for multiple strings at once
Expand Down Expand Up @@ -1092,7 +1093,7 @@ def load_rows(fp):
def tuples():
for row in rows:
values = list(row.values())
id = values[0]
id = prefix + str(values[0])
text = " ".join(v or "" for v in values[1:])
yield id, text

Expand Down
21 changes: 16 additions & 5 deletions tests/test_embed_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def test_similar_by_content_cli(tmpdir, user_path_with_embeddings, scenario):


@pytest.mark.parametrize("use_stdin", (False, True))
@pytest.mark.parametrize("prefix", (None, "prefix"))
@pytest.mark.parametrize(
"filename,content",
(
Expand All @@ -234,7 +235,7 @@ def test_similar_by_content_cli(tmpdir, user_path_with_embeddings, scenario):
),
),
)
def test_embed_multi_file_input(tmpdir, use_stdin, filename, content):
def test_embed_multi_file_input(tmpdir, use_stdin, prefix, filename, content):
db_path = tmpdir / "embeddings.db"
args = ["embed-multi", "phrases", "-d", str(db_path), "-m", "embed-demo"]
input = None
Expand All @@ -245,6 +246,8 @@ def test_embed_multi_file_input(tmpdir, use_stdin, filename, content):
path = tmpdir / filename
path.write_text(content, "utf-8")
args.append(str(path))
if prefix:
args.extend(("--prefix", prefix))
# Auto-detection can't detect JSON-nl, so make that explicit
if filename.endswith(".jsonl"):
args.extend(("--format", "nl"))
Expand All @@ -254,18 +257,26 @@ def test_embed_multi_file_input(tmpdir, use_stdin, filename, content):
# Check that everything was embedded correctly
db = sqlite_utils.Database(str(db_path))
assert db["embeddings"].count == 2
ids = [row["id"] for row in db["embeddings"].rows]
expected_ids = ["1", "2"]
if prefix:
expected_ids = ["prefix1", "prefix2"]
assert ids == expected_ids


@pytest.mark.parametrize("use_other_db", (True, False))
def test_sql(tmpdir, use_other_db):
@pytest.mark.parametrize("prefix", (None, "prefix"))
def test_sql(tmpdir, use_other_db, prefix):
db_path = str(tmpdir / "embeddings.db")
db = sqlite_utils.Database(db_path)
extra_args = []
if use_other_db:
db_path2 = str(tmpdir / "other.db")
db = sqlite_utils.Database(db_path2)
extra_args = ["--attach", "other", db_path2]
other_table = "other.content"

if prefix:
extra_args.extend(("--prefix", prefix))

db["content"].insert_all(
[
Expand Down Expand Up @@ -295,6 +306,6 @@ def test_sql(tmpdir, use_other_db):
assert embeddings_db["embeddings"].count == 2
rows = list(embeddings_db.query("select id, content from embeddings"))
assert rows == [
{"id": "1", "content": "cli Command line interface"},
{"id": "2", "content": "sql Structured query language"},
{"id": (prefix or "") + "1", "content": "cli Command line interface"},
{"id": (prefix or "") + "2", "content": "sql Structured query language"},
]

0 comments on commit c8c0f80

Please sign in to comment.