Skip to content

Commit

Permalink
BUG FIX: LVM security fix (#572)
Browse files Browse the repository at this point in the history
* add url validator

Signed-off-by: BaoHuiling <huiling.bao@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add validation for video_url

Signed-off-by: BaoHuiling <huiling.bao@intel.com>

---------

Signed-off-by: BaoHuiling <huiling.bao@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
BaoHuiling and pre-commit-ci[bot] authored Aug 28, 2024
1 parent e38ed6d commit 3e548f3
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 34 deletions.
2 changes: 2 additions & 0 deletions comps/lvms/video-llama/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ torchaudio==0.13.1 --index-url https://download.pytorch.org/whl/cpu
torchvision==0.14.1 --index-url https://download.pytorch.org/whl/cpu
transformers
uvicorn
validators
webdataset
werkzeug
114 changes: 81 additions & 33 deletions comps/lvms/video-llama/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import argparse
import logging
import os
import re
from threading import Thread
from urllib.parse import urlparse

import decord
import requests
import uvicorn
import validators
from extract_vl_embedding import VLEmbeddingExtractor as VL
from fastapi import FastAPI, Query
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -21,6 +23,7 @@
from transformers import TextIteratorStreamer, set_seed
from video_llama.common.registry import registry
from video_llama.conversation.conversation_video import Chat
from werkzeug.utils import secure_filename

# Initialize decord bridge and seed
decord.bridge.set_bridge("torch")
Expand All @@ -33,7 +36,7 @@
context_db = None
streamer = None
chat = None
VIDEO_DIR = "/home/user/videos"
VIDEO_DIR = "/home/user/comps/lvms/video-llama/server/data"
CFG_PATH = "video_llama_config/video_llama_eval_only_vl.yaml"
MODEL_TYPE = "llama_v2"

Expand Down Expand Up @@ -161,6 +164,43 @@ def is_local_file(url):
return not url.startswith("http://") and not url.startswith("https://")


def is_valid_url(url):
# Validate the URL's structure
validation = validators.url(url)
if not validation:
logging.error("URL is invalid")
return False

# Parse the URL to components
parsed_url = urlparse(url)

# Check the scheme
if parsed_url.scheme not in ["http", "https"]:
logging.error("URL scheme is invalid")
return False

# Check for "../" in the path
if "../" in parsed_url.path:
logging.error("URL contains '../', which is not allowed")
return False

# Check that the path only contains one "." for the file extension
if parsed_url.path.count(".") != 1:
logging.error("URL path does not meet the requirement of having only one '.'")
return False

# If all checks pass, the URL is valid
logging.info("URL is valid")
return True


def is_valid_video(filename):
if re.match(r"^[a-zA-Z0-9-_]+\.(mp4)$", filename, re.IGNORECASE):
return secure_filename(filename)
else:
return False


@app.get("/health")
async def health() -> Response:
"""Health check."""
Expand All @@ -175,46 +215,54 @@ async def generate(
prompt: str = Query(..., description="Query for Video-LLama", examples="What is the man doing?"),
max_new_tokens: int = Query(150, description="Maximum number of tokens to generate", examples=150),
) -> StreamingResponse:
if not is_local_file(video_url):
parsed_url = urlparse(video_url)
video_name = os.path.basename(parsed_url.path)
else:
video_name = os.path.basename(video_url)

if video_name.lower().endswith(".mp4"):
logging.info(f"Format check passed, the file '{video_name}' is an MP4 file.")
if video_url.lower().endswith(".mp4"):
logging.info(f"Format check passed, the file '{video_url}' is an MP4 file.")
else:
logging.info(f"Format check failed, the file '{video_name}' is not an MP4 file.")
return JSONResponse(status_code=400, content={"message": "Invalid file type. Only mp4 videos are allowed."})

if not is_local_file(video_url):
try:
video_path = os.path.join(VIDEO_DIR, video_name)
response = requests.get(video_url, stream=True)

if response.status_code == 200:
with open(video_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
file.write(chunk)
logging.info(f"File downloaded: {video_path}")
else:
logging.info(f"Format check failed, the file '{video_url}' is not an MP4 file.")
return JSONResponse(status_code=500, content={"message": "Invalid file type. Only mp4 videos are allowed."})

if is_local_file(video_url):
# validate the video name
if is_valid_video(video_url):
secure_video_name = is_valid_video(video_url) # only support video name without path
else:
return JSONResponse(status_code=500, content={"message": "Invalid file name."})

video_path = os.path.join(VIDEO_DIR, secure_video_name)
if os.path.exists(video_path):
logging.info(f"File found: {video_path}")
else:
logging.error(f"File not found: {video_path}")
return JSONResponse(
status_code=404, content={"message": "File not found. Only local files under data folder are allowed."}
)
else:
# validate the remote URL
if not is_valid_url(video_url):
return JSONResponse(status_code=500, content={"message": "Invalid URL."})
else:
parsed_url = urlparse(video_url)
video_path = os.path.join(VIDEO_DIR, os.path.basename(parsed_url.path))
try:
response = requests.get(video_url, stream=True)
if response.status_code == 200:
with open(video_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
file.write(chunk)
logging.info(f"File downloaded: {video_path}")
else:
logging.info(f"Error downloading file: {response.status_code}")
return JSONResponse(status_code=500, content={"message": "Error downloading file."})
except Exception as e:
logging.info(f"Error downloading file: {response.status_code}")
return JSONResponse(status_code=500, content={"message": "Error downloading file."})
except Exception as e:
logging.info(f"Error downloading file: {response.status_code}")
return JSONResponse(status_code=500, content={"message": "Error downloading file."})
else:
# check if the video exist
video_path = video_url
if not os.path.exists(video_path):
logging.info(f"File not found: {video_path}")
return JSONResponse(status_code=404, content={"message": "File not found."})

video_info = videoInfo(start_time=start, duration=duration, video_path=video_path)

# format context and instruction
instruction = f"{get_context(prompt,context_db)[0]}: {prompt}"
# logging.info("instruction:",instruction)

return StreamingResponse(stream_res(video_info, instruction, max_new_tokens))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_lvms_video-llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function start_service() {
}

function validate_microservice() {
result=$(http_proxy="" curl http://localhost:5031/v1/lvm -X POST -d '{"video_url":"./data/silence_girl.mp4","chunk_start": 0,"chunk_duration": 7,"prompt":"What is the person doing?","max_new_tokens": 50}' -H 'Content-Type: application/json')
result=$(http_proxy="" curl http://localhost:5031/v1/lvm -X POST -d '{"video_url":"silence_girl.mp4","chunk_start": 0,"chunk_duration": 7,"prompt":"What is the person doing?","max_new_tokens": 50}' -H 'Content-Type: application/json')
if [[ $result == *"silence"* ]]; then
echo "Result correct."
else
Expand Down

0 comments on commit 3e548f3

Please sign in to comment.