diff --git a/rasa/utils/llm.py b/rasa/utils/llm.py new file mode 100644 index 000000000000..cc22682019eb --- /dev/null +++ b/rasa/utils/llm.py @@ -0,0 +1,260 @@ +from typing import Any, Dict, Optional, Text, Type, TYPE_CHECKING +import warnings + +import structlog + +from rasa.shared.core.events import BotUttered, UserUttered +from rasa.shared.engine.caching import get_local_cache_location +import rasa.shared.utils.io + +if TYPE_CHECKING: + from langchain.schema.embeddings import Embeddings + from langchain.llms.base import BaseLLM + from rasa.shared.core.trackers import DialogueStateTracker + +from rasa.shared.constants import ( + RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG, + RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_EMPTY, +) + +structlogger = structlog.get_logger() + +USER = "USER" + +AI = "AI" + +DEFAULT_OPENAI_GENERATE_MODEL_NAME = "gpt-3.5-turbo" + +DEFAULT_OPENAI_CHAT_MODEL_NAME = "gpt-3.5-turbo" + +DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED = "gpt-4" + +DEFAULT_OPENAI_EMBEDDING_MODEL_NAME = "text-embedding-ada-002" + +DEFAULT_OPENAI_TEMPERATURE = 0.7 + +DEFAULT_OPENAI_MAX_GENERATED_TOKENS = 256 + +DEFAULT_MAX_USER_INPUT_CHARACTERS = 420 + + +# Placeholder messages used in the transcript for +# instances where user input results in an error +ERROR_PLACEHOLDER = { + RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG: "[User sent really long message]", + RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_EMPTY: "", + "default": "[User input triggered an error]", +} + + +def tracker_as_readable_transcript( + tracker: "DialogueStateTracker", + human_prefix: str = USER, + ai_prefix: str = AI, + max_turns: Optional[int] = 20, +) -> str: + """Creates a readable dialogue from a tracker. + + Args: + tracker: the tracker to convert + human_prefix: the prefix to use for human utterances + ai_prefix: the prefix to use for ai utterances + max_turns: the maximum number of turns to include in the transcript + + Example: + >>> tracker = Tracker( + ... sender_id="test", + ... slots=[], + ... events=[ + ... UserUttered("hello"), + ... BotUttered("hi"), + ... ], + ... ) + >>> tracker_as_readable_transcript(tracker) + USER: hello + AI: hi + + Returns: + A string representing the transcript of the tracker + """ + transcript = [] + + for event in tracker.events: + + if isinstance(event, UserUttered): + if event.has_triggered_error: + first_error = event.error_commands[0] + error_type = first_error.get("error_type") + message = ERROR_PLACEHOLDER.get( + error_type, ERROR_PLACEHOLDER["default"] + ) + else: + message = sanitize_message_for_prompt(event.text) + transcript.append(f"{human_prefix}: {message}") + + elif isinstance(event, BotUttered): + transcript.append(f"{ai_prefix}: {sanitize_message_for_prompt(event.text)}") + + if max_turns: + transcript = transcript[-max_turns:] + + return "\n".join(transcript) + + +def sanitize_message_for_prompt(text: Optional[str]) -> str: + """Removes new lines from a string. + + Args: + text: the text to sanitize + + Returns: + A string with new lines removed. + """ + return text.replace("\n", " ") if text else "" + + +def combine_custom_and_default_config( + custom_config: Optional[Dict[Text, Any]], default_config: Dict[Text, Any] +) -> Dict[Text, Any]: + """Merges the given llm config with the default config. + + Only uses the default configuration arguments, if the type set in the + custom config matches the type in the default config. Otherwise, only + the custom config is used. + + Args: + custom_config: The custom config containing values to overwrite defaults + default_config: The default config. + + Returns: + The merged config. + """ + if custom_config is None: + return default_config + + if "type" in custom_config: + # rename type to _type as "type" is the convention we use + # across the different components in config files. + # langchain expects "_type" as the key though + custom_config["_type"] = custom_config.pop("type") + + if "_type" in custom_config and custom_config["_type"] != default_config.get( + "_type" + ): + return custom_config + return {**default_config, **custom_config} + + +def ensure_cache() -> None: + """Ensures that the cache is initialized.""" + import langchain + from langchain.cache import SQLiteCache + + # ensure the cache directory exists + cache_location = get_local_cache_location() + cache_location.mkdir(parents=True, exist_ok=True) + + db_location = cache_location / "rasa-llm-cache.db" + langchain.llm_cache = SQLiteCache(database_path=str(db_location)) + + +def llm_factory( + custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any] +) -> "BaseLLM": + """Creates an LLM from the given config. + + Args: + custom_config: The custom config containing values to overwrite defaults + default_config: The default config. + + + Returns: + Instantiated LLM based on the configuration. + """ + from langchain.llms.loading import load_llm_from_config + + ensure_cache() + + config = combine_custom_and_default_config(custom_config, default_config) + + # need to create a copy as the langchain function modifies the + # config in place... + structlogger.debug("llmfactory.create.llm", config=config) + # langchain issues a user warning when using chat models. at the same time + # it doesn't provide a way to instantiate a chat model directly using the + # config. so for now, we need to suppress the warning here. Original + # warning: + # packages/langchain/llms/openai.py:189: UserWarning: You are trying to + # use a chat model. This way of initializing it is no longer supported. + # Instead, please use: `from langchain.chat_models import ChatOpenAI + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + return load_llm_from_config(config.copy()) + + +def embedder_factory( + custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any] +) -> "Embeddings": + """Creates an Embedder from the given config. + + Args: + custom_config: The custom config containing values to overwrite defaults + default_config: The default config. + + + Returns: + Instantiated Embedder based on the configuration. + """ + from langchain.schema.embeddings import Embeddings + from langchain.embeddings import ( + CohereEmbeddings, + HuggingFaceHubEmbeddings, + HuggingFaceInstructEmbeddings, + LlamaCppEmbeddings, + OpenAIEmbeddings, + SpacyEmbeddings, + VertexAIEmbeddings, + ) + + type_to_embedding_cls_dict: Dict[str, Type[Embeddings]] = { + "openai": OpenAIEmbeddings, + "cohere": CohereEmbeddings, + "spacy": SpacyEmbeddings, + "vertexai": VertexAIEmbeddings, + "huggingface_instruct": HuggingFaceInstructEmbeddings, + "huggingface_hub": HuggingFaceHubEmbeddings, + "llamacpp": LlamaCppEmbeddings, + } + + config = combine_custom_and_default_config(custom_config, default_config) + typ = config.get("_type") + + structlogger.debug("llmfactory.create.embedder", config=config) + + if not typ: + return OpenAIEmbeddings() + elif embeddings_cls := type_to_embedding_cls_dict.get(typ): + parameters = config.copy() + parameters.pop("_type") + return embeddings_cls(**parameters) + else: + raise ValueError(f"Unsupported embeddings type '{typ}'") + + +def get_prompt_template( + jinja_file_path: Optional[Text], default_prompt_template: Text +) -> Text: + """Returns the prompt template. + + Args: + jinja_file_path: the path to the jinja file + default_prompt_template: the default prompt template + + Returns: + The prompt template. + """ + return ( + rasa.shared.utils.io.read_file(jinja_file_path) + if jinja_file_path is not None + else default_prompt_template + ) diff --git a/tests/utils/test_llm.py b/tests/utils/test_llm.py new file mode 100644 index 000000000000..d2065a2730fb --- /dev/null +++ b/tests/utils/test_llm.py @@ -0,0 +1,255 @@ +from typing import Text, Any, Dict +from rasa.shared.constants import ( + RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG, + RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_EMPTY, +) +from rasa.shared.core.domain import Domain +from rasa.shared.core.events import BotUttered, UserUttered +from rasa.shared.core.trackers import DialogueStateTracker +from rasa.shared.utils.llm import ( + sanitize_message_for_prompt, + tracker_as_readable_transcript, + embedder_factory, + llm_factory, + ERROR_PLACEHOLDER, +) +from langchain import OpenAI +from langchain.embeddings import OpenAIEmbeddings +import pytest +from pytest import MonkeyPatch + + +def test_tracker_as_readable_transcript_handles_empty_tracker(): + tracker = DialogueStateTracker(sender_id="test", slots=[]) + assert tracker_as_readable_transcript(tracker) == "" + + +def test_tracker_as_readable_transcript_handles_tracker_with_events(domain: Domain): + tracker = DialogueStateTracker(sender_id="test", slots=domain.slots) + tracker.update_with_events( + [ + UserUttered("hello"), + BotUttered("hi"), + ], + ) + assert tracker_as_readable_transcript(tracker) == ("""USER: hello\nAI: hi""") + + +def test_tracker_as_readable_transcript_handles_tracker_with_events_and_prefixes( + domain: Domain, +): + tracker = DialogueStateTracker(sender_id="test", slots=domain.slots) + tracker.update_with_events( + [ + UserUttered("hello"), + BotUttered("hi"), + ], + domain, + ) + assert tracker_as_readable_transcript( + tracker, human_prefix="FOO", ai_prefix="BAR" + ) == ("""FOO: hello\nBAR: hi""") + + +def test_tracker_as_readable_transcript_handles_tracker_with_events_and_max_turns( + domain: Domain, +): + tracker = DialogueStateTracker(sender_id="test", slots=domain.slots) + tracker.update_with_events( + [ + UserUttered("hello"), + BotUttered("hi"), + ], + domain, + ) + assert tracker_as_readable_transcript(tracker, max_turns=1) == ("""AI: hi""") + + +def test_tracker_as_readable_transcript_and_discard_excess_turns_with_default_max_turns( + domain: Domain, +): + tracker = DialogueStateTracker(sender_id="test", slots=domain.slots) + tracker.update_with_events( + [ + UserUttered("A0"), + BotUttered("B1"), + UserUttered("C2"), + BotUttered("D3"), + UserUttered("E4"), + BotUttered("F5"), + UserUttered("G6"), + BotUttered("H7"), + UserUttered("I8"), + BotUttered("J9"), + UserUttered("K10"), + BotUttered("L11"), + UserUttered("M12"), + BotUttered("N13"), + UserUttered("O14"), + BotUttered("P15"), + UserUttered("Q16"), + BotUttered("R17"), + UserUttered("S18"), + BotUttered("T19"), + UserUttered("U20"), + BotUttered("V21"), + UserUttered("W22"), + BotUttered("X23"), + UserUttered("Y24"), + ], + domain, + ) + response = tracker_as_readable_transcript(tracker) + assert response == ( + """AI: F5\nUSER: G6\nAI: H7\nUSER: I8\nAI: J9\nUSER: K10\nAI: L11\n""" + """USER: M12\nAI: N13\nUSER: O14\nAI: P15\nUSER: Q16\nAI: R17\nUSER: S18\n""" + """AI: T19\nUSER: U20\nAI: V21\nUSER: W22\nAI: X23\nUSER: Y24""" + ) + assert response.count("\n") == 19 + + +@pytest.mark.parametrize( + "message, command, expected_response", + [ + ( + "Very long message", + { + "command": "error", + "error_type": RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG, + }, + ERROR_PLACEHOLDER[RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG], + ), + ( + "", + { + "command": "error", + "error_type": RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_EMPTY, + }, + ERROR_PLACEHOLDER[RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_EMPTY], + ), + ], +) +def test_tracker_as_readable_transcript_with_messages_that_triggered_error( + message: Text, + command: Dict[Text, Any], + expected_response: Text, + domain: Domain, +): + # Given + tracker = DialogueStateTracker(sender_id="test", slots=domain.slots) + tracker.update_with_events( + [ + UserUttered("Hi"), + BotUttered("Hi, how can I help you"), + UserUttered(text=message, parse_data={"commands": [command]}), + BotUttered("Error response"), + ] + ) + # When + response = tracker_as_readable_transcript(tracker) + # Then + assert response == ( + f"USER: Hi\n" + f"AI: Hi, how can I help you\n" + f"USER: {expected_response}\n" + f"AI: Error response" + ) + assert response.count("\n") == 3 + + +def test_sanitize_message_for_prompt_handles_none(): + assert sanitize_message_for_prompt(None) == "" + + +def test_sanitize_message_for_prompt_handles_empty_string(): + assert sanitize_message_for_prompt("") == "" + + +def test_sanitize_message_for_prompt_handles_string_with_newlines(): + assert sanitize_message_for_prompt("hello\nworld") == "hello world" + + +def test_llm_factory(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "test") + + llm = llm_factory(None, {"_type": "openai"}) + assert isinstance(llm, OpenAI) + + +def test_llm_factory_handles_type_without_underscore(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "test") + + llm = llm_factory({"type": "openai"}, {}) + assert isinstance(llm, OpenAI) + + +def test_llm_factory_uses_custom_type(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "test") + + llm = llm_factory({"type": "openai"}, {"_type": "foobar"}) + assert isinstance(llm, OpenAI) + + +def test_llm_factory_ignores_irrelevant_default_args(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "test") + + # since the types of the custom config and the default are different + # all default arguments should be removed. + llm = llm_factory({"type": "openai"}, {"_type": "foobar", "temperature": -1}) + assert isinstance(llm, OpenAI) + # since the default argument should be removed, this should be the default - + # which is not -1 + assert llm.temperature != -1 + + +def test_llm_factory_fails_on_invalid_args(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "test") + + # since the types of the custom config and the default are the same + # all default arguments should be kept. since the "foo" argument + # is not a valid argument for the OpenAI class, this should fail + llm = llm_factory({"type": "openai"}, {"_type": "openai", "temperature": -1}) + assert isinstance(llm, OpenAI) + # since the default argument should NOT be removed, this should be -1 now + assert llm.temperature == -1 + + +def test_llm_factory_uses_additional_args_from_custom(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "test") + + llm = llm_factory({"temperature": -1}, {"_type": "openai"}) + assert isinstance(llm, OpenAI) + assert llm.temperature == -1 + + +def test_embedder_factory(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "test") + + embedder = embedder_factory(None, {"_type": "openai"}) + assert isinstance(embedder, OpenAIEmbeddings) + + +def test_embedder_factory_handles_type_without_underscore( + monkeypatch: MonkeyPatch, +) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "test") + + embedder = embedder_factory({"type": "openai"}, {}) + assert isinstance(embedder, OpenAIEmbeddings) + + +def test_embedder_factory_uses_custom_type(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "test") + + embedder = embedder_factory({"type": "openai"}, {"_type": "foobar"}) + assert isinstance(embedder, OpenAIEmbeddings) + + +def test_embedder_factory_ignores_irrelevant_default_args( + monkeypatch: MonkeyPatch, +) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "test") + + # embedders don't expect args, they should just be ignored + embedder = embedder_factory({"type": "openai"}, {"_type": "foobar", "foo": "bar"}) + assert isinstance(embedder, OpenAIEmbeddings)