From b366b74c7da4e53db4d9e2f592829ed365529fec Mon Sep 17 00:00:00 2001 From: Shroominic Date: Mon, 12 Feb 2024 11:33:49 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=80=20move=20parser=20base=20model?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/funcchain/parser/custom.py | 38 +++++++++++++++++++++++++--- src/funcchain/syntax/output_types.py | 33 +----------------------- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/funcchain/parser/custom.py b/src/funcchain/parser/custom.py index dd79a44..d9c6e5b 100644 --- a/src/funcchain/parser/custom.py +++ b/src/funcchain/parser/custom.py @@ -1,12 +1,44 @@ import json +import re from typing import Type, TypeVar from langchain_core.exceptions import OutputParserException -from langchain_core.output_parsers import BaseOutputParser -from pydantic import ValidationError +from langchain_core.output_parsers import BaseLLMOutputParser, BaseOutputParser +from pydantic import BaseModel, ValidationError +from typing_extensions import Self from ..syntax.output_types import CodeBlock as CodeBlock -from ..syntax.output_types import ParserBaseModel + + +class ParserBaseModel(BaseModel): + @classmethod + def output_parser(cls) -> BaseLLMOutputParser[Self]: + from ..parser.custom import CustomPydanticOutputParser + + return CustomPydanticOutputParser(pydantic_object=cls) + + @classmethod + def parse(cls, text: str) -> Self: + """Override for custom parsing.""" + match = re.search(r"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL) + json_str = "" + if match: + json_str = match.group() + json_object = json.loads(json_str, strict=False) + return cls.model_validate(json_object) + + @staticmethod + def format_instructions() -> str: + return ( + "Please respond with a json result matching the following schema:" + "\n\n```schema\n{schema}\n```\n" + "Do not repeat the schema. Only respond with the result." + ) + + @staticmethod + def custom_grammar() -> str | None: + return None + P = TypeVar("P", bound=ParserBaseModel) diff --git a/src/funcchain/syntax/output_types.py b/src/funcchain/syntax/output_types.py index 829a3e0..059c4f6 100644 --- a/src/funcchain/syntax/output_types.py +++ b/src/funcchain/syntax/output_types.py @@ -1,41 +1,10 @@ -import json import re from typing import Optional from langchain_core.exceptions import OutputParserException -from langchain_core.output_parsers import BaseLLMOutputParser from pydantic import BaseModel, Field -from typing_extensions import Self - -class ParserBaseModel(BaseModel): - @classmethod - def output_parser(cls) -> BaseLLMOutputParser[Self]: - from ..parser.custom import CustomPydanticOutputParser - - return CustomPydanticOutputParser(pydantic_object=cls) - - @classmethod - def parse(cls, text: str) -> Self: - """Override for custom parsing.""" - match = re.search(r"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL) - json_str = "" - if match: - json_str = match.group() - json_object = json.loads(json_str, strict=False) - return cls.model_validate(json_object) - - @staticmethod - def format_instructions() -> str: - return ( - "Please respond with a json result matching the following schema:" - "\n\n```schema\n{schema}\n```\n" - "Do not repeat the schema. Only respond with the result." - ) - - @staticmethod - def custom_grammar() -> str | None: - return None +from ..parser.custom import ParserBaseModel as ParserBaseModel class CodeBlock(ParserBaseModel):