Skip to content

Commit

Permalink
🔧 fix input dict type
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Feb 12, 2024
1 parent be77287 commit 07c0e25
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 13 deletions.
10 changes: 5 additions & 5 deletions src/funcchain/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def create_union_chain(
context: list[BaseMessage],
llm: BaseChatModel,
input_kwargs: dict[str, Any],
) -> Runnable[dict[str, str], Any]:
) -> Runnable[dict[str, Any], Any]:
"""
Compile a langchain runnable chain from the funcchain syntax.
"""
Expand Down Expand Up @@ -78,7 +78,7 @@ def create_union_chain(
def patch_openai_function_to_pydantic(
llm: BaseChatModel,
output_type: type[BaseModel],
input_kwargs: dict[str, str],
input_kwargs: dict[str, Any],
primitive_type: bool = False,
) -> tuple[BaseChatModel, BaseGenerationOutputParser]:
input_kwargs["format_instructions"] = f"Extract to {output_type.__name__}."
Expand All @@ -101,7 +101,7 @@ def create_chain(
settings: FuncchainSettings,
input_args: list[tuple[str, type]],
temp_images: list[Image] = [],
) -> Runnable[dict[str, str], ChainOutput]:
) -> Runnable[dict[str, Any], ChainOutput]:
"""
Compile a langchain runnable chain from the funcchain syntax.
"""
Expand Down Expand Up @@ -209,7 +209,7 @@ def create_chain(
return chat_prompt | llm | parser


def compile_chain(signature: Signature, temp_images: list[Image] = []) -> Runnable[dict[str, str], ChainOutput]:
def compile_chain(signature: Signature, temp_images: list[Image] = []) -> Runnable[dict[str, Any], ChainOutput]:
"""
Compile a signature to a runnable chain.
"""
Expand All @@ -236,7 +236,7 @@ def compile_chain(signature: Signature, temp_images: list[Image] = []) -> Runnab
def _add_format_instructions(
parser: BaseOutputParser,
instruction: str,
input_kwargs: dict[str, str],
input_kwargs: dict[str, Any],
) -> tuple[str, str | None]:
"""
Add parsing format instructions
Expand Down
4 changes: 2 additions & 2 deletions src/funcchain/backend/meta_inspect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from inspect import FrameInfo, currentframe, getouterframes
from types import FunctionType, UnionType
from typing import Optional
from typing import Any, Optional

FUNC_DEPTH = 4

Expand Down Expand Up @@ -53,7 +53,7 @@ def get_output_types(f: Optional[FunctionType] = None) -> list[type]:
raise ValueError("The funcchain must have a return type annotation")


def kwargs_from_parent() -> dict[str, str]:
def kwargs_from_parent() -> dict[str, Any]:
"""
Get the kwargs from the parent function.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/funcchain/syntax/components/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def runnable(self) -> RunnableSerializable[HumanMessage, AIMessage]:
runnables={name: run["handler"] for name, run in self.routes.items()},
) # maybe add auto conversion of strings to AI Messages/Chunks

def _selector(self) -> Runnable[dict[str, str], Any]:
def _selector(self) -> Runnable[dict[str, Any], Any]:
RouteChoices = Enum( # type: ignore
"RouteChoices",
{r: r for r in self.routes.keys()},
Expand Down
6 changes: 3 additions & 3 deletions src/funcchain/syntax/decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import FunctionType
from typing import Callable, Optional, TypeVar, Union, overload
from typing import Any, Callable, Optional, TypeVar, Union, overload

from langchain_core.runnables import Runnable

Expand All @@ -15,7 +15,7 @@
@overload
def runnable(
f: Callable[..., OutputT],
) -> Runnable[dict[str, str], OutputT]:
) -> Runnable[dict[str, Any], OutputT]:
...


Expand All @@ -25,7 +25,7 @@ def runnable(
llm: UniversalChatModel = None,
settings: SettingsOverride = {},
auto_tune: bool = False,
) -> Callable[[Callable], Runnable[dict[str, str], OutputT]]:
) -> Callable[[Callable], Runnable[dict[str, Any], OutputT]]:
...


Expand Down
4 changes: 2 additions & 2 deletions src/funcchain/syntax/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def achain(
history=context,
settings=settings,
)
chain: Runnable[dict[str, str], Any] = compile_chain(sig, temp_images)
chain: Runnable[dict[str, Any], Any] = compile_chain(sig, temp_images)
result = await chain.ainvoke(input_kwargs, {"run_name": get_parent_frame(2).function, "callbacks": callbacks})

if memory and isinstance(result, str):
Expand All @@ -132,7 +132,7 @@ def compile_runnable(
llm: UniversalChatModel = None,
system: str = "",
settings_override: SettingsOverride = {},
) -> Runnable[dict[str, str], ChainOut]:
) -> Runnable[dict[str, Any], ChainOut]:
"""
On the fly compilation of the funcchain syntax.
"""
Expand Down

0 comments on commit 07c0e25

Please sign in to comment.