-
Notifications
You must be signed in to change notification settings - Fork 6
/
inference.py
154 lines (122 loc) · 4.91 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import json
import argparse
import pandas as pd
from tqdm import tqdm
from typing import Union
from PIL import Image
import mimetypes
import cv2
import torch
from torch.utils.data import DataLoader
import transformers
from transformers import LlamaTokenizer, CLIPImageProcessor
from configs.dataset_config import DATASET_CONFIG
from configs.lora_config import openflamingo_tuning_config, otter_tuning_config
from mllm.src.factory import create_model_and_transforms
from mllm.otter.modeling_otter import OtterConfig, OtterForConditionalGeneration
from huggingface_hub import hf_hub_download
from peft import (
get_peft_model,
LoraConfig,
get_peft_model_state_dict,
PeftConfig,
PeftModel
)
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def get_content_type(file_path):
content_type, _ = mimetypes.guess_type(file_path)
return content_type
# ------------------- Image and Video Handling Functions -------------------
def extract_frames(video_path, num_frames=16):
video = cv2.VideoCapture(video_path)
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
frame_step = total_frames // num_frames
frames = []
for i in range(num_frames):
video.set(cv2.CAP_PROP_POS_FRAMES, i * frame_step)
ret, frame = video.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame).convert("RGB")
frames.append(frame)
video.release()
return frames
def get_image(url: str) -> Union[Image.Image, list]:
if "://" not in url: # Local file
content_type = get_content_type(url)
else: # Remote URL
content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")
if "image" in content_type:
if "://" not in url: # Local file
return Image.open(url)
else: # Remote URL
return Image.open(requests.get(url, stream=True, verify=False).raw)
elif "video" in content_type:
video_path = "temp_video.mp4"
if "://" not in url: # Local file
video_path = url
else: # Remote URL
with open(video_path, "wb") as f:
f.write(requests.get(url, stream=True, verify=False).content)
frames = extract_frames(video_path)
if "://" in url: # Only remove the temporary video file if it was downloaded
os.remove(video_path)
return frames
else:
raise ValueError("Invalid content type. Expected image or video.")
def load_pretrained_modoel():
peft_config, peft_model_id = None, None
peft_config = LoraConfig(**openflamingo_tuning_config)
model, image_processor, tokenizer = create_model_and_transforms(
clip_vision_encoder_path="ViT-L-14-336",
clip_vision_encoder_pretrained="openai",
lang_encoder_path="anas-awadalla/mpt-7b", # anas-awadalla/mpt-7b
tokenizer_path="anas-awadalla/mpt-7b", # anas-awadalla/mpt-7b
cross_attn_every_n_layers=4,
use_peft=True,
peft_config=peft_config,
)
checkpoint_path = hf_hub_download("gray311/Dolphins", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)
model.half().cuda()
return model, image_processor, tokenizer
def get_model_inputs(video_path, instruction, model, image_processor, tokenizer):
frames = get_image(video_path)
vision_x = torch.stack([image_processor(image) for image in frames], dim=0).unsqueeze(0).unsqueeze(0)
assert vision_x.shape[2] == len(frames)
prompt = [
f"USER: <image> is a driving video. {instruction} GPT:<answer>"
]
inputs = tokenizer(prompt, return_tensors="pt", ).to(model.device)
print(vision_x.shape)
print(prompt)
return vision_x, inputs
if __name__ == "__main__":
video_path = "./playground/videos/1.mp4"
instruction = "Please describe this video in detail."
model, image_processor, tokenizer = load_pretrained_modoel()
vision_x, inputs = get_model_inputs(video_path, instruction, model, image_processor, tokenizer)
generation_kwargs = {'max_new_tokens': 512, 'temperature': 1,
'top_k': 0, 'top_p': 1, 'no_repeat_ngram_size': 3, 'length_penalty': 1,
'do_sample': False,
'early_stopping': True}
generated_tokens = model.generate(
vision_x=vision_x.half().cuda(),
lang_x=inputs["input_ids"].cuda(),
attention_mask=inputs["attention_mask"].cuda(),
num_beams=3,
**generation_kwargs,
)
generated_tokens = generated_tokens.cpu().numpy()
if isinstance(generated_tokens, tuple):
generated_tokens = generated_tokens[0]
generated_text = tokenizer.batch_decode(generated_tokens)
print(
f"Dolphin output:\n\n{generated_text}"
)