Skip to content

Commit

Permalink
Improve: Multi-GPU support in Py
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Apr 24, 2024
1 parent d00204f commit 917a4a8
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 93 deletions.
28 changes: 24 additions & 4 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,33 @@ For that pick the encoder of the model you want to run in parallel, and wrap it

```python
from uform import get_model, Modality
import torch.nn as nn

encoders, processors = uform.get_model('unum-cloud/uform-vl-english-small', backend='torch', device='gpu')
encoders, processors = uform.get_model('unum-cloud/uform-vl-english-small', backend='torch')

encoder_image = encoders[Modality.IMAGE_ENCODER]
encoder_image = nn.DataParallel(encoder_image)
model_text = models[Modality.TEXT_ENCODER]
model_image = models[Modality.IMAGE_ENCODER]
processor_text = processors[Modality.TEXT_ENCODER]
processor_image = processors[Modality.IMAGE_ENCODER]

_, res = encoder_image(images, 0)
model_text.return_features = False
model_image.return_features = False
model_text_parallel = nn.DataParallel(model_text)
model_image_parallel = nn.DataParallel(model_image)
```

Since we are now dealing with the PyTorch wrapper, make sure to use the `forward` method (instead of `encode`) to get the embeddings, and the `.detach().cpu().numpy()` sequence to bring the data back to more Pythonic NumPy arrays.

```python
def get_image_embedding(images: List[Image]):
preprocessed = processor_image(images)
embedding = model_image_parallel.forward(preprocessed)
return embedding.detach().cpu().numpy()

def get_text_embedding(texts: List[str]):
preprocessed = processor_text(texts)
embedding = model_text_parallel.forward(preprocessed)
return embedding.detach().cpu().numpy()
```

### ONNX and CUDA
Expand Down
175 changes: 109 additions & 66 deletions python/scripts/test_encoders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import wraps
from typing import Tuple
import requests
from io import BytesIO
Expand All @@ -7,7 +8,7 @@
import numpy as np
from PIL import Image

from uform import Modality, get_model, get_model_onnx
from uform import Modality, get_model, ExecutionProviderError

# PyTorch is a very heavy dependency, so we may want to skip these tests if it's not installed
try:
Expand Down Expand Up @@ -49,6 +50,21 @@
token = file.read().strip()


def skip_on(exception, reason="No good reason :)"):
def decorator_func(f):
@wraps(f)
def wrapper(*args, **kwargs):
try:
# Try to run the test
return f(*args, **kwargs)
except exception:
pytest.skip(reason)

return wrapper

return decorator_func


def cosine_similarity(x, y) -> float:
if not isinstance(x, np.ndarray):
x = x.detach().numpy()
Expand All @@ -61,7 +77,7 @@ def cosine_similarity(x, y) -> float:
return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))


def cross_references_image_and_text_embeddings(text_to_embedding, image_to_embedding):
def cross_references_image_and_text_embeddings(text_to_embedding, image_to_embedding, batch_size_multiple: int = 1):
"""Test if the embeddings of text and image are semantically similar
using a small set of example text-image pairs."""

Expand All @@ -80,30 +96,27 @@ def cross_references_image_and_text_embeddings(text_to_embedding, image_to_embed
"https://github.com/ashvardanian/ashvardanian/blob/master/demos/light-bedroom-furniture.jpg?raw=true",
"https://github.com/ashvardanian/ashvardanian/blob/master/demos/louvre-at-night.jpg?raw=true",
]
assert len(texts) == len(image_urls), "Number of texts and images should be the same."

text_embeddings = []
image_embeddings = []

for text, image_url in zip(texts, image_urls):
# Download and open the image
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
images = [Image.open(BytesIO(requests.get(image_url).content)) for image_url in image_urls]
count_pairs = len(texts)

# Get embeddings
text_embedding = text_to_embedding(text)
image_embedding = image_to_embedding(image)
# Ensure we have a sufficiently large batch
texts = texts * batch_size_multiple
images = images * batch_size_multiple

text_embeddings.append(text_embedding)
image_embeddings.append(image_embedding)
# Compute the embedding in a batch fashion
text_embeddings = text_to_embedding(texts)
image_embeddings = image_to_embedding(images)

# Evaluate cosine similarity
for i in range(len(texts)):
for i in range(count_pairs):
pair_similarity = cosine_similarity(text_embeddings[i], image_embeddings[i])
other_text_similarities = [
cosine_similarity(text_embeddings[j], image_embeddings[i]) for j in range(len(texts)) if j != i
cosine_similarity(text_embeddings[j], image_embeddings[i]) for j in range(count_pairs) if j != i
]
other_image_similarities = [
cosine_similarity(text_embeddings[i], image_embeddings[j]) for j in range(len(texts)) if j != i
cosine_similarity(text_embeddings[i], image_embeddings[j]) for j in range(count_pairs) if j != i
]

