Skip to content

Commit

Permalink
Add Megaservice support for MMRAG - MultimodalRAGQnAWithVideos usecase (
Browse files Browse the repository at this point in the history
#626)

* updates

Signed-off-by: Tiep Le <tiep.le@intel.com>

* cosmetic

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* update redis schema

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* update multimodal config and docker compose retriever

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* update requirements

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* update retriever redis

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* multimodal retriever implementation

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* test for multimodal retriever

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* include prompt preparation for multimodal rag on videos application

Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com>

* fix template

Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com>

* add test for llava for mm_rag_on_videos

Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com>

* update test

Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* first update on gateaway

Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com>

* fix index not found

Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com>

* add LVMSearchedMultimodalDoc

Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement gateway for MultimodalRagQnAWithVideos

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* remove INDEX_SCHEMA

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* update MultimodalRAGQnAWithVideosGateway with 2 megaservices

Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com>

* revise folder structure to comps/retrievers/langchain/redis_multimodal

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* update test

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* add unittest for multimodalrag_qna_with_videos_gateway

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* update test mmrag qna with videos

Signed-off-by: Tiep Le <tiep.le@intel.com>

* change port of redis to resolve CI test

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* update test

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* update lvms test

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* update test

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* update test

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* update test for multimodal rag qna with videos gateway

Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add more test to increase coverage

Signed-off-by: Tiep Le <tiep.le@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cosmetic

Signed-off-by: Tiep Le <tiep.le@intel.com>

* add more test

Signed-off-by: Tiep Le <tiep.le@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update name of gateway

Signed-off-by: Tiep Le <tiep.le@intel.com>

---------

Signed-off-by: Tiep Le <tiep.le@intel.com>
Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com>
Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com>
Co-authored-by: siddhivelankar23 <siddhi.velankar@intel.com>
Co-authored-by: sjagtap1803 <siddhant.jagtap@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sihan Chen <39623753+Spycsh@users.noreply.github.com>
  • Loading branch information
5 people authored Sep 6, 2024
1 parent 2705e93 commit 99be1bd
Show file tree
Hide file tree
Showing 4 changed files with 374 additions and 0 deletions.
1 change: 1 addition & 0 deletions comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
RetrievalToolGateway,
FaqGenGateway,
VisualQnAGateway,
MultimodalRAGWithVideosGateway,
)

# Telemetry
Expand Down
1 change: 1 addition & 0 deletions comps/cores/mega/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class MegaServiceEndpoint(Enum):
CODE_TRANS = "/v1/codetrans"
DOC_SUMMARY = "/v1/docsum"
SEARCH_QNA = "/v1/searchqna"
MULTIMODAL_RAG_WITH_VIDEOS = "/v1/mmragvideoqna"
TRANSLATION = "/v1/translation"
RETRIEVALTOOL = "/v1/retrievaltool"
FAQ_GEN = "/v1/faqgen"
Expand Down
157 changes: 157 additions & 0 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _handle_message(self, messages):
messages_dict[msg_role] = message["content"]
else:
raise ValueError(f"Unknown role: {msg_role}")

if system_prompt:
prompt = system_prompt + "\n"
for role, message in messages_dict.items():
Expand Down Expand Up @@ -582,3 +583,159 @@ def parser_input(data, TypeClass, key):
response = result_dict[last_node]
print("response is ", response)
return response


class MultimodalRAGWithVideosGateway(Gateway):
def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", port=9999):
self.lvm_megaservice = lvm_megaservice
super().__init__(
multimodal_rag_megaservice,
host,
port,
str(MegaServiceEndpoint.MULTIMODAL_RAG_WITH_VIDEOS),
ChatCompletionRequest,
ChatCompletionResponse,
)

# this overrides _handle_message method of Gateway
def _handle_message(self, messages):
images = []
messages_dicts = []
if isinstance(messages, str):
prompt = messages
else:
messages_dict = {}
system_prompt = ""
prompt = ""
for message in messages:
msg_role = message["role"]
messages_dict = {}
if msg_role == "system":
system_prompt = message["content"]
elif msg_role == "user":
if type(message["content"]) == list:
text = ""
text_list = [item["text"] for item in message["content"] if item["type"] == "text"]
text += "\n".join(text_list)
image_list = [
item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url"
]
if image_list:
messages_dict[msg_role] = (text, image_list)
else:
messages_dict[msg_role] = text
else:
messages_dict[msg_role] = message["content"]
messages_dicts.append(messages_dict)
elif msg_role == "assistant":
messages_dict[msg_role] = message["content"]
messages_dicts.append(messages_dict)
else:
raise ValueError(f"Unknown role: {msg_role}")

