Skip to content

Commit

Permalink
updated support for gpt-4, pixtral, gemini and momlo
Browse files Browse the repository at this point in the history
  • Loading branch information
PromtEngineer committed Sep 28, 2024
1 parent 51647db commit 0ba3c76
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 68 deletions.
43 changes: 23 additions & 20 deletions models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from transformers import MllamaForConditionalGeneration
from vllm.sampling_params import SamplingParams
from transformers import AutoModelForCausalLM
import google.generativeai as genai
from vllm import LLM
from vllm.sampling_params import SamplingParams

from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

from logger import get_logger

Expand All @@ -21,8 +27,8 @@ def detect_device():
"""
if torch.cuda.is_available():
return 'cuda'
# elif torch.backends.mps.is_available():
# return 'mps'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'

Expand Down Expand Up @@ -51,21 +57,13 @@ def load_model(model_choice):

elif model_choice == 'gemini':
# Load Gemini model
import genai
genai.api_key = os.environ.get('GENAI_API_KEY')
model = genai.GenerativeModel(model_name="gemini-1.5-pro")
processor = None
_model_cache[model_choice] = (model, processor)
logger.info("Gemini model loaded and cached.")
return _model_cache[model_choice]
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise ValueError("GOOGLE_API_KEY not found in .env file")
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-1.5-flash-002') # Use the appropriate model name
return model, None

elif model_choice == 'gpt4':
# Load OpenAI GPT-4 model
import openai
openai.api_key = os.environ.get('OPENAI_API_KEY')
_model_cache[model_choice] = (None, None)
logger.info("GPT-4 model ready and cached.")
return _model_cache[model_choice]

elif model_choice == 'llama-vision':
# Load Llama-Vision model
Expand All @@ -85,21 +83,26 @@ def load_model(model_choice):

elif model_choice == "pixtral":
device = detect_device()
model = LLM(model="mistralai/Pixtral-12B-2409", tokenizer_mode="mistral")
model = LLM(model="mistralai/Pixtral-12B-2409",
tokenizer_mode="mistral",
gpu_memory_utilization=0.8, # Increase GPU memory utilization
max_model_len=8192, # Decrease max model length
dtype="float16", # Use half precision to save memory
trust_remote_code=True)
sampling_params = SamplingParams(max_tokens=1024)
_model_cache[model_choice] = (model, sampling_params, device)
return _model_cache[model_choice]

elif model_choice == "molmo":
device = detect_device()
processor = AutoProcessor.from_pretrained(
'allenai/Molmo-7B-D-0924',
'allenai/MolmoE-1B-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
model = AutoModelForCausalLM.from_pretrained(
'allenai/Molmo-7B-D-0924',
'allenai/MolmoE-1B-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
Expand Down
159 changes: 112 additions & 47 deletions models/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@

from models.model_loader import load_model
from transformers import GenerationConfig
import google.generativeai as genai
from dotenv import load_dotenv
from logger import get_logger
from openai import OpenAI
from PIL import Image
import torch
import base64
import os
import io


logger = get_logger(__name__)

# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')

def generate_response(images, query, session_id, resized_height=280, resized_width=280, model_choice='qwen'):
"""
Generates a response using the selected model based on the query and images.
Expand Down Expand Up @@ -56,18 +67,83 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
)
logger.info("Response generated using Qwen model.")
return output_text[0]

elif model_choice == 'gemini':
from models.gemini_responder import generate_gemini_response
model, processor = load_model('gemini')
response = generate_gemini_response(images, query, model, processor)
logger.info("Response generated using Gemini model.")
return response

model, _ = load_model('gemini')

try:
content = []
content.append(query) # Add the text query first

for img_path in images:
full_path = os.path.join('static', img_path)
if os.path.exists(full_path):
try:
img = Image.open(full_path)
content.append(img)
except Exception as e:
logger.error(f"Error opening image {full_path}: {e}")
else:
logger.warning(f"Image file not found: {full_path}")

if len(content) == 1: # Only text, no images
return "No images could be loaded for analysis."

response = model.generate_content(content)

if response.text:
generated_text = response.text
logger.info("Response generated using Gemini model.")
return generated_text
else:
return "The Gemini model did not generate any text response."

except Exception as e:
logger.error(f"Error in Gemini processing: {str(e)}", exc_info=True)
return f"An error occurred while processing the images: {str(e)}"