assert pair_similarity > max(
Expand Down Expand Up @@ -171,79 +184,109 @@ def test_torch_many_embeddings(model_name: str, batch_size: int):
@pytest.mark.skipif(not onnx_available, reason="ONNX is not installed")
@pytest.mark.parametrize("model_name", onnx_models)
@pytest.mark.parametrize("device", ["CPUExecutionProvider"])
@skip_on(ExecutionProviderError, reason="Missing execution provider")
def test_onnx_one_embedding(model_name: str, device: str):

from uform.onnx_encoders import ExecutionProviderError

try:

processors, models = get_model(model_name, token=token, device=device, backend="onnx")
model_text = models[Modality.TEXT_ENCODER]
model_image = models[Modality.IMAGE_ENCODER]
processor_text = processors[Modality.TEXT_ENCODER]
processor_image = processors[Modality.IMAGE_ENCODER]

text = "a small red panda in a zoo"
image_path = "assets/unum.png"
processors, models = get_model(model_name, token=token, device=device, backend="onnx")
model_text = models[Modality.TEXT_ENCODER]
model_image = models[Modality.IMAGE_ENCODER]
processor_text = processors[Modality.TEXT_ENCODER]
processor_image = processors[Modality.IMAGE_ENCODER]

image = Image.open(image_path)
image_data = processor_image(image)
text_data = processor_text(text)
text = "a small red panda in a zoo"
image_path = "assets/unum.png"

image_features, image_embedding = model_image.encode(image_data)
text_features, text_embedding = model_text.encode(text_data)
image = Image.open(image_path)
image_data = processor_image(image)
text_data = processor_text(text)

assert image_embedding.shape[0] == 1, "Image embedding batch size is not 1"
assert text_embedding.shape[0] == 1, "Text embedding batch size is not 1"
image_features, image_embedding = model_image.encode(image_data)
text_features, text_embedding = model_text.encode(text_data)

# Nested fucntions are easier to debug, than lambdas
def get_image_embedding(image_data):
features, embedding = model_image.encode(processor_image(image_data))
return embedding
assert image_embedding.shape[0] == 1, "Image embedding batch size is not 1"
assert text_embedding.shape[0] == 1, "Text embedding batch size is not 1"

def get_text_embedding(text_data):
features, embedding = model_text.encode(processor_text(text_data))
return embedding
# Nested fucntions are easier to debug, than lambdas
def get_image_embedding(image_data):
features, embedding = model_image.encode(processor_image(image_data))
return embedding

# Test if the model outputs actually make sense
cross_references_image_and_text_embeddings(get_text_embedding, get_image_embedding)
def get_text_embedding(text_data):
features, embedding = model_text.encode(processor_text(text_data))
return embedding

except ExecutionProviderError as e:
pytest.skip(f"Execution provider error: {e}")
# Test if the model outputs actually make sense
cross_references_image_and_text_embeddings(get_text_embedding, get_image_embedding)


@pytest.mark.skipif(not onnx_available, reason="ONNX is not installed")
@pytest.mark.parametrize("model_name", onnx_models)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("device", ["CPUExecutionProvider"])
@skip_on(ExecutionProviderError, reason="Missing execution provider")
def test_onnx_many_embeddings(model_name: str, batch_size: int, device: str):

from uform.onnx_encoders import ExecutionProviderError
processors, models = get_model(model_name, token=token, device=device, backend="onnx")
model_text = models[Modality.TEXT_ENCODER]
model_image = models[Modality.IMAGE_ENCODER]
processor_text = processors[Modality.TEXT_ENCODER]
processor_image = processors[Modality.IMAGE_ENCODER]

try:
texts = ["a small red panda in a zoo"] * batch_size
image_paths = ["assets/unum.png"] * batch_size

images = [Image.open(path) for path in image_paths]
image_data = processor_image(images)
text_data = processor_text(texts)

processors, models = get_model(model_name, token=token, device=device, backend="onnx")
model_text = models[Modality.TEXT_ENCODER]
model_image = models[Modality.IMAGE_ENCODER]
processor_text = processors[Modality.TEXT_ENCODER]
processor_image = processors[Modality.IMAGE_ENCODER]
image_embeddings = model_image.encode(image_data, return_features=False)
text_embeddings = model_text.encode(text_data, return_features=False)

texts = ["a small red panda in a zoo"] * batch_size
image_paths = ["assets/unum.png"] * batch_size
assert image_embeddings.shape[0] == batch_size, "Image embedding is unexpected"
assert text_embeddings.shape[0] == batch_size, "Text embedding is unexpected"

images = [Image.open(path) for path in image_paths]
image_data = processor_image(images)
text_data = processor_text(texts)

image_embeddings = model_image.encode(image_data, return_features=False)
text_embeddings = model_text.encode(text_data, return_features=False)
@pytest.mark.skipif(not torch_available, reason="PyTorch is not installed")
@pytest.mark.parametrize("model_name", torch_models[:1])
def test_torch_multi_gpu(model_name: str):

assert image_embeddings.shape[0] == batch_size, "Image embedding is unexpected"
assert text_embeddings.shape[0] == batch_size, "Text embedding is unexpected"
count_cuda_devices = torch.cuda.device_count()
if count_cuda_devices < 2:
pytest.skip("Not enough CUDA devices to run multi-GPU test")

except ExecutionProviderError as e:
pytest.skip(f"Execution provider error: {e}")
processors, models = get_model(model_name, token=token, backend="torch", device="cuda")
model_text = models[Modality.TEXT_ENCODER]
model_image = models[Modality.IMAGE_ENCODER]
processor_text = processors[Modality.TEXT_ENCODER]
processor_image = processors[Modality.IMAGE_ENCODER]

import torch.nn as nn

model_text.return_features = False
model_image.return_features = False
model_text_parallel = nn.DataParallel(model_text)
model_image_parallel = nn.DataParallel(model_image)

# Nested fucntions are easier to debug, than lambdas
def get_image_embedding(image_data):
preprocessed = processor_image(image_data)
embedding = model_image_parallel.forward(preprocessed)
return embedding.detach().cpu().numpy()

def get_text_embedding(text_data):
preprocessed = processor_text(text_data)
embedding = model_text_parallel.forward(preprocessed)
return embedding.detach().cpu().numpy()

# Test if the model outputs actually make sense
cross_references_image_and_text_embeddings(
get_text_embedding,
get_image_embedding,
batch_size_multiple=count_cuda_devices,
)


if __name__ == "__main__":
pytest.main(["-s", "-x", __file__])
# If you want to run this test file individually, you can do so by running:
# pytest.main(["-s", "-x", __file__])
pass
10 changes: 1 addition & 9 deletions python/uform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
from os.path import join, exists
from typing import Dict, Optional, Tuple, Literal, Union, Callable
from enum import Enum

from huggingface_hub import snapshot_download, utils

from uform.onnx_encoders import ExecutionProviderError


class Modality(Enum):
TEXT_ENCODER = "text_encoder"
IMAGE_ENCODER = "image_encoder"
VIDEO_ENCODER = "video_encoder"
TEXT_DECODER = "text_decoder"
from uform.shared import ExecutionProviderError, Modality


def _normalize_modalities(modalities: Tuple[str, Modality]) -> Tuple[Modality]:
Expand Down
6 changes: 4 additions & 2 deletions python/uform/numpy_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from tokenizers import Tokenizer
import numpy as np

from uform.shared import read_config


class TextProcessor:
def __init__(self, config_path: PathLike, tokenizer_path: PathLike):
Expand All @@ -14,7 +16,7 @@ def __init__(self, config_path: PathLike, tokenizer_path: PathLike):
:param tokenizer_path: path to tokenizer file
"""

config = json.load(open(config_path, "r"))
config = read_config(config_path)
if "text_encoder" in config:
config = config["text_encoder"]

Expand Down Expand Up @@ -60,7 +62,7 @@ def __init__(self, config_path: PathLike, tokenizer_path: PathLike = None):
:param tensor_type: which tensors to return, either pt (PyTorch) or np (NumPy)
"""

config = json.load(open(config_path, "r"))
config = read_config(config_path)
if "image_encoder" in config:
config = config["image_encoder"]

Expand Down
4 changes: 1 addition & 3 deletions python/uform/onnx_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import onnxruntime as ort
from numpy import ndarray


class ExecutionProviderError(Exception):
"""Exception raised when a requested execution provider is not available."""
from uform.shared import ExecutionProviderError


def available_providers(device: Optional[str]) -> Tuple[str, ...]:
Expand Down
26 changes: 26 additions & 0 deletions python/uform/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from enum import Enum
from typing import Union
from os import PathLike
import json


class Modality(Enum):
TEXT_ENCODER = "text_encoder"
IMAGE_ENCODER = "image_encoder"
VIDEO_ENCODER = "video_encoder"
TEXT_DECODER = "text_decoder"


class ExecutionProviderError(Exception):
"""Exception raised when a requested execution provider is not available."""


ConfigOrPath = Union[PathLike, str, object]


def read_config(path_or_object: ConfigOrPath) -> object:
if isinstance(path_or_object, (PathLike, str)):
with open(path_or_object, "r") as f:
return json.load(f)
else:
return path_or_object
Loading

0 comments on commit 917a4a8

Please sign in to comment.