if system_prompt:
prompt = system_prompt + "\n"
for messages_dict in messages_dicts:
for i, (role, message) in enumerate(messages_dict.items()):
if isinstance(message, tuple):
text, image_list = message
if i == 0:
# do not add role for the very first message.
# this will be added by llava_server
if text:
prompt += text + "\n"
else:
if text:
prompt += role.upper() + ": " + text + "\n"
else:
prompt += role.upper() + ":"
for img in image_list:
# URL
if img.startswith("http://") or img.startswith("https://"):
response = requests.get(img)
image = Image.open(BytesIO(response.content)).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Local Path
elif os.path.exists(img):
image = Image.open(img).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Bytes
else:
img_b64_str = img

images.append(img_b64_str)
else:
if i == 0:
# do not add role for the very first message.
# this will be added by llava_server
if message:
prompt += role.upper() + ": " + message + "\n"
else:
if message:
prompt += role.upper() + ": " + message + "\n"
else:
prompt += role.upper() + ":"
if images:
return prompt, images
else:
return prompt

async def handle_request(self, request: Request):
data = await request.json()
stream_opt = bool(data.get("stream", False))
if stream_opt:
print("[ MultimodalRAGWithVideosGateway ] stream=True not used, this has not support streaming yet!")
stream_opt = False
chat_request = ChatCompletionRequest.model_validate(data)
# Multimodal RAG QnA With Videos has not yet accepts image as input during QnA.
prompt_and_image = self._handle_message(chat_request.messages)
if isinstance(prompt_and_image, tuple):
# print(f"This request include image, thus it is a follow-up query. Using lvm megaservice")
prompt, images = prompt_and_image
cur_megaservice = self.lvm_megaservice
initial_inputs = {"prompt": prompt, "image": images[0]}
else:
# print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice")
prompt = prompt_and_image
cur_megaservice = self.megaservice
initial_inputs = {"text": prompt}

parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
result_dict, runtime_graph = await cur_megaservice.schedule(
initial_inputs=initial_inputs, llm_parameters=parameters
)
for node, response in result_dict.items():
# the last microservice in this megaservice is LVM.
# checking if LVM returns StreamingResponse
# Currently, LVM with LLAVA has not yet supported streaming.
# @TODO: Will need to test this once LVM with LLAVA supports streaming
if (
isinstance(response, StreamingResponse)
and node == runtime_graph.all_leaves()[-1]
and self.megaservice.services[node].service_type == ServiceType.LVM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="multimodalragwithvideos", choices=choices, usage=usage)
215 changes: 215 additions & 0 deletions tests/cores/mega/test_multimodalrag_with_videos_gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import unittest
from typing import Union

import requests
from fastapi import Request

from comps import (
EmbedDoc,
EmbedMultimodalDoc,
LVMDoc,
LVMSearchedMultimodalDoc,
MultimodalDoc,
MultimodalRAGWithVideosGateway,
SearchedMultimodalDoc,
ServiceOrchestrator,
TextDoc,
opea_microservices,
register_microservice,
)


@register_microservice(name="mm_embedding", host="0.0.0.0", port=8083, endpoint="/v1/mm_embedding")
async def mm_embedding_add(request: MultimodalDoc) -> EmbedDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
res = {}
res["text"] = text
res["embedding"] = [0.12, 0.45]
return res


@register_microservice(name="mm_retriever", host="0.0.0.0", port=8084, endpoint="/v1/mm_retriever")
async def mm_retriever_add(request: EmbedMultimodalDoc) -> SearchedMultimodalDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
res = {}
res["retrieved_docs"] = []
res["initial_query"] = text
res["top_n"] = 1
res["metadata"] = [
{
"b64_img_str": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC",
"transcript_for_inference": "yellow image",
}
]
res["chat_template"] = "The caption of the image is: '{context}'. {question}"
return res


