Skip to content

Commit

Permalink
feat(New e2e tests, fix retreival with return_source_documents=True):
Browse files Browse the repository at this point in the history
  • Loading branch information
msoedov committed Jul 12, 2023
1 parent 6c1aa0c commit 0a017e0
Show file tree
Hide file tree
Showing 6 changed files with 1,185 additions and 7 deletions.
1 change: 1 addition & 0 deletions examples/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
"examples.ex5:conversation",
"examples.ex6:conversation_with_summary",
"examples.ex7_agent:agent",
"examples.ex8:qa",
)
22 changes: 22 additions & 0 deletions examples/ex8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.embeddings.fake import FakeEmbeddings
from langchain.llms import OpenAI
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma

loader = TextLoader("Readme.md")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)

embeddings = FakeEmbeddings(size=1504)
docsearch = Chroma.from_documents(texts, embeddings)


qa = RetrievalQA.from_chain_type(
llm=OpenAI(),
chain_type="stuff",
retriever=docsearch.as_retriever(),
return_source_documents=True,
)
18 changes: 16 additions & 2 deletions langcorn/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class LangResponse(BaseModel):
memory: list[Memory]


class LangResponseDocuments(LangResponse):
source_documents: list[str]


def authenticate_or_401(auth_token):
if not auth_token:
# Auth is not enabled.
Expand Down Expand Up @@ -85,14 +89,17 @@ def set_openai_key(new_key: str) -> str:
def make_handler(request_cls, chain):
async def handler(request: request_cls, http_request: Request):
llm_api_key = http_request.headers.get("x-llm-api-key")
retrieval_chain = len(chain.output_keys) > 1
try:
api_key = set_openai_key(llm_api_key)
run_params = request.dict()
memory = run_params.pop("memory", [])
if chain.memory and memory and memory[0]:
chain.memory.chat_memory.messages = messages_from_dict(memory)
output = chain.run(run_params)

if not retrieval_chain:
output = chain.run(run_params)
else:
output = chain(run_params)
# add error handling
memory = (
[]
Expand All @@ -103,6 +110,13 @@ async def handler(request: request_cls, http_request: Request):
raise HTTPException(status_code=500, detail=dict(error=str(e)))
finally:
set_openai_key(api_key)
if retrieval_chain:
return LangResponseDocuments(
output=output.get("result"),
error="",
memory=memory,
source_documents=[str(t) for t in output.get("source_documents")],
)
return LangResponse(output=output, error="", memory=memory)

return handler
Expand Down
61 changes: 58 additions & 3 deletions langcorn/server/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch

import pytest
from examples import app
from fastapi.testclient import TestClient
from langchain.llms.fake import FakeListLLM

from .api import create_service

Expand All @@ -10,6 +12,18 @@
example_app = TestClient(app.app)


@pytest.fixture(autouse=True)
def suppress_openai():
llm = FakeListLLM(responses=["FakeListLLM" for i in range(100)])
with patch("langchain.llms.OpenAI._generate", new=llm._generate):
yield


@pytest.fixture(autouse=True)
def example_app():
yield TestClient(app.app)


@pytest.fixture(
scope="session",
)
Expand All @@ -18,7 +32,7 @@ def fn_executor():


class TestRoutes:
def test_examples(self):
def test_examples(self, example_app):
response = example_app.get("/")
assert response.status_code == 404

Expand All @@ -40,3 +54,44 @@ def test_create_service(self, apps):
client = TestClient(create_service(*apps))
response = client.get("/")
assert response.status_code == 404

def test_chain_x(self, suppress_openai, example_app):
response = example_app.post("/examples.ex8.qa/run", json=dict(query="query"))
assert response.status_code == 200, response.text
assert response.json() == {"error": "", "memory": [], "output": "FakeListLLM"}

@pytest.mark.parametrize(
"endpoint, query",
[
("/examples.ex1.chain/run", dict(product="QUERY")),
(
"/examples.ex2.chain/run",
dict(
input="QUERY",
url="https://github.com/msoedov/langcorn/blob/main/examples/ex7_agent.py",
),
),
# ("/examples.ex3.chain/run", dict(question="QUERY")),
(
"/examples.ex4.sequential_chain/run",
dict(
query="QUERY",
url="https://github.com/msoedov/langcorn/blob/main/examples/ex7_agent.py",
),
),
(
"/examples.ex5.conversation/run",
dict(input="QUERY", history="", memory=[]),
),
(
"/examples.ex6.conversation_with_summary/run",
dict(input="QUERY", history="", memory=[]),
),
("/examples.ex7_agent.agent/run", dict(input="QUERY")),
("/examples.ex8.qa/run", dict(query="QUERY")),
],
)
def test_chain_e2e(self, suppress_openai, example_app, endpoint, query):
response = example_app.post(endpoint, json=dict(**query))
assert response.status_code == 200, response.text
assert response.json()
Loading

0 comments on commit 0a017e0

Please sign in to comment.