elif model_choice == 'gpt4':
from models.gpt4_responder import generate_gpt4_response
model, _ = load_model('gpt4')
response = generate_gpt4_response(images, query, model)
logger.info("Response generated using GPT-4 model.")
return response
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)

try:
content = [{"type": "text", "text": query}]

for img_path in images:
full_path = os.path.join('static', img_path)
if os.path.exists(full_path):
base64_image = encode_image(full_path)
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
})
else:
logger.warning(f"Image file not found: {full_path}")

if len(content) == 1: # Only text, no images
return "No images could be loaded for analysis."

response = client.chat.completions.create(
model="gpt-4o", # Make sure to use the correct model name
messages=[
{
"role": "user",
"content": content
}
],
max_tokens=1024
)

generated_text = response.choices[0].message.content
logger.info("Response generated using GPT-4 model.")
return generated_text

except Exception as e:
logger.error(f"Error in GPT-4 processing: {str(e)}", exc_info=True)
return f"An error occurred while processing the images: {str(e)}"

elif model_choice == 'llama-vision':
# Load model, processor, and device
Expand Down Expand Up @@ -98,20 +174,22 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid

model, sampling_params, device = load_model('pixtral')

image_urls = []
for img in images:
# Convert PIL Image to base64
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
image_urls.append(f"data:image/png;base64,{img_str}")

def image_to_data_url(image_path):

image_path = os.path.join('static', image_path)

with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
ext = os.path.splitext(image_path)[1][1:] # Get the file extension
return f"data:image/{ext};base64,{encoded_string}"

messages = [
{
"role": "user",
"content": [
{"type": "text", "text": query},
*[{"type": "image_url", "image_url": {"url": url}} for url in image_urls]
*[{"type": "image_url", "image_url": {"url": image_to_data_url(img_path)}} for i, img_path in enumerate(images) if i<1]
]
},
]
Expand All @@ -120,10 +198,10 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
return outputs[0].outputs[0].text

elif model_choice == "molmo":

model, processor, device = load_model('molmo')
model = model.half() # Convert model to half precision
pil_images = []
for img_path in images:
for img_path in images[:1]: # Process only the first image for now
full_path = os.path.join('static', img_path)
if os.path.exists(full_path):
try:
Expand All @@ -138,53 +216,40 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
return "No images could be loaded for analysis."

try:
# Log the types and shapes of the images
logger.info(f"Number of images: {len(pil_images)}")
logger.info(f"Image types: {[type(img) for img in pil_images]}")
logger.info(f"Image sizes: {[img.size for img in pil_images]}")

# Process the images and text
inputs = processor.process(
images=pil_images,
text=query
)

# Log the keys and shapes of the inputs
logger.info(f"Input keys: {inputs.keys()}")
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
logger.info(f"Input '{k}' shape: {v.shape}, dtype: {v.dtype}, device: {v.device}")
else:
logger.info(f"Input '{k}' type: {type(v)}")

# Move inputs to the correct device and make a batch of size 1
inputs = {k: v.to(model.device).unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

# Log the updated shapes after moving to device
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
logger.info(f"Updated input '{k}' shape: {v.shape}, dtype: {v.dtype}, device: {v.device}")
# Convert float tensors to half precision, but keep integer tensors as they are
inputs = {k: (v.to(device).unsqueeze(0).half() if v.dtype in [torch.float32, torch.float64] else
v.to(device).unsqueeze(0))
if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()}

# Generate output
output = model.generate_from_batch(
inputs,
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
tokenizer=processor.tokenizer
)
with torch.no_grad(): # Disable gradient calculation
output = model.generate_from_batch(
inputs,
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
tokenizer=processor.tokenizer
)

# Only get generated tokens; decode them to text
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

return generated_text

except Exception as e:
logger.error(f"Error in Molmo processing: {str(e)}", exc_info=True)
return f"An error occurred while processing the images: {str(e)}"
finally:
# Close the opened images to free up resources
for img in pil_images:
img.close()

return generated_text
img.close()
else:
logger.error(f"Invalid model choice: {model_choice}")
return "Invalid model selected."
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ docx2pdf
qwen-vl-utils
vllm>=0.6.1.post1
mistral_common>=1.4.1
einops
einops
mistral_common[opencv]
mistral_common

0 comments on commit 0ba3c76

Please sign in to comment.