@register_microservice(name="lvm", host="0.0.0.0", port=8085, endpoint="/v1/lvm")
async def lvm_add(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
if isinstance(request, LVMSearchedMultimodalDoc):
print("request is the output of multimodal retriever")
text = req_dict["initial_query"]
text += "opea project!"

else:
print("request is from user.")
text = req_dict["prompt"]
text = f"<image>\nUSER: {text}\nASSISTANT:"

res = {}
res["text"] = text
return res


class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
cls.mm_embedding = opea_microservices["mm_embedding"]
cls.mm_retriever = opea_microservices["mm_retriever"]
cls.lvm = opea_microservices["lvm"]
cls.mm_embedding.start()
cls.mm_retriever.start()
cls.lvm.start()

cls.service_builder = ServiceOrchestrator()

cls.service_builder.add(opea_microservices["mm_embedding"]).add(opea_microservices["mm_retriever"]).add(
opea_microservices["lvm"]
)
cls.service_builder.flow_to(cls.mm_embedding, cls.mm_retriever)
cls.service_builder.flow_to(cls.mm_retriever, cls.lvm)

cls.follow_up_query_service_builder = ServiceOrchestrator()
cls.follow_up_query_service_builder.add(cls.lvm)

cls.gateway = MultimodalRAGWithVideosGateway(
cls.service_builder, cls.follow_up_query_service_builder, port=9898
)

@classmethod
def tearDownClass(cls):
cls.mm_embedding.stop()
cls.mm_retriever.stop()
cls.lvm.stop()
cls.gateway.stop()

async def test_service_builder_schedule(self):
result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "})
self.assertEqual(result_dict[self.lvm.name]["text"], "hello, opea project!")

async def test_follow_up_query_service_builder_schedule(self):
result_dict, _ = await self.follow_up_query_service_builder.schedule(
initial_inputs={"prompt": "chao, ", "image": "some image"}
)
# print(result_dict)
self.assertEqual(result_dict[self.lvm.name]["text"], "<image>\nUSER: chao, \nASSISTANT:")

def test_multimodal_rag_with_videos_gateway(self):
json_data = {"messages": "hello, "}
response = requests.post("http://0.0.0.0:9898/v1/mmragvideoqna", json=json_data)
response = response.json()
self.assertEqual(response["choices"][-1]["message"]["content"], "hello, opea project!")

def test_follow_up_mm_rag_with_videos_gateway(self):
json_data = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello, "},
{
"type": "image_url",
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
},
],
},
{"role": "assistant", "content": "opea project! "},
{"role": "user", "content": "chao, "},
],
"max_tokens": 300,
}
response = requests.post("http://0.0.0.0:9898/v1/mmragvideoqna", json=json_data)
response = response.json()
self.assertEqual(
response["choices"][-1]["message"]["content"],
"<image>\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:",
)

def test_handle_message(self):
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "hello, "},
{
"type": "image_url",
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
},
],
},
{"role": "assistant", "content": "opea project! "},
{"role": "user", "content": "chao, "},
]
prompt, images = self.gateway._handle_message(messages)
self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: chao, \n")

def test_handle_message_with_system_prompt(self):
messages = [
{"role": "system", "content": "System Prompt"},
{
"role": "user",
"content": [
{"type": "text", "text": "hello, "},
{
"type": "image_url",
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
},
],
},
{"role": "assistant", "content": "opea project! "},
{"role": "user", "content": "chao, "},
]
prompt, images = self.gateway._handle_message(messages)
self.assertEqual(prompt, "System Prompt\nhello, \nASSISTANT: opea project! \nUSER: chao, \n")

async def test_handle_request(self):
json_data = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello, "},
{
"type": "image_url",
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
},
],
},
{"role": "assistant", "content": "opea project! "},
{"role": "user", "content": "chao, "},
],
"max_tokens": 300,
}
mock_request = Request(scope={"type": "http"})
mock_request._json = json_data
res = await self.gateway.handle_request(mock_request)
res = json.loads(res.json())
self.assertEqual(
res["choices"][-1]["message"]["content"],
"<image>\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 99be1bd

Please sign in to comment.