diff --git a/docs/embeddings/binary.md b/docs/embeddings/binary.md new file mode 100644 index 00000000..af7e8e27 --- /dev/null +++ b/docs/embeddings/binary.md @@ -0,0 +1,21 @@ +(embeddings-binary)= +# Binary embedding formats + +The default output format of the `llm embed` command is a JSON array of floating point numbers. + +LLM stores embeddings in a more space-efficient format: little-endian binary sequences of 32-bit floating point numbers, each represented using 4 bytes. + +The following Python functions can be used to convert between the two formats: + +```python +import struct + +def encode(values): + return struct.pack("<" + "f" * len(values), *values) + +def decode(binary): + return struct.unpack("<" + "f" * (len(binary) // 4), binary) +``` +When using `llm embed` directly, the default output format is JSON. + +Use `--format blob` for the binary output, `--format hex` for that binary output as hexadecimal and `--format base64` for that binary output encoded using base64. diff --git a/docs/embeddings/cli.md b/docs/embeddings/cli.md new file mode 100644 index 00000000..3d451ec8 --- /dev/null +++ b/docs/embeddings/cli.md @@ -0,0 +1,97 @@ +(embeddings-cli)= +# Embedding with the CLI + +LLM provides command-line utilities for calculating and storing embeddings for pieces of content. + +(embeddings-llm-embed)= +## llm embed + +The `llm embed` command can be used to calculate embedding vectors for a string of content. These can be returned directly to the terminal, stored in a SQLite database, or both. + +### Returning embeddings to the terminal + +The simplest way to use this command is to pass content to it using the `-c/--content` option, like this: + +```bash +llm embed -c 'This is some content' +``` +The command will return a JSON array of floating point numbers directly to the terminal: + +```json +[0.123, 0.456, 0.789...] +``` +By default it uses the {ref}`default embedding model `. + +Use the `-m/--model` option to specify a different model: + +```bash +llm -m sentence-transformers/all-MiniLM-L6-v2 \ + -c 'This is some content' +``` +See {ref}`embeddings-binary` for options to get back embeddings in formats other than JSON. + +### Storing embeddings in SQLite + +Embeddings are much more useful if you store them somewhere, so you can calculate similarity scores between different embeddings later on. + +LLM includes a concept of a "collection" of embeddings. This is a named object where multiple pieces of content can be stored, each with a unique ID. + +The `llm embed` command can store results directly in a named collection like this: + +```bash +cat one.txt | llm embed my-files one +``` +This will store the embedding for the contents of `one.txt` in the `my-files` collection under the key `one`. + +A collection will be created the first time you mention it. + +Collections have a fixed embedding model, which is the model that was used for the first embedding stored in that collection. + +In the above example this would have been the default embedding model at the time that the command was run. + +This example stores the embedding of the string "my happy hound" in a collection called `phrases` under the key `hound` and using the model `ada-002`: + +```bash +llm embed -m ada-002 -c 'my happy hound' phrases hound +``` +By default, the SQLite database used to store embeddings is the `embeddings.db` in the user content directory managed by LLM. + +You can see the path to this directory by running `llm embed-db path`. + +You can store embeddings in a different SQLite database by passing a path to it using the `-d/--database` option to `llm embed`. If this file does not exist yet the command will create it: + +```bash +llm embed -d my-embeddings.db -c 'my happy hound' phrases hound +``` +This creates a database file called `my-embeddings.db` in the current directory. + +(embeddings-cli-embed-models-default)= +## llm embed-models default + +This command can be used to get and set the default embedding model. + +This will return the name of the current default model: +```bash +llm embed-models default +``` +You can set a different default like this: +``` +llm embed-models default name-of-other-model +``` +Any of the supported aliases for a model can be passed to this command. + +## llm embed-db collections + +To list all of the collections in the embeddings database, run this command: + +```bash +llm embed-db collections +``` +Add `--json` for JSON output: +```bash +llm embed-db collections --json +``` +Add `-d/--database` to specify a different database file: +```bash +llm embed-db collections -d my-embeddings.db +``` diff --git a/docs/embeddings/index.md b/docs/embeddings/index.md new file mode 100644 index 00000000..5055f117 --- /dev/null +++ b/docs/embeddings/index.md @@ -0,0 +1,21 @@ +(embeddings)= +# Embeddings + +Embedding models allow you to take a piece of text - a word, sentence, paragraph or even a whole articles, and convert that into an array of floating point numbers. + +This floating point array is called an "embedding vector", and works as a numerical representation of the semantic meaning of the content in a many-multi-dimensional space. + +By calculating the distance between embedding vectors, we can identify which content is semantically "nearest" to other content. + +This can be used to build features like related article lookups. It can also be used to build semantic search, where a user can search for a phrase and get back results that are semantically similar to that phrase even if they do not share any exact keywords. + +LLM supports multiple embedding models through {ref}`plugins `. Once installed, an embedding model can be used on the command-line or via the Python API to calculate and store embeddings for content, and then to perform similarity searches against those embeddings. + +```{toctree} +--- +maxdepth: 3 +--- +cli +writing-plugins +binary +``` diff --git a/docs/embeddings/writing-plugins.md b/docs/embeddings/writing-plugins.md new file mode 100644 index 00000000..520732b2 --- /dev/null +++ b/docs/embeddings/writing-plugins.md @@ -0,0 +1,48 @@ +(embeddings-writing-plugins)= +# Writing plugins to add new embedding models + +Read the {ref}`plugin tutorial ` for details on how to develop and package a plugin. + +This page shows an example plugin that implements and registers a new embedding model. + +There are two components to an embedding model plugin: + +1. An implementation of the `register_embedding_models()` hook, which takes a `register` callback function and calls it to register the new model with the LLM plugin system. +2. A class that extends the `llm.EmbeddingModel` abstract base class. + + The only required method on this class is `embed(text)`, which takes a string and returns a list of floating point numbers. + +The following example uses the [sentence-transformers](https://github.com/UKPLab/sentence-transformers) package to provide access to the [MiniLM-L6](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) embedding model. + +```python +import llm +from sentence_transformers import SentenceTransformer + + +@llm.hookimpl +def register_embedding_models(register): + model_id = "sentence-transformers/all-MiniLM-L6-v2" + register(SentenceTransformerModel(model_id, model_id, 384), aliases=("all-MiniLM-L6-v2",)) + + +class SentenceTransformerModel(llm.EmbeddingModel): + def __init__(self, model_id, model_name, embedding_size): + self.model_id = model_id + self.model_name = model_name + self.embedding_size = embedding_size + self._model = None + + def embed(self, text): + if self._model is None: + self._model = SentenceTransformer(self.model_name) + return list(map(float, self._model.encode([text])[0])) +``` +Once installed, the model provided by this plugin can be used with the {ref}`llm embed ` command like this: + +```bash +cat file.txt | llm embed -m sentence-transformers/all-MiniLM-L6-v2 +``` +Or via its registered alias like this: +```bash +cat file.txt | llm embed -m all-MiniLM-L6-v2 +``` diff --git a/docs/help.md b/docs/help.md index ed7fb981..c34f94fb 100644 --- a/docs/help.md +++ b/docs/help.md @@ -53,16 +53,19 @@ Options: --help Show this message and exit. Commands: - prompt* Execute a prompt - aliases Manage model aliases - install Install packages from PyPI into the same environment as LLM - keys Manage stored API keys for different models - logs Tools for exploring logged prompts and responses - models Manage available models - openai Commands for working directly with the OpenAI API - plugins List installed plugins - templates Manage stored prompt templates - uninstall Uninstall Python packages from the LLM environment + prompt* Execute a prompt + aliases Manage model aliases + embed Embed text and store or return the result + embed-db Manage the embeddings database + embed-models Manage available embedding models + install Install packages from PyPI into the same environment as LLM + keys Manage stored API keys for different models + logs Tools for exploring logged prompts and responses + models Manage available models + openai Commands for working directly with the OpenAI API + plugins List installed plugins + templates Manage stored prompt templates + uninstall Uninstall Python packages from the LLM environment ``` ### llm prompt --help ``` @@ -380,6 +383,86 @@ Options: -y, --yes Don't ask for confirmation --help Show this message and exit. ``` +### llm embed --help +``` +Usage: llm embed [OPTIONS] [COLLECTION] [ID] + + Embed text and store or return the result + +Options: + -i, --input FILE Content to embed + -m, --model TEXT Embedding model to use + --store Store the text itself in the database + -d, --database FILE + -c, --content FILE + -f, --format [json|blob|base64|hex] + Output format + --help Show this message and exit. +``` +### llm embed-models --help +``` +Usage: llm embed-models [OPTIONS] COMMAND [ARGS]... + + Manage available embedding models + +Options: + --help Show this message and exit. + +Commands: + list* List available embedding models + default Show or set the default embedding model +``` +#### llm embed-models list --help +``` +Usage: llm embed-models list [OPTIONS] + + List available embedding models + +Options: + --help Show this message and exit. +``` +#### llm embed-models default --help +``` +Usage: llm embed-models default [OPTIONS] [MODEL] + + Show or set the default embedding model + +Options: + --help Show this message and exit. +``` +### llm embed-db --help +``` +Usage: llm embed-db [OPTIONS] COMMAND [ARGS]... + + Manage the embeddings database + +Options: + --help Show this message and exit. + +Commands: + collections Output the path to the embeddings database + path Output the path to the embeddings database +``` +#### llm embed-db path --help +``` +Usage: llm embed-db path [OPTIONS] + + Output the path to the embeddings database + +Options: + --help Show this message and exit. +``` +#### llm embed-db collections --help +``` +Usage: llm embed-db collections [OPTIONS] + + Output the path to the embeddings database + +Options: + -d, --database FILE Path to embeddings database + --json Output as JSON + --help Show this message and exit. +``` ### llm openai --help ``` Usage: llm openai [OPTIONS] COMMAND [ARGS]... diff --git a/docs/index.md b/docs/index.md index 81005b64..4b6d201c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -57,6 +57,7 @@ maxdepth: 3 setup usage other-models +embeddings/index plugins/index aliases python-api diff --git a/llm/__init__.py b/llm/__init__.py index 7ff8691e..f7e3662d 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -7,6 +7,8 @@ Conversation, Model, ModelWithAliases, + EmbeddingModel, + EmbeddingModelWithAliases, Options, Prompt, Response, @@ -73,6 +75,55 @@ def register(model, aliases=None): return model_aliases +def get_embedding_models_with_aliases() -> List["EmbeddingModelWithAliases"]: + model_aliases = [] + + # Include aliases from aliases.json + aliases_path = user_dir() / "aliases.json" + extra_model_aliases: Dict[str, list] = {} + if aliases_path.exists(): + configured_aliases = json.loads(aliases_path.read_text()) + for alias, model_id in configured_aliases.items(): + extra_model_aliases.setdefault(model_id, []).append(alias) + + def register(model, aliases=None): + alias_list = list(aliases or []) + if model.model_id in extra_model_aliases: + alias_list.extend(extra_model_aliases[model.model_id]) + model_aliases.append(EmbeddingModelWithAliases(model, alias_list)) + + pm.hook.register_embedding_models(register=register) + + return model_aliases + + +def get_embedding_models(): + models = [] + + def register(model, aliases=None): + models.append(model) + + pm.hook.register_embedding_models(register=register) + return models + + +def get_embedding_model(name): + aliases = get_embedding_model_aliases() + try: + return aliases[name] + except KeyError: + raise UnknownModelError("Unknown model: " + name) + + +def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]: + model_aliases = {} + for model_with_aliases in get_embedding_models_with_aliases(): + for alias in model_with_aliases.aliases: + model_aliases[alias] = model_with_aliases.model + model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model + return model_aliases + + def get_model_aliases() -> Dict[str, Model]: model_aliases = {} for model_with_aliases in get_models_with_aliases(): diff --git a/llm/cli.py b/llm/cli.py index ece01d12..2563351f 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -6,6 +6,8 @@ Response, Template, UnknownModelError, + get_embedding_models_with_aliases, + get_embedding_model, get_key, get_plugins, get_model, @@ -17,12 +19,16 @@ ) from .migrations import migrate +from .embeddings_migrations import embeddings_migrations from .plugins import pm +import base64 import pathlib import pydantic from runpy import run_module import shutil import sqlite_utils +from sqlite_utils.db import NotFoundError +import struct import sys import textwrap from typing import cast, Optional @@ -32,6 +38,7 @@ warnings.simplefilter("ignore", ResourceWarning) DEFAULT_MODEL = "gpt-3.5-turbo" +DEFAULT_EMBEDDING_MODEL = "ada-002" DEFAULT_TEMPLATE = "prompt: " @@ -853,6 +860,225 @@ def uninstall(packages, yes): run_module("pip", run_name="__main__") +@cli.command() +@click.argument("collection", required=False) +@click.argument("id", required=False) +@click.option( + "-i", + "--input", + type=click.Path(file_okay=True, allow_dash=True, dir_okay=False), + help="Content to embed", +) +@click.option("-m", "--model", help="Embedding model to use") +@click.option("--store", is_flag=True, help="Store the text itself in the database") +@click.option( + "-d", + "--database", + type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True), + envvar="LLM_EMBEDDINGS_DB", +) +@click.option( + "-c", + "--content", + type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True), +) +@click.option( + "format_", + "-f", + "--format", + type=click.Choice(["json", "blob", "base64", "hex"]), + help="Output format", +) +def embed(collection, id, input, model, store, database, content, format_): + """Embed text and store or return the result""" + if collection and not id: + raise click.ClickException("Must provide both collection and id") + + db = None + + def get_db(): + if database: + return sqlite_utils.Database(database) + else: + return sqlite_utils.Database(user_dir() / "embeddings.db") + + existing_collection = None + if collection: + db = get_db() + if db["collections"].exists(): + try: + existing_collection = get_collection(db, collection) + except NotFoundError: + pass + + if model is None: + # If collection exists, use that model + if existing_collection: + model = existing_collection["model"] + else: + # Use default model + model = get_default_embedding_model() + + if model and existing_collection and model != existing_collection["model"]: + raise click.ClickException( + "Model '{}' does not match '{}' collection model of '{}'".format( + model, collection, existing_collection["model"] + ) + ) + + try: + model = get_embedding_model(model) + except UnknownModelError as ex: + raise click.ClickException(str(ex)) + + show_output = True + if collection and (format_ is None): + show_output = False + + # Resolve input text + if not content: + if not input: + # Read from stdin + input = sys.stdin + content = input.read() + if not content: + raise click.ClickException("No content provided") + + embedding = model.embed(content) + + if collection: + # Store the embedding + if db is None: + db = get_db() + + embeddings_migrations.apply(db) + + if not existing_collection: + db["collections"].insert( + { + "name": collection, + "model": model.model_id, + } + ) + existing_collection = get_collection(db, collection) + + # Now store it + db["embeddings"].insert( + { + "collection_id": existing_collection["id"], + "id": id, + "content": content if store else None, + "embedding": encode(embedding), + }, + replace=True, + ) + + if show_output: + if format_ == "json" or format_ is None: + click.echo(json.dumps(embedding)) + elif format_ == "blob": + click.echo(encode(embedding)) + elif format_ == "base64": + click.echo(base64.b64encode(encode(embedding)).decode("ascii")) + elif format_ == "hex": + click.echo(encode(embedding).hex()) + + +@cli.group( + cls=DefaultGroup, + default="list", + default_if_no_args=True, +) +def embed_models(): + "Manage available embedding models" + + +@embed_models.command(name="list") +def embed_models_list(): + "List available embedding models" + output = [] + for model_with_aliases in get_embedding_models_with_aliases(): + s = str(model_with_aliases.model.model_id) + if model_with_aliases.aliases: + s += " (aliases: {})".format(", ".join(model_with_aliases.aliases)) + output.append(s) + click.echo("\n".join(output)) + + +@embed_models.command(name="default") +@click.argument("model", required=False) +def embed_models_default(model): + "Show or set the default embedding model" + if not model: + click.echo(get_default_embedding_model()) + return + # Validate it is a known model + try: + model = get_embedding_model(model) + set_default_embedding_model(model.model_id) + except KeyError: + raise click.ClickException("Unknown embedding model: {}".format(model)) + + +@cli.group() +def embed_db(): + "Manage the embeddings database" + + +@embed_db.command(name="path") +def embed_db_path(): + "Output the path to the embeddings database" + click.echo(user_dir() / "embeddings.db") + + +@embed_db.command(name="collections") +@click.option( + "-d", + "--database", + type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True), + envvar="LLM_EMBEDDINGS_DB", + help="Path to embeddings database", +) +@click.option("json_", "--json", is_flag=True, help="Output as JSON") +def embed_db_collections(database, json_): + "Output the path to the embeddings database" + database = database or (user_dir() / "embeddings.db") + db = sqlite_utils.Database(str(database)) + if not db["collections"].exists(): + raise click.ClickException("No collections table found in {}".format(database)) + rows = db.query( + """ + select + collections.name, + collections.model, + count(embeddings.id) as num_embeddings + from + collections left join embeddings + on collections.id = embeddings.collection_id + group by + collections.name, collections.model + """ + ) + if json_: + click.echo(json.dumps(list(rows), indent=4)) + else: + for row in rows: + click.echo("{}: {}".format(row["name"], row["model"])) + click.echo( + " {} embedding{}".format( + row["num_embeddings"], "s" if row["num_embeddings"] != 1 else "" + ) + ) + + +def get_collection(db, collection): + rows = db["collections"].rows_where("name = ?", [collection]) + try: + return next(rows) + except StopIteration: + raise NotFoundError("Collection not found: {}".format(collection)) + + def template_dir(): path = user_dir() / "templates" path.mkdir(parents=True, exist_ok=True) @@ -865,19 +1091,27 @@ def _truncate_string(s, max_length=100): return s -def get_default_model(): - path = user_dir() / "default_model.txt" +def get_default_model(filename="default_model.txt", default=DEFAULT_MODEL): + path = user_dir() / filename if path.exists(): return path.read_text().strip() else: - return DEFAULT_MODEL + return default -def set_default_model(model): - path = user_dir() / "default_model.txt" +def set_default_model(model, filename="default_model.txt"): + path = user_dir() / filename path.write_text(model) +def get_default_embedding_model(): + return get_default_model("default_embedding_model.txt", DEFAULT_EMBEDDING_MODEL) + + +def set_default_embedding_model(model): + set_default_model(model, "default_embedding_model.txt") + + def logs_db_path(): return user_dir() / "logs.db" @@ -947,3 +1181,11 @@ def _human_readable_size(size_bytes): def logs_on(): return not (user_dir() / "logs-off").exists() + + +def encode(values): + return struct.pack("<" + "f" * len(values), *values) + + +def decode(binary): + return struct.unpack("<" + "f" * (len(binary) // 4), binary) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 56331afb..408dff52 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -1,4 +1,4 @@ -from llm import Model, hookimpl +from llm import EmbeddingModel, Model, hookimpl import llm from llm.utils import dicts_to_table_string import click @@ -33,9 +33,18 @@ def register_models(register): aliases = extra_model.get("aliases", []) model_name = extra_model["model_name"] api_base = extra_model.get("api_base") + api_type = extra_model.get("api_type") + api_version = extra_model.get("api_version") + api_engine = extra_model.get("api_engine") headers = extra_model.get("headers") chat_model = Chat( - model_id, model_name=model_name, api_base=api_base, headers=headers + model_id, + model_name=model_name, + api_base=api_base, + api_type=api_type, + api_version=api_version, + api_engine=api_engine, + headers=headers, ) if api_base: chat_model.needs_key = None @@ -47,6 +56,23 @@ def register_models(register): ) +@hookimpl +def register_embedding_models(register): + register(Ada002(), aliases=("ada",)) + + +class Ada002(EmbeddingModel): + model_id = "ada-002" + embedding_size = 1536 + needs_key = "openai" + key_env_var = "OPENAI_API_KEY" + + def embed(self, text): + return openai.Embedding.create( + input=text, model="text-embedding-ada-002", api_key=self.get_key() + )["data"][0]["embedding"] + + @hookimpl def register_commands(cli): @cli.group(name="openai") @@ -179,12 +205,23 @@ def validate_logit_bias(cls, logit_bias): return validated_logit_bias def __init__( - self, model_id, key=None, model_name=None, api_base=None, headers=None + self, + model_id, + key=None, + model_name=None, + api_base=None, + api_type=None, + api_version=None, + api_engine=None, + headers=None, ): self.model_id = model_id self.key = key self.model_name = model_name self.api_base = api_base + self.api_type = api_type + self.api_version = api_version + self.api_engine = api_engine self.headers = headers def __str__(self): @@ -214,6 +251,12 @@ def execute(self, prompt, stream, response, conversation=None): kwargs = dict(not_nulls(prompt.options)) if self.api_base: kwargs["api_base"] = self.api_base + if self.api_type: + kwargs["api_type"] = self.api_type + if self.api_version: + kwargs["api_version"] = self.api_version + if self.api_engine: + kwargs["engine"] = self.api_engine if self.needs_key: if self.key: kwargs["api_key"] = self.key diff --git a/llm/embeddings_migrations.py b/llm/embeddings_migrations.py new file mode 100644 index 00000000..0cc9f71a --- /dev/null +++ b/llm/embeddings_migrations.py @@ -0,0 +1,19 @@ +from sqlite_migrate import Migrations + +embeddings_migrations = Migrations("llm.embeddings") + + +@embeddings_migrations() +def m001_create_tables(db): + db["collections"].create({"id": int, "name": str, "model": str}, pk="id") + db["collections"].create_index(["name"], unique=True) + db["embeddings"].create( + { + "collection_id": int, + "id": str, + "embedding": bytes, + "content": str, + "metadata": str, + }, + pk=("collection_id", "id"), + ) diff --git a/llm/hookspecs.py b/llm/hookspecs.py index d3a00d4d..b56b2c78 100644 --- a/llm/hookspecs.py +++ b/llm/hookspecs.py @@ -13,3 +13,8 @@ def register_commands(cli): @hookspec def register_models(register): "Return a list of model instances representing LLM models that can be called" + + +@hookspec +def register_embedding_models(register): + "Return a list of model instances that can be used for embedding" diff --git a/llm/models.py b/llm/models.py index 6e9b8f10..8e1f17c6 100644 --- a/llm/models.py +++ b/llm/models.py @@ -208,16 +208,7 @@ class Config: _Options = Options -class Model(ABC): - model_id: str - key: Optional[str] = None - needs_key: Optional[str] = None - key_env_var: Optional[str] = None - can_stream: bool = False - - class Options(_Options): - pass - +class _get_key_mixin: def get_key(self): from llm import get_key @@ -244,6 +235,17 @@ def get_key(self): message += " or set the {} environment variable".format(self.key_env_var) raise NeedsKeyException(message) + +class Model(ABC, _get_key_mixin): + model_id: str + key: Optional[str] = None + needs_key: Optional[str] = None + key_env_var: Optional[str] = None + can_stream: bool = False + + class Options(_Options): + pass + def conversation(self): return Conversation(model=self) @@ -283,12 +285,33 @@ def __repr__(self): return "".format(self.model_id) +class EmbeddingModel(ABC, _get_key_mixin): + model_id: str + embedding_size: int + key: Optional[str] = None + needs_key: Optional[str] = None + key_env_var: Optional[str] = None + + @abstractmethod + def embed(self, text: str) -> List[float]: + """ + Embed a some text as a list of floats + """ + pass + + @dataclass class ModelWithAliases: model: Model aliases: Set[str] +@dataclass +class EmbeddingModelWithAliases: + model: EmbeddingModel + aliases: Set[str] + + def _conversation_name(text): # Collapse whitespace, including newlines text = re.sub(r"\s+", " ", text) diff --git a/mypy.ini b/mypy.ini index d025b7cf..a17287e4 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,3 +6,5 @@ ignore_missing_imports = True [mypy-click_default_group.*] ignore_missing_imports = True +[mypy-sqlite_migrate.*] +ignore_missing_imports = True diff --git a/setup.py b/setup.py index 67fb9002..863cff77 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ def get_long_description(): "openai", "click-default-group-wheel", "sqlite-utils>=3.35.0", + "sqlite-migrate", "pydantic>=1.10.2", "PyYAML", "pluggy", diff --git a/tests/conftest.py b/tests/conftest.py index f8830fcf..4d039fe1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,6 @@ import pytest +import llm +from llm.plugins import pm def pytest_configure(config): @@ -26,6 +28,33 @@ def env_setup(monkeypatch, user_path): monkeypatch.setenv("LLM_USER_PATH", str(user_path)) +class EmbedDemo(llm.EmbeddingModel): + model_id = "embed-demo" + + def embed(self, text): + words = text.split()[:16] + embedding = [len(word) for word in words] + # Pad with 0 up to 16 words + embedding += [0] * (16 - len(embedding)) + return embedding + + +@pytest.fixture(autouse=True) +def register_embed_demo_model(): + class EmbedDemoPlugin: + __name__ = "EmbedDemoPlugin" + + @llm.hookimpl + def register_embedding_models(self, register): + register(EmbedDemo()) + + pm.register(EmbedDemoPlugin(), name="undo-embed-demo-plugin") + try: + yield + finally: + pm.unregister(name="undo-embed-demo-plugin") + + @pytest.fixture def mocked_openai(requests_mock): return requests_mock.post( diff --git a/tests/test_embed.py b/tests/test_embed.py new file mode 100644 index 00000000..138db470 --- /dev/null +++ b/tests/test_embed.py @@ -0,0 +1,6 @@ +import llm + + +def test_demo_plugin(): + model = llm.get_embedding_model("embed-demo") + assert model.embed("hello world") == [5, 5] + [0] * 14 diff --git a/tests/test_embed_cli.py b/tests/test_embed_cli.py new file mode 100644 index 00000000..2543bb0c --- /dev/null +++ b/tests/test_embed_cli.py @@ -0,0 +1,105 @@ +from click.testing import CliRunner +from llm.cli import cli +import json +import pytest +import sqlite_utils + + +@pytest.mark.parametrize( + "format_,expected", + ( + ("json", "[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"), + ( + "base64", + ( + "AACgQAAAoEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==\n" + ), + ), + ( + "hex", + ( + "0000a0400000a04000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000" + "00000000000000\n" + ), + ), + ( + "blob", + ( + b"\x00\x00\xef\xbf\xbd@\x00\x00\xef\xbf\xbd@\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n" + ).decode("utf-8"), + ), + ), +) +def test_embed_output_format(format_, expected): + runner = CliRunner() + result = runner.invoke( + cli, ["embed", "--format", format_, "-c", "hello world", "-m", "embed-demo"] + ) + assert result.exit_code == 0 + assert result.output == expected + + +@pytest.mark.parametrize( + "args,expected_error", + ((["-c", "Content", "stories"], "Must provide both collection and id"),), +) +def test_embed_errors(args, expected_error): + runner = CliRunner() + result = runner.invoke(cli, ["embed"] + args) + assert result.exit_code == 1 + assert expected_error in result.output + + +def test_embed_store(user_path): + embeddings_db = user_path / "embeddings.db" + assert not embeddings_db.exists() + runner = CliRunner() + result = runner.invoke(cli, ["embed", "-c", "hello", "-m", "embed-demo"]) + assert result.exit_code == 0 + # Should not have created the table + assert not embeddings_db.exists() + # Now run it to store + result = runner.invoke( + cli, ["embed", "-c", "hello", "-m", "embed-demo", "items", "1"] + ) + assert result.exit_code == 0 + assert embeddings_db.exists() + # Check the contents + db = sqlite_utils.Database(str(embeddings_db)) + assert list(db["collections"].rows) == [ + {"id": 1, "name": "items", "model": "embed-demo"} + ] + assert list(db["embeddings"].rows) == [ + { + "collection_id": 1, + "id": "1", + "embedding": ( + b"\x00\x00\xa0@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00" + ), + "content": None, + "metadata": None, + } + ] + # Should show up in 'llm embed-db collections' + for is_json in (False, True): + args = ["embed-db", "collections"] + if is_json: + args.extend(["--json"]) + result2 = runner.invoke(cli, args) + assert result2.exit_code == 0 + if is_json: + assert json.loads(result2.output) == [ + {"name": "items", "model": "embed-demo", "num_embeddings": 1} + ] + else: + assert result2.output == "items: embed-demo\n 1 embedding\n"