-
Notifications
You must be signed in to change notification settings - Fork 830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add custom embedder #2236
Merged
Merged
Changes from 18 commits
Commits
Show all changes
56 commits
Select commit
Hold shift + click to select a range
0d76e67
Feat: Add custom embedder
vonodiripsa 9afa431
Corrected Names and file location
ee8f6f9
Code style corrections
09af0e0
Source temp fixes
a485010
Formating
eae2aee
First test
2cdfb59
Name changes
88d34b8
With two models
052f39b
Source style corrections
ac7bc67
Name change
52a1581
Name change
1bd45ab
Merge init scripts
cce365a
Removed extra file
8d2ed06
Added result output and _ correction
6b7798d
Formatted
a8b4dc9
Merge branch 'microsoft:master' into add-demo
bvonodiripsa dff46fe
Runtime flag update and load class from file (not from synapse.ml..)
5f569ce
Use built synapse.ml package instead of file
3098645
Clean imports and the class
e2e2014
Corrected edge cases (slam dataframe or no gpu)
710d9a6
Added check for cuda
42d8a07
Added synapse.ml.nn.KNN to run on CPU
7280ea7
add some small fixes to namespaces
mhamilton723 c63ab5b
Merge branch 'master' into add-demo
bvonodiripsa ac5828c
formatted
ff7b42d
Merge branch 'master' into add-demo
mhamilton723 96b81e2
Merge branch 'master' into add-demo
mhamilton723 f0c3b49
corrected default batch size
4047867
Added test
7894e9d
Corrected build errors
48433cd
style fixes
6989322
Style corrections
c58d929
More style corrections
9c54a66
Added extra row
776dfeb
Corrected comparison results image
52d52df
Corrected init and test
959f04e
Style correction
9fefefe
Updated notebook image link
b39218b
Added sentence_transformers for testing
f058cb2
trying to fix testing
ab0895b
style change
9c096ed
removed style spaces
62d10fc
Style again...
538c05c
added pyspark
4006266
remove pyspark
63ef878
corrected utest
72b4595
Reverse style change
575f020
change data size
fc66d0d
comment a line
e8d308d
Corrected init_spark()
07a67d3
Style and added SQLContext
2599ae0
Corrected result_df and remove old image
2b3f3d4
Corrected sidebars.js
be95231
Merge branch 'master' into add-demo
mhamilton723 b19a895
match web names
5a993cf
Merge branch 'master' into add-demo
mhamilton723 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
351 changes: 351 additions & 0 deletions
351
deep-learning/src/main/python/synapse/ml/HuggingFaceSentenceEmbedder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,351 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Import necessary libraries | ||
import numpy as np | ||
import torch | ||
import pyspark.sql.functions as F | ||
import tensorrt as trt | ||
import logging | ||
import warnings | ||
import sys | ||
import datetime | ||
import pytz | ||
from tqdm import tqdm, trange | ||
from numpy import ndarray | ||
from torch import Tensor | ||
from typing import List, Union | ||
|
||
import model_navigator as nav | ||
from sentence_transformers import SentenceTransformer | ||
from sentence_transformers.util import batch_to_device | ||
from pyspark.ml.functions import predict_batch_udf | ||
from faker import Faker | ||
|
||
from pyspark.ml import Transformer | ||
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params | ||
from pyspark.sql.functions import col, struct, rand | ||
from pyspark.sql.types import ( | ||
StructType, | ||
StructField, | ||
IntegerType, | ||
StringType, | ||
ArrayType, | ||
FloatType, | ||
) | ||
|
||
class HuggingFaceSentenceEmbedder(Transformer, HasInputCol, HasOutputCol): | ||
""" | ||
Custom transformer that extends PySpark's Transformer class to | ||
perform sentence embedding using a model with optional TensorRT acceleration. | ||
""" | ||
|
||
# Define additional parameters | ||
runtime = Param( | ||
Params._dummy(), | ||
"runtime", | ||
"Specifies the runtime environment: cpu, cuda, or tensorrt", | ||
) | ||
batchSize = Param(Params._dummy(), "batchSize", "Batch size for embeddings", int) | ||
modelName = Param(Params._dummy(), "modelName", "Full Model Name parameter") | ||
|
||
class _SentenceTransformerNavigator(SentenceTransformer): | ||
""" | ||
Inner class extending SentenceTransformer to override the encode method | ||
with additional functionality and optimizations (mainly to eliminate RecursiveErrors). | ||
""" | ||
|
||
def encode( | ||
self, | ||
sentences: Union[str, List[str]], | ||
batch_size: int = 64, | ||
sentence_length: int = 512, | ||
show_progress_bar: bool = False, | ||
output_value: str = "sentence_embedding", | ||
convert_to_numpy: bool = True, | ||
convert_to_tensor: bool = False, | ||
device: str = None, | ||
normalize_embeddings: bool = False, | ||
) -> Union[List[Tensor], ndarray, Tensor]: | ||
""" | ||
Encode sentences into embeddings with optional configurations. | ||
""" | ||
self.eval() | ||
show_progress_bar = ( | ||
show_progress_bar if show_progress_bar is not None else True | ||
) | ||
convert_to_numpy = convert_to_numpy and not convert_to_tensor | ||
output_value = output_value or "sentence_embedding" | ||
|
||
# Handle input as a list of sentences | ||
input_was_string = isinstance(sentences, str) or not hasattr( | ||
sentences, "__len__" | ||
) | ||
if input_was_string: | ||
sentences = [sentences] | ||
|
||
# Determine the device to use for computation | ||
device = device or self._target_device | ||
self.to(device) | ||
|
||
# Initialize list for embeddings | ||
all_embeddings = [] | ||
length_sorted_idx = np.argsort( | ||
[-self._text_length(sen) for sen in sentences] | ||
) | ||
sentences_sorted = [sentences[idx] for idx in length_sorted_idx] | ||
|
||
# Process sentences in batches | ||
for start_index in trange( | ||
0, | ||
len(sentences), | ||
batch_size, | ||
desc="Batches", | ||
disable=not show_progress_bar, | ||
): | ||
sentences_batch = sentences_sorted[ | ||
start_index : start_index + batch_size | ||
] | ||
features = self.tokenize(sentences_batch) | ||
features = batch_to_device(features, device) | ||
|
||
# Perform forward pass and gather embeddings | ||
with torch.no_grad(): | ||
out_features = self(features) | ||
|
||
if output_value == "token_embeddings": | ||
embeddings = [] | ||
for token_emb, attention in zip( | ||
out_features[output_value], out_features["attention_mask"] | ||
): | ||
last_mask_id = len(attention) - 1 | ||
while ( | ||
last_mask_id > 0 and attention[last_mask_id].item() == 0 | ||
): | ||
last_mask_id -= 1 | ||
embeddings.append(token_emb[0 : last_mask_id + 1]) | ||
elif output_value is None: | ||
embeddings = [] | ||
for sent_idx in range(len(out_features["sentence_embedding"])): | ||
row = { | ||
name: out_features[name][sent_idx] | ||
for name in out_features | ||
} | ||
embeddings.append(row) | ||
else: | ||
embeddings = out_features[output_value] | ||
embeddings = embeddings.detach() | ||
if normalize_embeddings: | ||
embeddings = torch.nn.functional.normalize( | ||
embeddings, p=2, dim=1 | ||
) | ||
if convert_to_numpy: | ||
embeddings = embeddings.cpu() | ||
|
||
all_embeddings.extend(embeddings) | ||
|
||
# Restore original order of sentences | ||
all_embeddings = [ | ||
all_embeddings[idx] for idx in np.argsort(length_sorted_idx) | ||
] | ||
if convert_to_tensor: | ||
all_embeddings = torch.stack(all_embeddings) | ||
elif convert_to_numpy: | ||
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) | ||
|
||
if input_was_string: | ||
all_embeddings = all_embeddings[0] | ||
|
||
return all_embeddings | ||
|
||
def __init__( | ||
self, | ||
inputCol=None, | ||
outputCol=None, | ||
runtime=None, | ||
batchSize=16, | ||
modelName=None, | ||
): | ||
""" | ||
Initialize the HuggingFaceSentenceEmbedder with input/output columns and optional TRT flag. | ||
""" | ||
super(HuggingFaceSentenceEmbedder, self).__init__() | ||
self._setDefault( | ||
runtime="cpu", | ||
modelName=modelName, | ||
batchSize=16, | ||
) | ||
self._set( | ||
inputCol=inputCol, | ||
outputCol=outputCol, | ||
runtime=runtime, | ||
batchSize=batchSize, | ||
modelName=modelName, | ||
) | ||
|
||
# Setter method for batchSize | ||
def setBatchSize(self, value): | ||
self._set(batchSize=value) | ||
return self | ||
|
||
# Getter method for batchSize | ||
def getBatchSize(self): | ||
return self.getOrDefault(self.batchSize) | ||
|
||
# Sets the runtime environment for the model. | ||
# Supported values: 'cpu', 'cuda', 'tensorrt' | ||
def setRuntime(self, value): | ||
""" | ||
Sets the runtime environment for the model. | ||
Supported values: 'cpu', 'cuda', 'tensorrt' | ||
""" | ||
if value not in ["cpu", "cuda", "tensorrt"]: | ||
raise ValueError( | ||
"Invalid runtime specified. Choose from 'cpu', 'cuda', 'tensorrt'" | ||
) | ||
self.setOrDefault(self.runtime, value) | ||
|
||
def getRuntime(self): | ||
return self.getOrDefault(self.runtime) | ||
|
||
# Setter method for modelName | ||
def setModelName(self, value): | ||
self._set(modelName=value) | ||
return self | ||
|
||
# Getter method for modelName | ||
def getModelName(self): | ||
return self.getOrDefault(self.modelName) | ||
|
||
# Optimize the model using Model Navigator with TensorRT configuration. | ||
def _optimize(self, model): | ||
conf = nav.OptimizeConfig( | ||
target_formats=(nav.Format.TENSORRT,), | ||
runners=("TensorRT",), | ||
optimization_profile=nav.OptimizationProfile(max_batch_size=64), | ||
custom_configs=[ | ||
nav.TorchConfig(autocast=True), | ||
nav.TorchScriptConfig(autocast=True), | ||
nav.TensorRTConfig( | ||
precision=(nav.TensorRTPrecision.FP16,), | ||
onnx_parser_flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM.value], | ||
), | ||
], | ||
) | ||
|
||
def _gen_size_chunk(): | ||
""" | ||
Generate chunks of different batch sizes and sentence lengths. | ||
""" | ||
for batch_size in [64]: | ||
for sentence_length in [20, 300, 512]: | ||
yield (batch_size, sentence_length) | ||
|
||
def _get_dataloader(repeat_times: int = 2): | ||
""" | ||
Create a data loader with synthetic data using Faker. | ||
""" | ||
faker = Faker() | ||
i = 0 | ||
for batch_size, chunk_size in _gen_size_chunk(): | ||
for _ in range(repeat_times): | ||
yield ( | ||
i, | ||
( | ||
[ | ||
" ".join(faker.words(chunk_size)) | ||
for _ in range(batch_size) | ||
], | ||
{"show_progress_bar": False}, | ||
), | ||
) | ||
i += 1 | ||
|
||
total_batches = len(list(_gen_size_chunk())) | ||
func = lambda x, **kwargs: self._SentenceTransformerNavigator.encode( | ||
model, x, **kwargs | ||
) | ||
nav.optimize( | ||
func, dataloader=tqdm(_get_dataloader(), total=total_batches), config=conf | ||
) | ||
|
||
def _predict_batch_fn(self): | ||
""" | ||
Create and return a function for batch prediction. | ||
""" | ||
runtime = self.getRuntime() | ||
if "model" not in globals(): | ||
global model | ||
modelName = self.getModelName() | ||
if runtime == "tensorrt": | ||
moduleName = modelName.split("/")[1] | ||
model = self._SentenceTransformerNavigator(modelName).eval() | ||
model = nav.Module(model, name=moduleName) | ||
try: | ||
nav.load_optimized() | ||
except Exception: | ||
self._optimize(model) | ||
nav.load_optimized() | ||
else: | ||
model = SentenceTransformer(modelName).eval() | ||
if runtime == "cuda": | ||
model = model.cuda() | ||
else: | ||
model = model.to("cpu") | ||
|
||
def predict(inputs): | ||
""" | ||
Predict method to encode inputs using the model. | ||
""" | ||
with torch.no_grad(): | ||
output = model.encode( | ||
inputs.tolist(), convert_to_tensor=False, show_progress_bar=False | ||
) | ||
|
||
return output | ||
|
||
return predict | ||
|
||
# Method to apply the transformation to the dataset | ||
def _transform(self, dataset, spark): | ||
""" | ||
Apply the transformation to the input dataset. | ||
""" | ||
input_col = self.getInputCol() | ||
output_col = self.getOutputCol() | ||
|
||
encode = predict_batch_udf( | ||
self._predict_batch_fn, | ||
return_type=ArrayType(FloatType()), | ||
batch_size=self.getBatchSize(), | ||
) | ||
return dataset.withColumn(output_col, encode(input_col)) | ||
|
||
def transform(self, dataset, spark=None): | ||
""" | ||
Public method to transform the dataset. | ||
""" | ||
return self._transform(dataset, spark) | ||
|
||
def copy(self, extra=None): | ||
""" | ||
Create a copy of the transformer. | ||
""" | ||
return self._defaultCopy(extra) | ||
|
||
# Example usage: | ||
# data = input data frame | ||
# transformer = EmbeddingTransformer(inputCol="combined", outputCol="embeddings") | ||
# result = transformer.transform(data) | ||
# result.show() |
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove this dep as previously discussed? you can add a little fake data passage and replicate it if you need