diff --git a/langchain/output_parsers/rail_parser.py b/langchain/output_parsers/rail_parser.py index 0dab50d9a894c..cd02b337330cd 100644 --- a/langchain/output_parsers/rail_parser.py +++ b/langchain/output_parsers/rail_parser.py @@ -1,19 +1,29 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any, Callable, Dict, Optional from langchain.schema import BaseOutputParser class GuardrailsOutputParser(BaseOutputParser): guard: Any + api: Optional[Callable] + args: Any + kwargs: Any @property def _type(self) -> str: return "guardrails" @classmethod - def from_rail(cls, rail_file: str, num_reasks: int = 1) -> GuardrailsOutputParser: + def from_rail( + cls, + rail_file: str, + num_reasks: int = 1, + api: Optional[Callable] = None, + *args: Any, + **kwargs: Any, + ) -> GuardrailsOutputParser: try: from guardrails import Guard except ImportError: @@ -21,11 +31,21 @@ def from_rail(cls, rail_file: str, num_reasks: int = 1) -> GuardrailsOutputParse "guardrails-ai package not installed. " "Install it by running `pip install guardrails-ai`." ) - return cls(guard=Guard.from_rail(rail_file, num_reasks=num_reasks)) + return cls( + guard=Guard.from_rail(rail_file, num_reasks=num_reasks), + api=api, + args=args, + kwargs=kwargs, + ) @classmethod def from_rail_string( - cls, rail_str: str, num_reasks: int = 1 + cls, + rail_str: str, + num_reasks: int = 1, + api: Optional[Callable] = None, + *args: Any, + **kwargs: Any, ) -> GuardrailsOutputParser: try: from guardrails import Guard @@ -34,10 +54,15 @@ def from_rail_string( "guardrails-ai package not installed. " "Install it by running `pip install guardrails-ai`." ) - return cls(guard=Guard.from_rail_string(rail_str, num_reasks=num_reasks)) + return cls( + guard=Guard.from_rail_string(rail_str, num_reasks=num_reasks), + api=api, + args=args, + kwargs=kwargs, + ) def get_format_instructions(self) -> str: return self.guard.raw_prompt.format_instructions def parse(self, text: str) -> Dict: - return self.guard.parse(text) + return self.guard.parse(text, llm_api=self.api, *self.args, **self.kwargs)