Skip to content

Commit

Permalink
Harrison/guarded output parser (#1804)
Browse files Browse the repository at this point in the history
Co-authored-by: jerwelborn <jeremy.welborn@gmail.com>
  • Loading branch information
hwchase17 and jerwelborn authored Mar 22, 2023
1 parent 8fa1764 commit ce5d97b
Show file tree
Hide file tree
Showing 17 changed files with 567 additions and 75 deletions.
343 changes: 321 additions & 22 deletions docs/modules/prompts/examples/output_parsers.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion langchain/agents/conversational_chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.output_parsers.base import BaseOutputParser
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
Expand All @@ -26,6 +25,7 @@
AIMessage,
BaseLanguageModel,
BaseMessage,
BaseOutputParser,
HumanMessage,
)
from langchain.tools.base import BaseTool
Expand Down
7 changes: 5 additions & 2 deletions langchain/output_parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.fix import OutputFixingParser
from langchain.output_parsers.list import (
CommaSeparatedListOutputParser,
ListOutputParser,
Expand All @@ -7,16 +7,19 @@
from langchain.output_parsers.rail_parser import GuardrailsOutputParser
from langchain.output_parsers.regex import RegexParser
from langchain.output_parsers.regex_dict import RegexDictParser
from langchain.output_parsers.retry import RetryOutputParser, RetryWithErrorOutputParser
from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser

__all__ = [
"RegexParser",
"RegexDictParser",
"ListOutputParser",
"CommaSeparatedListOutputParser",
"BaseOutputParser",
"StructuredOutputParser",
"ResponseSchema",
"GuardrailsOutputParser",
"PydanticOutputParser",
"RetryOutputParser",
"RetryWithErrorOutputParser",
"OutputFixingParser",
]
28 changes: 0 additions & 28 deletions langchain/output_parsers/base.py

This file was deleted.

41 changes: 41 additions & 0 deletions langchain/output_parsers/fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from typing import Any

from langchain.chains.llm import LLMChain
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException


class OutputFixingParser(BaseOutputParser):
"""Wraps a parser and tries to fix parsing errors."""

parser: BaseOutputParser
retry_chain: LLMChain

@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
) -> OutputFixingParser:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)

def parse(self, completion: str) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = self.retry_chain.run(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
parsed_completion = self.parser.parse(new_completion)

return parsed_completion

def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()
5 changes: 4 additions & 1 deletion langchain/output_parsers/format_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
}}
```"""

PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below. For example, the object {{"foo": ["bar", "baz"]}} conforms to the schema {{"foo": {{"description": "a list of strings field", "type": "string"}}}}.
PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}}}
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
Here is the output schema:
```
Expand Down
2 changes: 1 addition & 1 deletion langchain/output_parsers/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import abstractmethod
from typing import List

from langchain.output_parsers.base import BaseOutputParser
from langchain.schema import BaseOutputParser


class ListOutputParser(BaseOutputParser):
Expand Down
22 changes: 22 additions & 0 deletions langchain/output_parsers/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate

NAIVE_FIX = """Instructions:
--------------
{instructions}
--------------
Completion:
--------------
{completion}
--------------
Above, the Completion did not satisfy the constraints given in the Instructions.
Error:
--------------
{error}
--------------
Please try again. Please only respond with an answer that satisfies the constraints laid out in the Instructions:"""


NAIVE_FIX_PROMPT = PromptTemplate.from_template(NAIVE_FIX)
17 changes: 10 additions & 7 deletions langchain/output_parsers/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from pydantic import BaseModel, ValidationError

from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException


class PydanticOutputParser(BaseOutputParser):
Expand All @@ -14,7 +14,9 @@ class PydanticOutputParser(BaseOutputParser):
def parse(self, text: str) -> BaseModel:
try:
# Greedy search for 1st json candidate.
match = re.search("\{.*\}", text.strip())
match = re.search(
"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL
)
json_str = ""
if match:
json_str = match.group()
Expand All @@ -24,16 +26,17 @@ def parse(self, text: str) -> BaseModel:
except (json.JSONDecodeError, ValidationError) as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {text}. Got: {e}"
raise ValueError(msg)
raise OutputParserException(msg)

