Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(New e2e tests, fix retreival with return_source_documents=True): #23

Merged
merged 4 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
"examples.ex5:conversation",
"examples.ex6:conversation_with_summary",
"examples.ex7_agent:agent",
"examples.ex8:qa",
)
3 changes: 2 additions & 1 deletion examples/ex4.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os

from langchain import OpenAI, PromptTemplate
from langchain import PromptTemplate
from langchain.chains import LLMChain, LLMRequestsChain, SequentialChain
from langchain.llms import OpenAI

os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY", "sk-********")

Expand Down
6 changes: 3 additions & 3 deletions examples/ex5.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from langchain.chains import ConversationChain
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import (
ChatPromptTemplate,
Expand All @@ -11,14 +11,14 @@
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(
"The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know."
"The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know." # noqa
),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
]
)

llm = ChatOpenAI(temperature=0)
llm = OpenAI(temperature=0)
memory = ConversationBufferMemory(return_messages=True)
conversation = ConversationChain(memory=memory, prompt=prompt, llm=llm)

Expand Down
5 changes: 2 additions & 3 deletions examples/ex7_agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from langchain import LLMMathChain
from langchain.agents.tools import Tool
from langchain.chat_models import ChatOpenAI
from langchain.experimental.plan_and_execute import (
PlanAndExecute,
load_agent_executor,
load_chat_planner,
)
from langchain.llms import OpenAI
from langchain.agents.tools import Tool
from langchain import LLMMathChain


llm = OpenAI(temperature=0)
llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)
Expand Down
22 changes: 22 additions & 0 deletions examples/ex8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.embeddings.fake import FakeEmbeddings
from langchain.llms import OpenAI
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma

loader = TextLoader("Readme.md")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)

embeddings = FakeEmbeddings(size=1504)
docsearch = Chroma.from_documents(texts, embeddings)


qa = RetrievalQA.from_chain_type(
llm=OpenAI(),
chain_type="stuff",
retriever=docsearch.as_retriever(),
return_source_documents=True,
)
1 change: 1 addition & 0 deletions langcorn/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import fire
import uvicorn

from langcorn.server import api


Expand Down
20 changes: 17 additions & 3 deletions langcorn/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class LangResponse(BaseModel):
memory: list[Memory]


class LangResponseDocuments(LangResponse):
source_documents: list[str]


def authenticate_or_401(auth_token):
if not auth_token:
# Auth is not enabled.
Expand Down Expand Up @@ -85,14 +89,17 @@ def set_openai_key(new_key: str) -> str:
def make_handler(request_cls, chain):
async def handler(request: request_cls, http_request: Request):
llm_api_key = http_request.headers.get("x-llm-api-key")
retrieval_chain = len(chain.output_keys) > 1
try:
api_key = set_openai_key(llm_api_key)
run_params = request.dict()
memory = run_params.pop("memory", [])
if chain.memory and memory and memory[0]:
chain.memory.chat_memory.messages = messages_from_dict(memory)
output = chain.run(run_params)

if not retrieval_chain:
output = chain.run(run_params)
else:
output = chain(run_params)
# add error handling
memory = (
[]
Expand All @@ -103,6 +110,13 @@ async def handler(request: request_cls, http_request: Request):
raise HTTPException(status_code=500, detail=dict(error=str(e)))
finally:
set_openai_key(api_key)
if retrieval_chain:
return LangResponseDocuments(
output=output.get("result"),
error="",
memory=memory,
source_documents=[str(t) for t in output.get("source_documents")],
)
return LangResponse(output=output, error="", memory=memory)

return handler
Expand All @@ -125,7 +139,7 @@ def create_service(*lc_apps, auth_token: str = "", app: FastAPI = None):
inn, out = derive_fields(chain)
logger.debug(f"inputs:{inn=}")
logger.info(f"{lang_app=}:{chain.__class__.__name__}({inn})")
endpoint_prefix = lang_app.replace(":", '.')
endpoint_prefix = lang_app.replace(":", ".")
cls_name = "".join([c.capitalize() for c in endpoint_prefix.split(".")])
request_cls = derive_class(cls_name, inn, add_memory=chain.memory)
logger.debug(f"{request_cls=}")
Expand Down
60 changes: 59 additions & 1 deletion langcorn/server/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from unittest.mock import patch

import pytest
from fastapi.testclient import TestClient
from langchain.llms.fake import FakeListLLM

from examples import app

Expand All @@ -10,6 +13,20 @@
example_app = TestClient(app.app)


@pytest.fixture(autouse=True)
def suppress_openai():
llm = FakeListLLM(responses=["FakeListLLM" for i in range(100)])
with patch("langchain.llms.OpenAI._generate", new=llm._generate), patch(
"langchain.llms.OpenAI._agenerate", new=llm._agenerate
):
yield


@pytest.fixture(autouse=True)
def example_app():
yield TestClient(app.app)


@pytest.fixture(
scope="session",
)
Expand All @@ -18,7 +35,7 @@ def fn_executor():


class TestRoutes:
def test_examples(self):
def test_examples(self, example_app):
response = example_app.get("/")
assert response.status_code == 404

Expand All @@ -40,3 +57,44 @@ def test_create_service(self, apps):
client = TestClient(create_service(*apps))
response = client.get("/")
assert response.status_code == 404

def test_chain_x(self, suppress_openai, example_app):
response = example_app.post("/examples.ex8.qa/run", json=dict(query="query"))
assert response.status_code == 200, response.text
assert response.json() == {"error": "", "memory": [], "output": "FakeListLLM"}

@pytest.mark.parametrize(
"endpoint, query",
[
("/examples.ex1.chain/run", dict(product="QUERY")),
(
"/examples.ex2.chain/run",
dict(
input="QUERY",
url="https://github.com/msoedov/langcorn/blob/main/examples/ex7_agent.py",
),
),
# ("/examples.ex3.chain/run", dict(question="QUERY")), # requires llm response format
(
"/examples.ex4.sequential_chain/run",
dict(
query="QUERY",
url="https://github.com/msoedov/langcorn/blob/main/examples/ex7_agent.py",
),
),
(
"/examples.ex5.conversation/run",
dict(input="QUERY", history="", memory=[]),
),
(
"/examples.ex6.conversation_with_summary/run",
dict(input="QUERY", history="", memory=[]),
),
# ("/examples.ex7_agent.agent/run", dict(input="QUERY")), # requires llm response format
("/examples.ex8.qa/run", dict(query="QUERY")),
],
)
def test_chain_e2e(self, suppress_openai, example_app, endpoint, query):
response = example_app.post(endpoint, json=dict(**query))
assert response.status_code == 200, response.text
assert response.json()
Loading