Skip to content

Commit

Permalink
[ENG-586] Moved imports of langchain to speed up rasa --help (#12930)
Browse files Browse the repository at this point in the history
* moved imports of langchain

* improve imports
  • Loading branch information
tabergma committed Oct 26, 2023
1 parent ed35201 commit e0f2cca
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions rasa/shared/utils/llm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import Any, Dict, Optional, Text, Type
from typing import Any, Dict, Optional, Text, Type, TYPE_CHECKING
import warnings

import structlog
from langchain.embeddings.base import Embeddings
from langchain.llms.base import BaseLLM
from langchain.llms.loading import load_llm_from_config
from langchain.cache import SQLiteCache

from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.events import BotUttered, UserUttered
from rasa.shared.engine.caching import get_local_cache_location
import rasa.shared.utils.io

if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
from langchain.llms.base import BaseLLM


structlogger = structlog.get_logger()

Expand Down Expand Up @@ -122,6 +122,7 @@ def combine_custom_and_default_config(
def ensure_cache() -> None:
"""Ensures that the cache is initialized."""
import langchain
from langchain.cache import SQLiteCache

# ensure the cache directory exists
cache_location = get_local_cache_location()
Expand All @@ -133,7 +134,7 @@ def ensure_cache() -> None:

def llm_factory(
custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
) -> BaseLLM:
) -> "BaseLLM":
"""Creates an LLM from the given config.
Args:
Expand All @@ -144,6 +145,8 @@ def llm_factory(
Returns:
Instantiated LLM based on the configuration.
"""
from langchain.llms.loading import load_llm_from_config

ensure_cache()

config = combine_custom_and_default_config(custom_config, default_config)
Expand All @@ -165,7 +168,7 @@ def llm_factory(

def embedder_factory(
custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
) -> Embeddings:
) -> "Embeddings":
"""Creates an Embedder from the given config.
Args:
Expand All @@ -176,6 +179,7 @@ def embedder_factory(
Returns:
Instantiated Embedder based on the configuration.
"""
from langchain.embeddings.base import Embeddings
from langchain.embeddings import (
CohereEmbeddings,
HuggingFaceHubEmbeddings,
Expand Down

0 comments on commit e0f2cca

Please sign in to comment.