Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

works on windows #410

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions llm/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import click
from click_default_group import DefaultGroup
from dataclasses import asdict
Expand Down Expand Up @@ -342,8 +344,11 @@ def chat(
Hold an ongoing chat with a model.
"""
# Left and right arrow keys to move cursor:
readline.parse_and_bind("\\e[D: backward-char")
readline.parse_and_bind("\\e[C: forward-char")
if os.name != "nt":
# I'm pretty sure this can't be done without win32 API calls.
# pyreadline is unsupported and no longer python 3.12 compatible.
readline.parse_and_bind("\\e[D: backward-char")
readline.parse_and_bind("\\e[C: forward-char")
log_path = logs_db_path()
(log_path.parent).mkdir(parents=True, exist_ok=True)
db = sqlite_utils.Database(log_path)
Expand Down
5 changes: 0 additions & 5 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import llm.cli
from unittest.mock import ANY
import pytest
import sys


def test_mock_model(mock_model):
Expand All @@ -17,7 +16,6 @@ def test_mock_model(mock_model):
assert response2.text() == "second"


@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
def test_chat_basic(mock_model, logs_db):
runner = CliRunner()
mock_model.enqueue(["one world"])
Expand Down Expand Up @@ -114,7 +112,6 @@ def test_chat_basic(mock_model, logs_db):
]


@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
def test_chat_system(mock_model, logs_db):
runner = CliRunner()
mock_model.enqueue(["I am mean"])
Expand Down Expand Up @@ -151,7 +148,6 @@ def test_chat_system(mock_model, logs_db):
]


@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
def test_chat_options(mock_model, logs_db):
runner = CliRunner()
mock_model.enqueue(["Some text"])
Expand Down Expand Up @@ -179,7 +175,6 @@ def test_chat_options(mock_model, logs_db):
]


@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
@pytest.mark.parametrize(
"input,expected",
(
Expand Down
10 changes: 5 additions & 5 deletions tests/test_embed_cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os

from click.testing import CliRunner
from llm.cli import cli
from llm import Collection
import json
import pathlib
import pytest
import sqlite_utils
import sys
from unittest.mock import ANY


Expand Down Expand Up @@ -423,7 +424,6 @@ def multi_files(tmpdir):
return db_path, tmpdir / "files"


@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
@pytest.mark.parametrize("scenario", ("single", "multi"))
def test_embed_multi_files(multi_files, scenario):
db_path, files = multi_files
Expand Down Expand Up @@ -473,9 +473,9 @@ def test_embed_multi_files(multi_files, scenario):
assert rows == [
{"id": "file1.txt", "content": "hello world"},
{"id": "file2.txt", "content": "goodbye world"},
{"id": "nested/more/three.txt", "content": "three"},
{"id": "nested/one.txt", "content": "one"},
{"id": "nested/two.txt", "content": "two"},
{"id": f"nested{os.sep}more{os.sep}three.txt", "content": "three"},
{"id": f"nested{os.sep}one.txt", "content": "one"},
{"id": f"nested{os.sep}two.txt", "content": "two"},
]
else:
assert rows == [
Expand Down
18 changes: 11 additions & 7 deletions tests/test_keys.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os
from pathlib import Path
from click.testing import CliRunner
import json
from llm.cli import cli
import pathlib
import pytest
import sys


@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
@pytest.mark.parametrize("env", ({}, {"LLM_USER_PATH": "/tmp/llm-keys-test"}))
def test_keys_in_user_path(monkeypatch, env, user_path):
for key, value in env.items():
Expand All @@ -15,13 +15,13 @@ def test_keys_in_user_path(monkeypatch, env, user_path):
result = runner.invoke(cli, ["keys", "path"])
assert result.exit_code == 0
if env:
expected = env["LLM_USER_PATH"] + "/keys.json"
expected_path = Path(env["LLM_USER_PATH"]) / "keys.json"
else:
expected = user_path + "/keys.json"
expected_path = Path(user_path) / "keys.json"
expected = str(expected_path)
assert result.output.strip() == expected


@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
def test_keys_set(monkeypatch, tmpdir):
user_path = tmpdir / "user/keys"
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
Expand All @@ -31,8 +31,12 @@ def test_keys_set(monkeypatch, tmpdir):
result = runner.invoke(cli, ["keys", "set", "openai"], input="foo")
assert result.exit_code == 0
assert keys_path.exists()
# Should be chmod 600
assert oct(keys_path.stat().mode)[-3:] == "600"
if os.name != "nt":
# Should be chmod 600
assert oct(keys_path.stat().mode)[-3:] == "600"
else:
# Windows file permissions don't work that way.
pass
content = keys_path.read_text("utf-8")
assert json.loads(content) == {
"// Note": "This file stores secret API credentials. Do not share!",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from click.testing import CliRunner
import datetime
import llm
Expand All @@ -8,7 +9,6 @@
import pytest
import re
import sqlite_utils
import sys
from ulid import ULID
from unittest import mock

Expand Down Expand Up @@ -97,7 +97,6 @@ def test_logs_json(n, log_path):
assert len(logs) == expected_length


@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
@pytest.mark.parametrize("env", ({}, {"LLM_USER_PATH": "/tmp/llm-user-path"}))
def test_logs_path(monkeypatch, env, user_path):
for key, value in env.items():
Expand All @@ -106,9 +105,10 @@ def test_logs_path(monkeypatch, env, user_path):
result = runner.invoke(cli, ["logs", "path"])
assert result.exit_code == 0
if env:
expected = env["LLM_USER_PATH"] + "/logs.db"
expected_path = Path(env["LLM_USER_PATH"]) / "logs.db"
else:
expected = str(user_path) + "/logs.db"
expected_path = Path(user_path) / "logs.db"
expected = str(expected_path)
assert result.output.strip() == expected


Expand Down