Skip to content

Commit

Permalink
Error conditions for 'llm similar', refs #190
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Sep 2, 2023
1 parent de6d257 commit 3ee9215
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
14 changes: 11 additions & 3 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,12 +1028,20 @@ def similar(collection, id, input, content, number, database):
if not db["embeddings"].exists():
raise click.ClickException("No embeddings table found in database")

collection_obj = Collection(db, collection)
if not collection_obj.exists():
collection_exists = False
try:
collection_obj = Collection(db, collection)
collection_exists = collection_obj.exists()
except ValueError:
collection_exists = False
if not collection_exists:
raise click.ClickException("Collection does not exist")

if id:
results = collection_obj.similar_by_id(id, number)
try:
results = collection_obj.similar_by_id(id, number)
except ValueError:
raise click.ClickException("ID not found in collection")
else:
if not content:
if not input:
Expand Down
2 changes: 2 additions & 0 deletions llm/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def model(self) -> EmbeddingModel:
if self._model:
return self._model
try:
if not self._model_id:
raise ValueError("No model_id specified")
self._model = llm.get_embedding_model(self._model_id)
except llm.UnknownModelError:
raise ValueError("No model_id specified and no model found with that name")
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import sqlite_utils
import llm
from llm.plugins import pm

Expand All @@ -16,6 +17,14 @@ def user_path(tmpdir):
return dir


@pytest.fixture
def user_path_with_embeddings(user_path):
path = str(user_path / "embeddings.db")
db = sqlite_utils.Database(path)
collection = llm.Collection(db, "demo", model_id="embed-demo")
collection.embed("1", "hello world")


@pytest.fixture
def templates_path(user_path):
dir = user_path / "templates"
Expand Down
15 changes: 15 additions & 0 deletions tests/test_embed_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,18 @@ def test_embed_store(user_path):
]
else:
assert result2.output == "items: embed-demo\n 1 embedding\n"


@pytest.mark.parametrize(
"args,expected_error",
(
([], "Missing argument 'COLLECTION'"),
(["badcollection", "-c", "content"], "Collection does not exist"),
(["demo", "2"], "ID not found in collection"),
),
)
def test_similar_errors(args, expected_error, user_path_with_embeddings):
runner = CliRunner()
result = runner.invoke(cli, ["similar"] + args, catch_exceptions=False)
assert result.exit_code != 0
assert expected_error in result.output

0 comments on commit 3ee9215

Please sign in to comment.