From 952a7575b73834b267410b20de3cb6528a439961 Mon Sep 17 00:00:00 2001 From: davidmezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Thu, 19 Dec 2024 07:28:51 -0500 Subject: [PATCH] Add defaultrole to LLM pipeline, closes #841 --- docs/pipeline/text/llm.md | 4 ++++ src/python/txtai/pipeline/llm/generation.py | 9 +++++++-- src/python/txtai/pipeline/llm/llm.py | 5 +++-- src/python/txtai/pipeline/llm/rag.py | 1 + test/python/testpipeline/testllm/testlitellm.py | 3 +++ test/python/testpipeline/testllm/testllama.py | 3 +++ test/python/testpipeline/testllm/testllm.py | 8 ++++++++ 7 files changed, 29 insertions(+), 4 deletions(-) diff --git a/docs/pipeline/text/llm.md b/docs/pipeline/text/llm.md index 10990be56..4da905e2c 100644 --- a/docs/pipeline/text/llm.md +++ b/docs/pipeline/text/llm.md @@ -47,10 +47,14 @@ llm([ {"role": "user", "content": "Answer the following question..."} ]) +# Set the default role to user and string inputs are converted to chat messages +llm("Answer the following question...", defaultrole="user") ``` The LLM pipeline automatically detects the underlying LLM framework. This can also be manually set. +[Hugging Face Transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/abetlen/llama-cpp-python) and [hosted API models via LiteLLM](https://github.com/BerriAI/litellm) are all supported by this pipeline. + See the [LiteLLM documentation](https://litellm.vercel.app/docs/providers) for the options available with LiteLLM models. llama.cpp models support both local and remote GGUF paths on the HF Hub. ```python diff --git a/src/python/txtai/pipeline/llm/generation.py b/src/python/txtai/pipeline/llm/generation.py index 371d57e81..885ed4315 100644 --- a/src/python/txtai/pipeline/llm/generation.py +++ b/src/python/txtai/pipeline/llm/generation.py @@ -24,7 +24,7 @@ def __init__(self, path=None, template=None, **kwargs): self.template = template self.kwargs = kwargs - def __call__(self, text, maxlength, stream, stop, **kwargs): + def __call__(self, text, maxlength, stream, stop, defaultrole, **kwargs): """ Generates text. Supports the following input formats: @@ -36,6 +36,7 @@ def __call__(self, text, maxlength, stream, stop, **kwargs): maxlength: maximum sequence length stream: stream response if True, defaults to False stop: list of stop strings + defaultrole: default role to apply to text inputs (prompt for raw prompts (default) or user for user chat messages) kwargs: additional generation keyword arguments Returns: @@ -48,7 +49,11 @@ def __call__(self, text, maxlength, stream, stop, **kwargs): # Apply template, if necessary if self.template: formatter = TemplateFormatter() - texts = [formatter.format(self.template, text=x) for x in texts] + texts = [formatter.format(self.template, text=x) if isinstance(x, str) else x for x in texts] + + # Apply default role, if necessary + if defaultrole == "user": + texts = [[{"role": "user", "content": x}] if isinstance(x, str) else x for x in texts] # Run pipeline results = self.execute(texts, maxlength, stream, stop, **kwargs) diff --git a/src/python/txtai/pipeline/llm/llm.py b/src/python/txtai/pipeline/llm/llm.py index 91d65456a..98350f99e 100644 --- a/src/python/txtai/pipeline/llm/llm.py +++ b/src/python/txtai/pipeline/llm/llm.py @@ -38,7 +38,7 @@ def __init__(self, path=None, method=None, **kwargs): # Generation instance self.generator = GenerationFactory.create(path, method, **kwargs) - def __call__(self, text, maxlength=512, stream=False, stop=None, **kwargs): + def __call__(self, text, maxlength=512, stream=False, stop=None, defaultrole="prompt", **kwargs): """ Generates text. Supports the following input formats: @@ -50,6 +50,7 @@ def __call__(self, text, maxlength=512, stream=False, stop=None, **kwargs): maxlength: maximum sequence length stream: stream response if True, defaults to False stop: list of stop strings, defaults to None + defaultrole: default role to apply to text inputs (prompt for raw prompts (default) or user for user chat messages) kwargs: additional generation keyword arguments Returns: @@ -60,4 +61,4 @@ def __call__(self, text, maxlength=512, stream=False, stop=None, **kwargs): logger.debug(text) # Run LLM generation - return self.generator(text, maxlength, stream, stop, **kwargs) + return self.generator(text, maxlength, stream, stop, defaultrole, **kwargs) diff --git a/src/python/txtai/pipeline/llm/rag.py b/src/python/txtai/pipeline/llm/rag.py index a201ff220..2aee05f70 100644 --- a/src/python/txtai/pipeline/llm/rag.py +++ b/src/python/txtai/pipeline/llm/rag.py @@ -310,6 +310,7 @@ def answers(self, questions, contexts, **kwargs): Args: questions: questions contexts: question context + kwargs: additional keyword arguments to pass to model Returns: answers diff --git a/test/python/testpipeline/testllm/testlitellm.py b/test/python/testpipeline/testllm/testlitellm.py index 6cc5ccacf..381a19828 100644 --- a/test/python/testpipeline/testllm/testlitellm.py +++ b/test/python/testpipeline/testllm/testlitellm.py @@ -78,5 +78,8 @@ def testGeneration(self): model = LLM("huggingface/t5-small", api_base="http://127.0.0.1:8000") self.assertEqual(model("The sky is"), "blue") + # Test default role + self.assertEqual(model("The sky is", defaultrole="user"), "blue") + # Test streaming self.assertEqual(" ".join(x for x in model("The sky is", stream=True)), "blue") diff --git a/test/python/testpipeline/testllm/testllama.py b/test/python/testpipeline/testllm/testllama.py index 2dc4d007a..4d5361b8b 100644 --- a/test/python/testpipeline/testllm/testllama.py +++ b/test/python/testpipeline/testllm/testllama.py @@ -69,5 +69,8 @@ def testGeneration(self): messages = [{"role": "system", "content": "You are a helpful assistant. You answer math problems."}, {"role": "user", "content": "2+2?"}] self.assertIsNotNone(model(messages, maxlength=10, seed=0, stop=["."])) + # Test default role + self.assertIsNotNone(model("2 + 2 = ", maxlength=10, seed=0, stop=["."], defaultrole="user")) + # Test streaming self.assertEqual(" ".join(x for x in model("2 + 2 = ", maxlength=10, stream=True, seed=0, stop=["."]))[0], "4") diff --git a/test/python/testpipeline/testllm/testllm.py b/test/python/testpipeline/testllm/testllm.py index 24599fd2f..7c07e200c 100644 --- a/test/python/testpipeline/testllm/testllm.py +++ b/test/python/testpipeline/testllm/testllm.py @@ -54,6 +54,14 @@ def testCustomNotFound(self): with self.assertRaises(ImportError): LLM("hf-internal-testing/tiny-random-gpt2", method="notfound.generation") + def testDefaultRole(self): + """ + Test default role + """ + + model = LLM("hf-internal-testing/tiny-random-LlamaForCausalLM") + self.assertIsNotNone(model("Hello, how are", defaultrole="user")) + def testExternal(self): """ Test externally loaded model