From 0980a564f4ad4cfa2c9b3770f9eb258acdfaa817 Mon Sep 17 00:00:00 2001 From: Tanja Bunk Date: Fri, 20 Oct 2023 12:11:18 +0200 Subject: [PATCH 1/2] moved imports of langchain --- rasa/shared/utils/llm.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/rasa/shared/utils/llm.py b/rasa/shared/utils/llm.py index 5cf22b85883f..a1930051c9eb 100644 --- a/rasa/shared/utils/llm.py +++ b/rasa/shared/utils/llm.py @@ -2,16 +2,18 @@ import warnings import structlog -from langchain.embeddings.base import Embeddings -from langchain.llms.base import BaseLLM -from langchain.llms.loading import load_llm_from_config -from langchain.cache import SQLiteCache from rasa.shared.core.trackers import DialogueStateTracker from rasa.shared.core.events import BotUttered, UserUttered from rasa.shared.engine.caching import get_local_cache_location import rasa.shared.utils.io +import typing + +if typing.TYPE_CHECKING: + from langchain.embeddings.base import Embeddings + from langchain.llms.base import BaseLLM + structlogger = structlog.get_logger() @@ -122,6 +124,7 @@ def combine_custom_and_default_config( def ensure_cache() -> None: """Ensures that the cache is initialized.""" import langchain + from langchain.cache import SQLiteCache # ensure the cache directory exists cache_location = get_local_cache_location() @@ -133,7 +136,7 @@ def ensure_cache() -> None: def llm_factory( custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any] -) -> BaseLLM: +) -> "BaseLLM": """Creates an LLM from the given config. Args: @@ -144,6 +147,8 @@ def llm_factory( Returns: Instantiated LLM based on the configuration. """ + from langchain.llms.loading import load_llm_from_config + ensure_cache() config = combine_custom_and_default_config(custom_config, default_config) @@ -165,7 +170,7 @@ def llm_factory( def embedder_factory( custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any] -) -> Embeddings: +) -> "Embeddings": """Creates an Embedder from the given config. Args: @@ -176,6 +181,7 @@ def embedder_factory( Returns: Instantiated Embedder based on the configuration. """ + from langchain.embeddings.base import Embeddings from langchain.embeddings import ( CohereEmbeddings, HuggingFaceHubEmbeddings, From 9fabde3869c1724ad2bad68bd22874f5d86f3438 Mon Sep 17 00:00:00 2001 From: Tanja Bunk Date: Fri, 20 Oct 2023 13:21:00 +0200 Subject: [PATCH 2/2] improve imports --- rasa/shared/utils/llm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/rasa/shared/utils/llm.py b/rasa/shared/utils/llm.py index a1930051c9eb..aa2be1f99fac 100644 --- a/rasa/shared/utils/llm.py +++ b/rasa/shared/utils/llm.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Text, Type +from typing import Any, Dict, Optional, Text, Type, TYPE_CHECKING import warnings import structlog @@ -8,9 +8,7 @@ from rasa.shared.engine.caching import get_local_cache_location import rasa.shared.utils.io -import typing - -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from langchain.embeddings.base import Embeddings from langchain.llms.base import BaseLLM