Skip to content

Commit

Permalink
embedding modalities eeeek so expensive
Browse files Browse the repository at this point in the history
  • Loading branch information
esteininger committed Mar 18, 2024
1 parent acacd94 commit e89ffc2
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 36 deletions.
39 changes: 39 additions & 0 deletions src/inference/embed/modalities/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import torch
import librosa
import numpy as np
import io

class AudioEmbeddingService:
def __init__(self, model):
self.processor = Wav2Vec2Processor.from_pretrained(model)
self.model = Wav2Vec2Model.from_pretrained(model)

def encode(self, file_stream):
# Load the audio file
audio_input, sr = librosa.load(io.BytesIO(file_stream), sr=16000)

# Process audio
inputs = self.processor(audio_input, return_tensors="pt", sampling_rate=sr)

# Move to CPU
inputs = {k: v.to("cpu") for k, v in inputs.items()}

# Get the audio embedding
with torch.no_grad():
audio_features = self.model(**inputs).last_hidden_state

# Mean pooling the embeddings across the time dimension
embeddings = audio_features.mean(dim=1)

# Normalize the embeddings
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

return normalized_embeddings

def get_dimensions(self):
return self.model.config.hidden_size

def get_token_size(self):
# Similar to images, token size isn't directly applicable to audio
return None
38 changes: 38 additions & 0 deletions src/inference/embed/modalities/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from transformers import CLIPProcessor, CLIPModel
import torch
import io
from PIL import Image


class ImageEmbeddingService:
def __init__(self, model):
self.processor = CLIPProcessor.from_pretrained(model)
self.model = CLIPModel.from_pretrained(model)

def encode(self, file_stream):
# Load the image
image = Image.open(io.BytesIO(file_stream))

# Process image
inputs = self.processor(images=image, return_tensors="pt")

# Move to CPU
inputs = {k: v.to("cpu") for k, v in inputs.items()}

# Get the image embedding
with torch.no_grad():
image_features = self.model.get_image_features(**inputs)

# Normalize the embeddings
image_embeddings = torch.nn.functional.normalize(image_features, p=2, dim=1)

return image_embeddings

def get_dimensions(self):
# CLIP's image and text embeddings are of the same size
return self.model.config.text_config.hidden_size

def get_token_size(self):
# This method isn't directly applicable to images as it is to text
# Returning None or a default value could be more appropriate
return None
40 changes: 40 additions & 0 deletions src/inference/embed/modalities/text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import time


class TextEmbeddingService:
def __init__(self, model):
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = AutoModel.from_pretrained(model)

def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)

def encode(self, sentences):
encoded_input = self.tokenizer(
sentences, padding=True, truncation=True, return_tensors="pt"
)

with torch.no_grad():
model_output = self.model(**encoded_input)

sentence_embeddings = self.mean_pooling(
model_output, encoded_input["attention_mask"]
)
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

return sentence_embeddings

def get_dimensions(self):
return self.model.config.hidden_size

def get_token_size(self):
return self.tokenizer.model_max_length
57 changes: 57 additions & 0 deletions src/inference/embed/modalities/video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from transformers import CLIPProcessor, CLIPModel
import cv2
import numpy as np


class VideoEmbeddingService:
def __init__(self, model):
self.processor = CLIPProcessor.from_pretrained(model)
self.model = CLIPModel.from_pretrained(model)

def frame_embeddings(self, video_path):
# Initialize a video capture object
cap = cv2.VideoCapture(video_path)
frame_embeddings = []

while True:
ret, frame = cap.read()
if not ret:
break

# Convert the color space from BGR to RGB, then convert to PIL Image
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame)

# Preprocess the image
inputs = self.processor(images=pil_image, return_tensors="pt", padding=True)

# Move to CPU and get the image embedding
inputs = {k: v.to("cpu") for k, v in inputs.items()}
with torch.no_grad():
frame_features = self.model.get_image_features(**inputs)

frame_embeddings.append(frame_features)

cap.release()
return torch.stack(frame_embeddings)

def encode(self, file_stream):
# Assume file_stream is a path for simplicity; adapt as necessary for actual streams
embeddings = self.frame_embeddings(file_stream)
# Aggregate embeddings, e.g., by averaging
video_embedding = embeddings.mean(dim=0)
# Normalize the embeddings
normalized_embedding = torch.nn.functional.normalize(
video_embedding, p=2, dim=1
)
return normalized_embedding

def get_dimensions(self):
return self.model.config.visual_projection.out_features

def get_token_size(self):
# Not applicable for videos, similar to audio and images
return None
43 changes: 7 additions & 36 deletions src/inference/embed/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@

from _utils import create_success_response

from modalities.image import ImageEmbeddingService
from modalities.text import TextEmbeddingService


class EmbeddingHandler:
def __init__(self, modality, model):
if modality == "text":
self.service = TextEmbeddingService(model)
elif modality == "image":
self.service = ImageEmbeddingService(model)
elif modality == "audio":
self.service = AudioEmbeddingService(model)
else:
raise ValueError(f"Unknown modality: {modality}")

Expand All @@ -34,39 +41,3 @@ def get_configs(self):
"elapsed_time": (time.time() * 1000) - start_time,
}
)


class TextEmbeddingService:
def __init__(self, model):
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = AutoModel.from_pretrained(model)

def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)

def encode(self, sentences):
encoded_input = self.tokenizer(
sentences, padding=True, truncation=True, return_tensors="pt"
)

with torch.no_grad():
model_output = self.model(**encoded_input)

sentence_embeddings = self.mean_pooling(
model_output, encoded_input["attention_mask"]
)
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

return sentence_embeddings

def get_dimensions(self):
return self.model.config.hidden_size

def get_token_size(self):
return self.tokenizer.model_max_length

0 comments on commit e89ffc2

Please sign in to comment.