-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
embedding modalities eeeek so expensive
- Loading branch information
1 parent
acacd94
commit e89ffc2
Showing
5 changed files
with
181 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from transformers import Wav2Vec2Processor, Wav2Vec2Model | ||
import torch | ||
import librosa | ||
import numpy as np | ||
import io | ||
|
||
class AudioEmbeddingService: | ||
def __init__(self, model): | ||
self.processor = Wav2Vec2Processor.from_pretrained(model) | ||
self.model = Wav2Vec2Model.from_pretrained(model) | ||
|
||
def encode(self, file_stream): | ||
# Load the audio file | ||
audio_input, sr = librosa.load(io.BytesIO(file_stream), sr=16000) | ||
|
||
# Process audio | ||
inputs = self.processor(audio_input, return_tensors="pt", sampling_rate=sr) | ||
|
||
# Move to CPU | ||
inputs = {k: v.to("cpu") for k, v in inputs.items()} | ||
|
||
# Get the audio embedding | ||
with torch.no_grad(): | ||
audio_features = self.model(**inputs).last_hidden_state | ||
|
||
# Mean pooling the embeddings across the time dimension | ||
embeddings = audio_features.mean(dim=1) | ||
|
||
# Normalize the embeddings | ||
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | ||
|
||
return normalized_embeddings | ||
|
||
def get_dimensions(self): | ||
return self.model.config.hidden_size | ||
|
||
def get_token_size(self): | ||
# Similar to images, token size isn't directly applicable to audio | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from transformers import CLIPProcessor, CLIPModel | ||
import torch | ||
import io | ||
from PIL import Image | ||
|
||
|
||
class ImageEmbeddingService: | ||
def __init__(self, model): | ||
self.processor = CLIPProcessor.from_pretrained(model) | ||
self.model = CLIPModel.from_pretrained(model) | ||
|
||
def encode(self, file_stream): | ||
# Load the image | ||
image = Image.open(io.BytesIO(file_stream)) | ||
|
||
# Process image | ||
inputs = self.processor(images=image, return_tensors="pt") | ||
|
||
# Move to CPU | ||
inputs = {k: v.to("cpu") for k, v in inputs.items()} | ||
|
||
# Get the image embedding | ||
with torch.no_grad(): | ||
image_features = self.model.get_image_features(**inputs) | ||
|
||
# Normalize the embeddings | ||
image_embeddings = torch.nn.functional.normalize(image_features, p=2, dim=1) | ||
|
||
return image_embeddings | ||
|
||
def get_dimensions(self): | ||
# CLIP's image and text embeddings are of the same size | ||
return self.model.config.text_config.hidden_size | ||
|
||
def get_token_size(self): | ||
# This method isn't directly applicable to images as it is to text | ||
# Returning None or a default value could be more appropriate | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from transformers import AutoTokenizer, AutoModel | ||
import torch | ||
import torch.nn.functional as F | ||
import time | ||
|
||
|
||
class TextEmbeddingService: | ||
def __init__(self, model): | ||
self.tokenizer = AutoTokenizer.from_pretrained(model) | ||
self.model = AutoModel.from_pretrained(model) | ||
|
||
def mean_pooling(self, model_output, attention_mask): | ||
token_embeddings = model_output[0] | ||
input_mask_expanded = ( | ||
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | ||
) | ||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | ||
input_mask_expanded.sum(1), min=1e-9 | ||
) | ||
|
||
def encode(self, sentences): | ||
encoded_input = self.tokenizer( | ||
sentences, padding=True, truncation=True, return_tensors="pt" | ||
) | ||
|
||
with torch.no_grad(): | ||
model_output = self.model(**encoded_input) | ||
|
||
sentence_embeddings = self.mean_pooling( | ||
model_output, encoded_input["attention_mask"] | ||
) | ||
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) | ||
|
||
return sentence_embeddings | ||
|
||
def get_dimensions(self): | ||
return self.model.config.hidden_size | ||
|
||
def get_token_size(self): | ||
return self.tokenizer.model_max_length |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import torch | ||
from PIL import Image | ||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | ||
from transformers import CLIPProcessor, CLIPModel | ||
import cv2 | ||
import numpy as np | ||
|
||
|
||
class VideoEmbeddingService: | ||
def __init__(self, model): | ||
self.processor = CLIPProcessor.from_pretrained(model) | ||
self.model = CLIPModel.from_pretrained(model) | ||
|
||
def frame_embeddings(self, video_path): | ||
# Initialize a video capture object | ||
cap = cv2.VideoCapture(video_path) | ||
frame_embeddings = [] | ||
|
||
while True: | ||
ret, frame = cap.read() | ||
if not ret: | ||
break | ||
|
||
# Convert the color space from BGR to RGB, then convert to PIL Image | ||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
pil_image = Image.fromarray(frame) | ||
|
||
# Preprocess the image | ||
inputs = self.processor(images=pil_image, return_tensors="pt", padding=True) | ||
|
||
# Move to CPU and get the image embedding | ||
inputs = {k: v.to("cpu") for k, v in inputs.items()} | ||
with torch.no_grad(): | ||
frame_features = self.model.get_image_features(**inputs) | ||
|
||
frame_embeddings.append(frame_features) | ||
|
||
cap.release() | ||
return torch.stack(frame_embeddings) | ||
|
||
def encode(self, file_stream): | ||
# Assume file_stream is a path for simplicity; adapt as necessary for actual streams | ||
embeddings = self.frame_embeddings(file_stream) | ||
# Aggregate embeddings, e.g., by averaging | ||
video_embedding = embeddings.mean(dim=0) | ||
# Normalize the embeddings | ||
normalized_embedding = torch.nn.functional.normalize( | ||
video_embedding, p=2, dim=1 | ||
) | ||
return normalized_embedding | ||
|
||
def get_dimensions(self): | ||
return self.model.config.visual_projection.out_features | ||
|
||
def get_token_size(self): | ||
# Not applicable for videos, similar to audio and images | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters