Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add custom embedder #2236

Merged
merged 56 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
0d76e67
Feat: Add custom embedder
vonodiripsa Jun 12, 2024
9afa431
Corrected Names and file location
Jun 13, 2024
ee8f6f9
Code style corrections
Jun 14, 2024
09af0e0
Source temp fixes
Jun 14, 2024
a485010
Formating
Jun 14, 2024
eae2aee
First test
Jun 14, 2024
2cdfb59
Name changes
Jun 14, 2024
88d34b8
With two models
Jun 14, 2024
052f39b
Source style corrections
Jun 16, 2024
ac7bc67
Name change
Jun 16, 2024
52a1581
Name change
Jun 16, 2024
1bd45ab
Merge init scripts
Jun 17, 2024
cce365a
Removed extra file
Jun 17, 2024
8d2ed06
Added result output and _ correction
Jun 18, 2024
6b7798d
Formatted
Jun 18, 2024
a8b4dc9
Merge branch 'microsoft:master' into add-demo
bvonodiripsa Jun 18, 2024
dff46fe
Runtime flag update and load class from file (not from synapse.ml..)
Jun 18, 2024
5f569ce
Use built synapse.ml package instead of file
Jun 18, 2024
3098645
Clean imports and the class
Jun 27, 2024
e2e2014
Corrected edge cases (slam dataframe or no gpu)
Jun 28, 2024
710d9a6
Added check for cuda
Jun 28, 2024
42d8a07
Added synapse.ml.nn.KNN to run on CPU
Jul 1, 2024
7280ea7
add some small fixes to namespaces
mhamilton723 Jul 3, 2024
c63ab5b
Merge branch 'master' into add-demo
bvonodiripsa Jul 10, 2024
ac5828c
formatted
Jul 10, 2024
ff7b42d
Merge branch 'master' into add-demo
mhamilton723 Jul 17, 2024
96b81e2
Merge branch 'master' into add-demo
mhamilton723 Jul 17, 2024
f0c3b49
corrected default batch size
Jul 19, 2024
4047867
Added test
Jul 19, 2024
7894e9d
Corrected build errors
Jul 20, 2024
48433cd
style fixes
Jul 20, 2024
6989322
Style corrections
Jul 20, 2024
c58d929
More style corrections
Jul 20, 2024
9c54a66
Added extra row
Jul 20, 2024
776dfeb
Corrected comparison results image
Jul 22, 2024
52d52df
Corrected init and test
Jul 26, 2024
959f04e
Style correction
Jul 26, 2024
9fefefe
Updated notebook image link
Jul 26, 2024
b39218b
Added sentence_transformers for testing
Jul 31, 2024
f058cb2
trying to fix testing
Jul 31, 2024
ab0895b
style change
Jul 31, 2024
9c096ed
removed style spaces
Jul 31, 2024
62d10fc
Style again...
Jul 31, 2024
538c05c
added pyspark
Aug 1, 2024
4006266
remove pyspark
Aug 1, 2024
63ef878
corrected utest
Aug 1, 2024
72b4595
Reverse style change
Aug 1, 2024
575f020
change data size
Aug 1, 2024
fc66d0d
comment a line
Aug 1, 2024
e8d308d
Corrected init_spark()
Aug 1, 2024
07a67d3
Style and added SQLContext
Aug 1, 2024
2599ae0
Corrected result_df and remove old image
Aug 1, 2024
2b3f3d4
Corrected sidebars.js
Aug 2, 2024
be95231
Merge branch 'master' into add-demo
mhamilton723 Aug 2, 2024
b19a895
match web names
Aug 6, 2024
5a993cf
Merge branch 'master' into add-demo
mhamilton723 Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Import necessary libraries
import numpy as np
import torch
import pyspark.sql.functions as F
import tensorrt as trt
import logging
import warnings
import sys
import datetime
import pytz
from tqdm import tqdm, trange
from numpy import ndarray
from torch import Tensor
from typing import List, Union

