diff --git a/docs/embeddings/cli.md b/docs/embeddings/cli.md index 80cfa142..cbabbbd8 100644 --- a/docs/embeddings/cli.md +++ b/docs/embeddings/cli.md @@ -256,7 +256,7 @@ plugins/index.md ``` Each corresponding to embedded content for the file in question. -The `--prefix` option can be useful here to add a prefix to each ID: +The `--prefix` option can be used to add a prefix to each ID: ```bash llm embed-multi documentation \ @@ -279,6 +279,19 @@ llm-docs/logging.md llm-docs/plugins/directory.md llm-docs/plugins/index.md ``` +Files are assumed to be `utf-8`, but LLM will fall back to `latin-1` if it encounters an encoding error. You can specify a different set of encodings using the `--encoding` option. + +This example will try `utf-16` first and then `mac_roman` before falling back to `latin-1`: +``` +llm embed-multi documentation \ + -m ada-002 \ + --files docs '**/*.md' \ + -d documentation.db \ + --encoding utf-16 \ + --encoding mac_roman \ + --encoding latin-1 +``` +If a file cannot be read it will be logged to standard error but the script will keep on running. (embeddings-cli-similar)= ## llm similar diff --git a/docs/help.md b/docs/help.md index ed3593a8..a77c22cd 100644 --- a/docs/help.md +++ b/docs/help.md @@ -458,6 +458,7 @@ Options: --format [json|csv|tsv|nl] Format of input file - defaults to auto-detect --files ... Embed files in this directory - specify directory and glob pattern + --encoding TEXT Encoding to use when reading --files --sql TEXT Read input using this SQL query --attach ... Additional databases to attach - specify alias and file path diff --git a/llm/cli.py b/llm/cli.py index d0f48086..e6588ac5 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1009,6 +1009,12 @@ def get_db(): multiple=True, help="Embed files in this directory - specify directory and glob pattern", ) +@click.option( + "encodings", + "--encoding", + help="Encoding to use when reading --files", + multiple=True, +) @click.option("--sql", help="Read input using this SQL query") @click.option( "--attach", @@ -1026,7 +1032,17 @@ def get_db(): envvar="LLM_EMBEDDINGS_DB", ) def embed_multi( - collection, input_path, format, files, sql, attach, prefix, model, store, database + collection, + input_path, + format, + files, + encodings, + sql, + attach, + prefix, + model, + store, + database, ): """ Store embeddings for multiple strings at once @@ -1072,6 +1088,7 @@ def embed_multi( expected_length = None if files: + encodings = encodings or ("utf-8", "latin-1") def count_files(): i = 0 @@ -1084,7 +1101,20 @@ def iterate_files(): for directory, pattern in files: for path in pathlib.Path(directory).glob(pattern): relative = path.relative_to(directory) - yield {"id": str(relative), "content": path.read_text()} + content = None + for encoding in encodings: + try: + content = path.read_text(encoding=encoding) + except UnicodeDecodeError: + continue + if content is None: + # Log to stderr + click.echo( + "Could not decode text in file {}".format(path), + err=True, + ) + else: + yield {"id": str(relative), "content": content} expected_length = count_files() rows = iterate_files() diff --git a/tests/test_embed_cli.py b/tests/test_embed_cli.py index 2f88f28b..25dd392d 100644 --- a/tests/test_embed_cli.py +++ b/tests/test_embed_cli.py @@ -313,21 +313,40 @@ def test_embed_multi_sql(tmpdir, use_other_db, prefix): ] -@pytest.mark.parametrize("scenario", ("single", "multi")) -def test_embed_multi_files(tmpdir, scenario): +@pytest.fixture +def multi_files(tmpdir): db_path = str(tmpdir / "files.db") files = tmpdir / "files" for filename, content in ( - ("file1.txt", "hello world"), - ("file2.txt", "goodbye world"), - ("nested/one.txt", "one"), - ("nested/two.txt", "two"), - ("nested/more/three.txt", "three"), - ("nested/more/ignored.ini", "Does not match glob"), + ("file1.txt", b"hello world"), + ("file2.txt", b"goodbye world"), + ("nested/one.txt", b"one"), + ("nested/two.txt", b"two"), + ("nested/more/three.txt", b"three"), + # This tests the fallback to latin-1 encoding: + ("nested/more/ignored.ini", b"Has weird \x96 character"), ): path = pathlib.Path(files / filename) path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(content, "utf-8") + path.write_bytes(content) + return db_path, tmpdir / "files" + + +@pytest.mark.parametrize("scenario", ("single", "multi")) +def test_embed_multi_files(multi_files, scenario): + db_path, files = multi_files + for filename, content in ( + ("file1.txt", b"hello world"), + ("file2.txt", b"goodbye world"), + ("nested/one.txt", b"one"), + ("nested/two.txt", b"two"), + ("nested/more/three.txt", b"three"), + # This tests the fallback to latin-1 encoding: + ("nested/more/ignored.ini", b"Has weird \x96 character"), + ): + path = pathlib.Path(files / filename) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(content) if scenario == "single": extra_args = ["--files", str(files), "**/*.txt"] @@ -368,12 +387,56 @@ def test_embed_multi_files(tmpdir, scenario): ] else: assert rows == [ - {"id": "ignored.ini", "content": "Does not match glob"}, + {"id": "ignored.ini", "content": "Has weird \x96 character"}, {"id": "two.txt", "content": "two"}, {"id": "one.txt", "content": "one"}, ] +@pytest.mark.parametrize( + "extra_args,expected_error", + ( + # With no args default utf-8 with latin-1 fallback should work + ([], None), + (["--encoding", "utf-8"], "Could not decode text in file"), + (["--encoding", "latin-1"], None), + (["--encoding", "latin-1", "--encoding", "utf-8"], None), + (["--encoding", "utf-8", "--encoding", "latin-1"], None), + ), +) +def test_embed_multi_files_encoding(multi_files, extra_args, expected_error): + db_path, files = multi_files + runner = CliRunner(mix_stderr=False) + result = runner.invoke( + cli, + [ + "embed-multi", + "files", + "-d", + db_path, + "-m", + "embed-demo", + "--files", + str(files / "nested" / "more"), + "*.ini", + "--store", + ] + + extra_args, + ) + if expected_error: + # Should still succeed with 0, but show a warning + assert result.exit_code == 0 + assert expected_error in result.stderr + else: + assert result.exit_code == 0 + assert not result.stderr + embeddings_db = sqlite_utils.Database(db_path) + rows = list(embeddings_db.query("select id, content from embeddings")) + assert rows == [ + {"id": "ignored.ini", "content": "Has weird \x96 character"}, + ] + + def test_default_embedding_model(): runner = CliRunner() result = runner.invoke(cli, ["embed-models", "default"])