From e0f2cca4ac26269952ef81769c84f1636c206aa7 Mon Sep 17 00:00:00 2001 From: Tanja Date: Thu, 26 Oct 2023 09:59:49 +0200 Subject: [PATCH] [ENG-586] Moved imports of langchain to speed up rasa --help (#12930) * moved imports of langchain * improve imports --- rasa/shared/utils/llm.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/rasa/shared/utils/llm.py b/rasa/shared/utils/llm.py index 5cf22b85883f..aa2be1f99fac 100644 --- a/rasa/shared/utils/llm.py +++ b/rasa/shared/utils/llm.py @@ -1,17 +1,17 @@ -from typing import Any, Dict, Optional, Text, Type +from typing import Any, Dict, Optional, Text, Type, TYPE_CHECKING import warnings import structlog -from langchain.embeddings.base import Embeddings -from langchain.llms.base import BaseLLM -from langchain.llms.loading import load_llm_from_config -from langchain.cache import SQLiteCache from rasa.shared.core.trackers import DialogueStateTracker from rasa.shared.core.events import BotUttered, UserUttered from rasa.shared.engine.caching import get_local_cache_location import rasa.shared.utils.io +if TYPE_CHECKING: + from langchain.embeddings.base import Embeddings + from langchain.llms.base import BaseLLM + structlogger = structlog.get_logger() @@ -122,6 +122,7 @@ def combine_custom_and_default_config( def ensure_cache() -> None: """Ensures that the cache is initialized.""" import langchain + from langchain.cache import SQLiteCache # ensure the cache directory exists cache_location = get_local_cache_location() @@ -133,7 +134,7 @@ def ensure_cache() -> None: def llm_factory( custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any] -) -> BaseLLM: +) -> "BaseLLM": """Creates an LLM from the given config. Args: @@ -144,6 +145,8 @@ def llm_factory( Returns: Instantiated LLM based on the configuration. """ + from langchain.llms.loading import load_llm_from_config + ensure_cache() config = combine_custom_and_default_config(custom_config, default_config) @@ -165,7 +168,7 @@ def llm_factory( def embedder_factory( custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any] -) -> Embeddings: +) -> "Embeddings": """Creates an Embedder from the given config. Args: @@ -176,6 +179,7 @@ def embedder_factory( Returns: Instantiated Embedder based on the configuration. """ + from langchain.embeddings.base import Embeddings from langchain.embeddings import ( CohereEmbeddings, HuggingFaceHubEmbeddings,