import model_navigator as nav
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import batch_to_device
from pyspark.ml.functions import predict_batch_udf
from faker import Faker
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this dep as previously discussed? you can add a little fake data passage and replicate it if you need


from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params
from pyspark.sql.functions import col, struct, rand
from pyspark.sql.types import (
StructType,
StructField,
IntegerType,
StringType,
ArrayType,
FloatType,
)

class HuggingFaceSentenceEmbedder(Transformer, HasInputCol, HasOutputCol):
"""
Custom transformer that extends PySpark's Transformer class to
perform sentence embedding using a model with optional TensorRT acceleration.
"""

# Define additional parameters
runtime = Param(
Params._dummy(),
"runtime",
"Specifies the runtime environment: cpu, cuda, or tensorrt",
)
batchSize = Param(Params._dummy(), "batchSize", "Batch size for embeddings", int)
modelName = Param(Params._dummy(), "modelName", "Full Model Name parameter")

class _SentenceTransformerNavigator(SentenceTransformer):
"""
Inner class extending SentenceTransformer to override the encode method
with additional functionality and optimizations (mainly to eliminate RecursiveErrors).
"""

def encode(
self,
sentences: Union[str, List[str]],
batch_size: int = 64,
sentence_length: int = 512,
show_progress_bar: bool = False,
output_value: str = "sentence_embedding",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
normalize_embeddings: bool = False,
) -> Union[List[Tensor], ndarray, Tensor]:
"""
Encode sentences into embeddings with optional configurations.
"""
self.eval()
show_progress_bar = (
show_progress_bar if show_progress_bar is not None else True
)
convert_to_numpy = convert_to_numpy and not convert_to_tensor
output_value = output_value or "sentence_embedding"

# Handle input as a list of sentences
input_was_string = isinstance(sentences, str) or not hasattr(
sentences, "__len__"
)
if input_was_string:
sentences = [sentences]

# Determine the device to use for computation
device = device or self._target_device
self.to(device)

# Initialize list for embeddings
all_embeddings = []
length_sorted_idx = np.argsort(
[-self._text_length(sen) for sen in sentences]
)
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

# Process sentences in batches
for start_index in trange(
0,
len(sentences),
batch_size,
desc="Batches",
disable=not show_progress_bar,
):
sentences_batch = sentences_sorted[
start_index : start_index + batch_size
]
features = self.tokenize(sentences_batch)
features = batch_to_device(features, device)

# Perform forward pass and gather embeddings
with torch.no_grad():
out_features = self(features)

if output_value == "token_embeddings":
embeddings = []
for token_emb, attention in zip(
out_features[output_value], out_features["attention_mask"]
):
last_mask_id = len(attention) - 1
while (
last_mask_id > 0 and attention[last_mask_id].item() == 0
):
last_mask_id -= 1
embeddings.append(token_emb[0 : last_mask_id + 1])
elif output_value is None:
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {
name: out_features[name][sent_idx]
for name in out_features
}
embeddings.append(row)
else:
embeddings = out_features[output_value]
embeddings = embeddings.detach()
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(
embeddings, p=2, dim=1
)
if convert_to_numpy:
embeddings = embeddings.cpu()

all_embeddings.extend(embeddings)

# Restore original order of sentences
all_embeddings = [
all_embeddings[idx] for idx in np.argsort(length_sorted_idx)
]
if convert_to_tensor:
all_embeddings = torch.stack(all_embeddings)
elif convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])

if input_was_string:
all_embeddings = all_embeddings[0]

return all_embeddings

def __init__(
self,
inputCol=None,
outputCol=None,
runtime=None,
batchSize=16,
modelName=None,
):
"""
Initialize the HuggingFaceSentenceEmbedder with input/output columns and optional TRT flag.
"""
super(HuggingFaceSentenceEmbedder, self).__init__()
self._setDefault(
runtime="cpu",
modelName=modelName,
batchSize=16,
)
self._set(
inputCol=inputCol,
outputCol=outputCol,
runtime=runtime,
batchSize=batchSize,
modelName=modelName,
)

