Skip to content

Commit

Permalink
Bugfix for PR 496 to add format_video_name function (#602)
Browse files Browse the repository at this point in the history
* add format_video_name

Signed-off-by: BaoHuiling <huiling.bao@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add test for negative case

Signed-off-by: BaoHuiling <huiling.bao@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update file path

Signed-off-by: BaoHuiling <huiling.bao@intel.com>

---------

Signed-off-by: BaoHuiling <huiling.bao@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
BaoHuiling and pre-commit-ci[bot] authored Sep 4, 2024
1 parent b873cf8 commit 54aa943
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 18 deletions.
60 changes: 45 additions & 15 deletions comps/reranks/video-rag-qna/local_reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@

import logging
import os
import re
import time

from fastapi import HTTPException

from comps import (
LVMVideoDoc,
SearchedMultimodalDoc,
Expand Down Expand Up @@ -52,6 +55,25 @@ def find_timestamp_from_video(metadata_list, video):
)


def format_video_name(video_name):
# Check for an existing file extension
match = re.search(r"\.(\w+)$", video_name)

if match:
extension = match.group(1)
# If the extension is not 'mp4', raise an error
if extension != "mp4":
raise ValueError(f"Invalid file extension: .{extension}. Only '.mp4' is allowed.")

# Use regex to remove any suffix after the base name (e.g., '_interval_0', etc.)
base_name = re.sub(r"(_interval_\d+)?(\.mp4)?$", "", video_name)

# Add the '.mp4' extension
formatted_name = f"{base_name}.mp4"

return formatted_name


@register_microservice(
name="opea_service@reranking_visual_rag",
service_type=ServiceType.RERANK,
Expand All @@ -64,22 +86,30 @@ def find_timestamp_from_video(metadata_list, video):
@register_statistics(names=["opea_service@reranking_visual_rag"])
def reranking(input: SearchedMultimodalDoc) -> LVMVideoDoc:
start = time.time()
try:
# get top video name from metadata
video_names = [meta["video"] for meta in input.metadata]
top_video_names = get_top_doc(input.top_n, video_names)

# only use the first top video
timestamp = find_timestamp_from_video(input.metadata, top_video_names[0])
formatted_video_name = format_video_name(top_video_names[0])
video_url = f"{file_server_endpoint.rstrip('/')}/{formatted_video_name}"

result = LVMVideoDoc(
video_url=video_url,
prompt=input.initial_query,
chunk_start=timestamp,
chunk_duration=float(chunk_duration),
max_new_tokens=512,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logging.error(f"Unexpected error in reranking: {str(e)}")
# Handle any other exceptions with a generic server error response
raise HTTPException(status_code=500, detail="An unexpected error occurred.")

# get top video name from metadata
video_names = [meta["video"] for meta in input.metadata]
top_video_names = get_top_doc(input.top_n, video_names)

# only use the first top video
timestamp = find_timestamp_from_video(input.metadata, top_video_names[0])
video_url = f"{file_server_endpoint.rstrip('/')}/{top_video_names[0]}"

result = LVMVideoDoc(
video_url=video_url,
prompt=input.initial_query,
chunk_start=timestamp,
chunk_duration=float(chunk_duration),
max_new_tokens=512,
)
statistics_dict["opea_service@reranking_visual_rag"].append_latency(time.time() - start, None)

return result
Expand Down
2 changes: 1 addition & 1 deletion comps/retrievers/langchain/vdms/retriever_vdms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
if use_clip:
import sys

sys.path.append("../../../embeddings/langchain_multimodal/")
sys.path.append("../../../embeddings/multimodal_clip/")
from embeddings_clip import vCLIP

# Debugging
Expand Down
25 changes: 23 additions & 2 deletions tests/test_reranks_video-rag-qna.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,30 @@ function validate_microservice() {
]
}')
if [[ $result == *"this is the query"* ]]; then
echo "Result correct."
echo "Result correct for the positive case."
else
echo "Result wrong."
echo "Result wrong for the positive case. Received was $result"
exit 1
fi

# Add test for negative case
result=$(\
http_proxy="" \
curl -X 'POST' \
"http://${ip_address}:5037/v1/reranking" \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"retrieved_docs": [
{"doc": [{"text": "this is the retrieved text"}]}
],
"initial_query": "this is the query",
"top_n": 1,
"metadata": [{"other_key": "value", "video":"top_video_name_bad_format.avi", "timestamp":"20"}]}')
if [[ $result == *"Invalid file extension"* ]]; then
echo "Result correct for the negative case."
else
echo "Result wrong for the negative case. Received was $result"
exit 1
fi
}
Expand Down

0 comments on commit 54aa943

Please sign in to comment.