def get_format_instructions(self) -> str:
schema = self.pydantic_object.schema()

# Remove extraneous fields.
reduced_schema = {
prop: {"description": data["description"], "type": data["type"]}
for prop, data in schema["properties"].items()
}
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema = json.dumps(reduced_schema)

Expand Down
2 changes: 1 addition & 1 deletion langchain/output_parsers/rail_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Dict

from langchain.output_parsers.base import BaseOutputParser
from langchain.schema import BaseOutputParser


class GuardrailsOutputParser(BaseOutputParser):
Expand Down
2 changes: 1 addition & 1 deletion langchain/output_parsers/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic import BaseModel

from langchain.output_parsers.base import BaseOutputParser
from langchain.schema import BaseOutputParser


class RegexParser(BaseOutputParser, BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion langchain/output_parsers/regex_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic import BaseModel

from langchain.output_parsers.base import BaseOutputParser
from langchain.schema import BaseOutputParser


class RegexDictParser(BaseOutputParser, BaseModel):
Expand Down
118 changes: 118 additions & 0 deletions langchain/output_parsers/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

from typing import Any

from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
BaseLanguageModel,
BaseOutputParser,
OutputParserException,
PromptValue,
)

NAIVE_COMPLETION_RETRY = """Prompt:
{prompt}
Completion:
{completion}
Above, the Completion did not satisfy the constraints given in the Prompt.
Please try again:"""

NAIVE_COMPLETION_RETRY_WITH_ERROR = """Prompt:
{prompt}
Completion:
{completion}
Above, the Completion did not satisfy the constraints given in the Prompt.
Details: {error}
Please try again:"""

NAIVE_RETRY_PROMPT = PromptTemplate.from_template(NAIVE_COMPLETION_RETRY)
NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
NAIVE_COMPLETION_RETRY_WITH_ERROR
)


class RetryOutputParser(BaseOutputParser):
"""Wraps a parser and tries to fix parsing errors.
Does this by passing the original prompt and the completion to another
LLM, and telling it the completion did not satisfy criteria in the prompt.
"""

parser: BaseOutputParser
retry_chain: LLMChain

@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
) -> RetryOutputParser:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)

def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException:
new_completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion
)
parsed_completion = self.parser.parse(new_completion)

return parsed_completion

def parse(self, completion: str) -> Any:
raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method."
)

def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()


class RetryWithErrorOutputParser(BaseOutputParser):
"""Wraps a parser and tries to fix parsing errors.
Does this by passing the original prompt, the completion, AND the error
that was raised to another language and telling it that the completion
did not work, and raised the given error. Differs from RetryOutputParser
in that this implementation provides the error that was raised back to the
LLM, which in theory should give it more information on how to fix it.
"""

parser: BaseOutputParser
retry_chain: LLMChain

@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
) -> RetryWithErrorOutputParser:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)

def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
)
parsed_completion = self.parser.parse(new_completion)

return parsed_completion

def parse(self, completion: str) -> Any:
raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method."
)

def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()
4 changes: 2 additions & 2 deletions langchain/output_parsers/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from pydantic import BaseModel

from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException

line_template = '\t"{name}": {type} // {description}'

Expand Down Expand Up @@ -42,7 +42,7 @@ def parse(self, text: str) -> BaseModel:
json_obj = json.loads(json_string)
for schema in self.response_schemas:
if schema.name not in json_obj:
raise ValueError(
raise OutputParserException(
f"Got invalid return object. Expected key `{schema.name}` "
f"to be present, but got {json_obj}"
)
Expand Down
8 changes: 1 addition & 7 deletions langchain/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@
from pydantic import BaseModel, Extra, Field, root_validator

from langchain.formatting import formatter
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.list import ( # noqa: F401
CommaSeparatedListOutputParser,
ListOutputParser,
)
from langchain.output_parsers.regex import RegexParser # noqa: F401
from langchain.schema import BaseMessage, HumanMessage, PromptValue
from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue


def jinja2_formatter(template: str, **kwargs: Any) -> str:
Expand Down
Loading

0 comments on commit ce5d97b

Please sign in to comment.