# Setter method for batchSize
def setBatchSize(self, value):
self._set(batchSize=value)
return self

# Getter method for batchSize
def getBatchSize(self):
return self.getOrDefault(self.batchSize)

# Sets the runtime environment for the model.
# Supported values: 'cpu', 'cuda', 'tensorrt'
def setRuntime(self, value):
"""
Sets the runtime environment for the model.
Supported values: 'cpu', 'cuda', 'tensorrt'
"""
if value not in ["cpu", "cuda", "tensorrt"]:
raise ValueError(
"Invalid runtime specified. Choose from 'cpu', 'cuda', 'tensorrt'"
)
self.setOrDefault(self.runtime, value)

def getRuntime(self):
return self.getOrDefault(self.runtime)

# Setter method for modelName
def setModelName(self, value):
self._set(modelName=value)
return self

# Getter method for modelName
def getModelName(self):
return self.getOrDefault(self.modelName)

# Optimize the model using Model Navigator with TensorRT configuration.
def _optimize(self, model):
conf = nav.OptimizeConfig(
target_formats=(nav.Format.TENSORRT,),
runners=("TensorRT",),
optimization_profile=nav.OptimizationProfile(max_batch_size=64),
custom_configs=[
nav.TorchConfig(autocast=True),
nav.TorchScriptConfig(autocast=True),
nav.TensorRTConfig(
precision=(nav.TensorRTPrecision.FP16,),
onnx_parser_flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM.value],
),
],
)

def _gen_size_chunk():
"""
Generate chunks of different batch sizes and sentence lengths.
"""
for batch_size in [64]:
for sentence_length in [20, 300, 512]:
yield (batch_size, sentence_length)

def _get_dataloader(repeat_times: int = 2):
"""
Create a data loader with synthetic data using Faker.
"""
faker = Faker()
i = 0
for batch_size, chunk_size in _gen_size_chunk():
for _ in range(repeat_times):
yield (
i,
(
[
" ".join(faker.words(chunk_size))
for _ in range(batch_size)
],
{"show_progress_bar": False},
),
)
i += 1

total_batches = len(list(_gen_size_chunk()))
func = lambda x, **kwargs: self._SentenceTransformerNavigator.encode(
model, x, **kwargs
)
nav.optimize(
func, dataloader=tqdm(_get_dataloader(), total=total_batches), config=conf
)

def _predict_batch_fn(self):
"""
Create and return a function for batch prediction.
"""
runtime = self.getRuntime()
if "model" not in globals():
global model
modelName = self.getModelName()
if runtime == "tensorrt":
moduleName = modelName.split("/")[1]
model = self._SentenceTransformerNavigator(modelName).eval()
model = nav.Module(model, name=moduleName)
try:
nav.load_optimized()
except Exception:
self._optimize(model)
nav.load_optimized()
else:
model = SentenceTransformer(modelName).eval()
if runtime == "cuda":
model = model.cuda()
else:
model = model.to("cpu")

def predict(inputs):
"""
Predict method to encode inputs using the model.
"""
with torch.no_grad():
output = model.encode(
inputs.tolist(), convert_to_tensor=False, show_progress_bar=False
)

return output

return predict

# Method to apply the transformation to the dataset
def _transform(self, dataset, spark):
"""
Apply the transformation to the input dataset.
"""
input_col = self.getInputCol()
output_col = self.getOutputCol()

encode = predict_batch_udf(
self._predict_batch_fn,
return_type=ArrayType(FloatType()),
batch_size=self.getBatchSize(),
)
return dataset.withColumn(output_col, encode(input_col))

def transform(self, dataset, spark=None):
"""
Public method to transform the dataset.
"""
return self._transform(dataset, spark)

def copy(self, extra=None):
"""
Create a copy of the transformer.
"""
return self._defaultCopy(extra)

# Example usage:
# data = input data frame
# transformer = EmbeddingTransformer(inputCol="combined", outputCol="embeddings")
# result = transformer.transform(data)
# result.show()
Empty file.
Loading