Skip to content

Commit

Permalink
Improve: Fetch modalities separately
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Apr 16, 2024
1 parent a5d84fc commit 2246f13
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 126 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ package-lock.json
*.onnx
__pycache__
.build
.swiftpm
.swiftpm
node_modules
78 changes: 64 additions & 14 deletions python/uform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,80 @@
from json import load
from os.path import join
from os.path import join, exists
from typing import Mapping, Optional, Tuple
from enum import Enum

from huggingface_hub import snapshot_download


def get_checkpoint(model_name: str, token: str) -> Tuple[str, Mapping, str]:
import torch

model_path = snapshot_download(repo_id=model_name, token=token)
config_path = join(model_path, "torch_config.json")
class Modality(Enum):
TEXT = "text"
IMAGE = "image"

state = torch.load(join(model_path, "torch_weight.pt"))
return config_path, state, join(model_path, "tokenizer.json")

def get_checkpoint(model_name: str, token: Optional[str], modalities: Tuple[str]) -> Tuple[str, Mapping, str]:
import torch

def get_model(model_name: str, token: Optional[str] = None):
from uform.torch_models import VLM
# It is not recommended to use `.pth` extension when checkpointing models
# because it collides with Python path (`.pth`) configuration files.
merged_model_names = ["torch_weight.pt", "weights.pt", "model.pt"]
separate_modality_names = [str(x) + ".pt" for x in modalities]
config_names = ["config.json", "torch_config.json"]
tokenizer_names = ["tokenizer.json"]

# The download stats depend on the number of times the `config.json` is pulled
# https://huggingface.co/docs/hub/models-download-stats
model_path = snapshot_download(
repo_id=model_name,
token=token,
allow_patterns=merged_model_names + separate_modality_names + config_names + tokenizer_names,
)

# Find the first name in `config_names` that is present
config_path = None
for config_name in config_names:
if exists(join(model_path, config_name)):
config_path = join(model_path, config_name)
break

# Same for the tokenizer
tokenizer_path = None
for tokenizer_name in tokenizer_names:
if exists(join(model_path, tokenizer_name)):
tokenizer_path = join(model_path, tokenizer_name)
break

# Ideally, we want to separately fetch all the models.
# If those aren't available, aggregate separate modalities and merge them.
state = None
for file_name in merged_model_names:
if exists(join(model_path, file_name)):
state = torch.load(join(model_path, file_name))
break

if state is None:
state = {}
for file_name in separate_modality_names:
if exists(join(model_path, file_name)):
modality_name, _, _ = file_name.partition(".")
property_name = modality_name + "_encoder"
state[property_name] = torch.load(join(model_path, file_name))

return config_path, state, tokenizer_path


def get_model(model_name: str, token: Optional[str] = None, modalities: Optional[Tuple[str]] = None):
from uform.torch_models import TextVisualEncoder
from uform.torch_preprocessor import TorchProcessor

config_path, state, tokenizer_path = get_checkpoint(model_name, token)
if modalities is None:
modalities = (Modality.TEXT, Modality.IMAGE)

config_path, state, tokenizer_path = get_checkpoint(model_name, token, modalities)

with open(config_path) as f:
config = load(f)

model = VLM(config, tokenizer_path)
model = TextVisualEncoder(config, tokenizer_path)
model.image_encoder.load_state_dict(state["image_encoder"])
model.text_encoder.load_state_dict(state["text_encoder"])
processor = TorchProcessor(config, tokenizer_path)
Expand All @@ -33,7 +83,7 @@ def get_model(model_name: str, token: Optional[str] = None):


def get_model_onnx(model_name: str, device: str, dtype: str, token: Optional[str] = None):
from uform.onnx_models import VLM_ONNX
from uform.onnx_models import TextVisualEncoder
from uform.numpy_preprocessor import NumPyProcessor

assert device in (
Expand All @@ -53,7 +103,7 @@ def get_model_onnx(model_name: str, device: str, dtype: str, token: Optional[str
with open(join(model_path, "config.json")) as f:
config = load(f)

model = VLM_ONNX(model_path, config, device, dtype)
model = TextVisualEncoder(model_path, config, device, dtype)
processor = NumPyProcessor(config, join(model_path, "tokenizer.json"))

return model, processor
13 changes: 8 additions & 5 deletions python/uform/onnx_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def available_providers(device: str) -> Tuple[str, ...]:
return cpu_providers


class VisualEncoderONNX:
class VisualEncoder:
def __init__(self, model_path: str, device: str):
"""
:param model_path: Path to onnx model
Expand All @@ -43,7 +43,7 @@ def __call__(self, images: ndarray) -> Tuple[ndarray, ndarray]:
return self.session.run(None, {"images": images})


class TextEncoderONNX:
class TextEncoder:
def __init__(self, text_encoder_path: str, reranker_path: str, device: str):
"""
:param text_encoder_path: Path to onnx of text encoder
Expand Down Expand Up @@ -82,7 +82,7 @@ def forward_multimodal(
)


class VLM_ONNX:
class TextVisualEncoder:
def __init__(self, checkpoint_path: str, config: Dict, device: str, dtype: str):
assert device in (
"cpu",
Expand All @@ -103,13 +103,13 @@ def __init__(self, checkpoint_path: str, config: Dict, device: str, dtype: str):
self._text_encoder_dim = config["text_encoder"]["dim"]
self._image_encoder_dim = config["image_encoder"]["dim"]

self.text_encoder = TextEncoderONNX(
self.text_encoder = TextEncoder(
join(checkpoint_path, f"text_encoder.onnx"),
join(checkpoint_path, f"reranker.onnx"),
device,
)

self.image_encoder = VisualEncoderONNX(join(checkpoint_path, f"image_encoder.onnx"), device)
self.image_encoder = VisualEncoder(join(checkpoint_path, f"image_encoder.onnx"), device)

def encode_image(
self,
Expand Down Expand Up @@ -229,3 +229,6 @@ def embedding_dim(self) -> int:
def multimodal_embedding_dim(self) -> int:
"""Dimensionality of multimodal joint embedding."""
return self._text_encoder_dim


VLM_ONNX = TextVisualEncoder # legacy
105 changes: 0 additions & 105 deletions python/uform/preprocessing.py

This file was deleted.

5 changes: 4 additions & 1 deletion python/uform/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def forward(self, x: Tensor, return_features: Optional[bool] = None) -> Tensor:
return embeddings


class VLM(nn.Module):
class TextVisualEncoder(nn.Module):
"""
Vision-Language Model for Multimodal embeddings.
"""
Expand Down Expand Up @@ -503,3 +503,6 @@ def embedding_dim(self) -> int:
def multimodal_embedding_dim(self) -> int:
"""Dimensionality of multimodal joint embedding."""
return self.text_encoder.dim


VLM = TextVisualEncoder # legacy

0 comments on commit 2246f13

Please sign in to comment.