-
Notifications
You must be signed in to change notification settings - Fork 15.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Harrison/guarded output parser (#1804)
Co-authored-by: jerwelborn <jeremy.welborn@gmail.com>
- Loading branch information
1 parent
8fa1764
commit ce5d97b
Showing
17 changed files
with
567 additions
and
75 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from langchain.chains.llm import LLMChain | ||
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT | ||
from langchain.prompts.base import BasePromptTemplate | ||
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException | ||
|
||
|
||
class OutputFixingParser(BaseOutputParser): | ||
"""Wraps a parser and tries to fix parsing errors.""" | ||
|
||
parser: BaseOutputParser | ||
retry_chain: LLMChain | ||
|
||
@classmethod | ||
def from_llm( | ||
cls, | ||
llm: BaseLanguageModel, | ||
parser: BaseOutputParser, | ||
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT, | ||
) -> OutputFixingParser: | ||
chain = LLMChain(llm=llm, prompt=prompt) | ||
return cls(parser=parser, retry_chain=chain) | ||
|
||
def parse(self, completion: str) -> Any: | ||
try: | ||
parsed_completion = self.parser.parse(completion) | ||
except OutputParserException as e: | ||
new_completion = self.retry_chain.run( | ||
instructions=self.parser.get_format_instructions(), | ||
completion=completion, | ||
error=repr(e), | ||
) | ||
parsed_completion = self.parser.parse(new_completion) | ||
|
||
return parsed_completion | ||
|
||
def get_format_instructions(self) -> str: | ||
return self.parser.get_format_instructions() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# flake8: noqa | ||
from langchain.prompts.prompt import PromptTemplate | ||
|
||
NAIVE_FIX = """Instructions: | ||
-------------- | ||
{instructions} | ||
-------------- | ||
Completion: | ||
-------------- | ||
{completion} | ||
-------------- | ||
Above, the Completion did not satisfy the constraints given in the Instructions. | ||
Error: | ||
-------------- | ||
{error} | ||
-------------- | ||
Please try again. Please only respond with an answer that satisfies the constraints laid out in the Instructions:""" | ||
|
||
|
||
NAIVE_FIX_PROMPT = PromptTemplate.from_template(NAIVE_FIX) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from langchain.chains.llm import LLMChain | ||
from langchain.prompts.base import BasePromptTemplate | ||
from langchain.prompts.prompt import PromptTemplate | ||
from langchain.schema import ( | ||
BaseLanguageModel, | ||
BaseOutputParser, | ||
OutputParserException, | ||
PromptValue, | ||
) | ||
|
||
NAIVE_COMPLETION_RETRY = """Prompt: | ||
{prompt} | ||
Completion: | ||
{completion} | ||
Above, the Completion did not satisfy the constraints given in the Prompt. | ||
Please try again:""" | ||
|
||
NAIVE_COMPLETION_RETRY_WITH_ERROR = """Prompt: | ||
{prompt} | ||
Completion: | ||
{completion} | ||
Above, the Completion did not satisfy the constraints given in the Prompt. | ||
Details: {error} | ||
Please try again:""" | ||
|
||
NAIVE_RETRY_PROMPT = PromptTemplate.from_template(NAIVE_COMPLETION_RETRY) | ||
NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template( | ||
NAIVE_COMPLETION_RETRY_WITH_ERROR | ||
) | ||
|
||
|
||
class RetryOutputParser(BaseOutputParser): | ||
"""Wraps a parser and tries to fix parsing errors. | ||
Does this by passing the original prompt and the completion to another | ||
LLM, and telling it the completion did not satisfy criteria in the prompt. | ||
""" | ||
|
||
parser: BaseOutputParser | ||
retry_chain: LLMChain | ||
|
||
@classmethod | ||
def from_llm( | ||
cls, | ||
llm: BaseLanguageModel, | ||
parser: BaseOutputParser, | ||
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT, | ||
) -> RetryOutputParser: | ||
chain = LLMChain(llm=llm, prompt=prompt) | ||
return cls(parser=parser, retry_chain=chain) | ||
|
||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any: | ||
try: | ||
parsed_completion = self.parser.parse(completion) | ||
except OutputParserException: | ||
new_completion = self.retry_chain.run( | ||
prompt=prompt_value.to_string(), completion=completion | ||
) | ||
parsed_completion = self.parser.parse(new_completion) | ||
|
||
return parsed_completion | ||
|
||
def parse(self, completion: str) -> Any: | ||
raise NotImplementedError( | ||
"This OutputParser can only be called by the `parse_with_prompt` method." | ||
) | ||
|
||
def get_format_instructions(self) -> str: | ||
return self.parser.get_format_instructions() | ||
|
||
|
||
class RetryWithErrorOutputParser(BaseOutputParser): | ||
"""Wraps a parser and tries to fix parsing errors. | ||
Does this by passing the original prompt, the completion, AND the error | ||
that was raised to another language and telling it that the completion | ||
did not work, and raised the given error. Differs from RetryOutputParser | ||
in that this implementation provides the error that was raised back to the | ||
LLM, which in theory should give it more information on how to fix it. | ||
""" | ||
|
||
parser: BaseOutputParser | ||
retry_chain: LLMChain | ||
|
||
@classmethod | ||
def from_llm( | ||
cls, | ||
llm: BaseLanguageModel, | ||
parser: BaseOutputParser, | ||
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT, | ||
) -> RetryWithErrorOutputParser: | ||
chain = LLMChain(llm=llm, prompt=prompt) | ||
return cls(parser=parser, retry_chain=chain) | ||
|
||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any: | ||
try: | ||
parsed_completion = self.parser.parse(completion) | ||
except OutputParserException as e: | ||
new_completion = self.retry_chain.run( | ||
prompt=prompt_value.to_string(), completion=completion, error=repr(e) | ||
) | ||
parsed_completion = self.parser.parse(new_completion) | ||
|
||
return parsed_completion | ||
|
||
def parse(self, completion: str) -> Any: | ||
raise NotImplementedError( | ||
"This OutputParser can only be called by the `parse_with_prompt` method." | ||
) | ||
|
||
def get_format_instructions(self) -> str: | ||
return self.parser.get_format_instructions() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.