-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add custom embedder #2236
base: master
Are you sure you want to change the base?
Conversation
self, | ||
inputCol=None, | ||
outputCol=None, | ||
useTRTFlag=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: useTRTFlag -> runtime: "cpu", "gpu", "tensorrt", default cpu
|
||
# Define additional parameters | ||
useTRT = Param(Params._dummy(), "useTRT", "True if use TRT acceleration") | ||
driverOnly = Param( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove driver Only code
inputCol="combined", | ||
outputCol="embeddings", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
look at other examples of proper defaults for these columns in library
for batch_size in [64, 32, 16, 8, 4, 2, 1]: | ||
for sentence_length in [20, 300, 512]: | ||
yield (batch_size, sentence_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make these magic numbers, parameters with defaults
""" | ||
Create a data loader with synthetic data using Faker. | ||
""" | ||
faker = Faker() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: lets try to remove this dependency
for sentence_length in [20, 300, 512]: | ||
yield (batch_size, sentence_length) | ||
|
||
def get_dataloader(repeat_times: int = 2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: _get_dataloader
func, dataloader=tqdm(get_dataloader(), total=total_batches), config=conf | ||
) | ||
|
||
def run_on_driver(self, queries, spark): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
likewise _
""" | ||
return self._defaultCopy(extra) | ||
|
||
def load_data_food_reviews(self, spark, path=None, limit=1000): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this code into demo
class SuppressLogging: | ||
def __init__(self): | ||
self._original_stderr = None | ||
|
||
def start(self): | ||
"""Start suppressing logging by redirecting sys.stderr to /dev/null.""" | ||
if self._original_stderr is None: | ||
self._original_stderr = sys.stderr | ||
sys.stderr = open('/dev/null', 'w') | ||
|
||
def stop(self): | ||
"""Stop suppressing logging and restore sys.stderr.""" | ||
if self._original_stderr is not None: | ||
sys.stderr.close() | ||
sys.stderr = self._original_stderr | ||
self._original_stderr = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
FloatType, | ||
) | ||
|
||
class EmbeddingTransformer(Transformer, HasInputCol, HasOutputCol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: HuggingFaceSentenceEmbedder
Also name the file HuggingFaceSentenceEmbedder.py
modelName="intfloat/e5-large-v2", | ||
moduleName="e5-large-v2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no defaults here, and try to make this module Name thing go away
Initialize the EmbeddingTransformer with input/output columns and optional TRT flag. | ||
""" | ||
super(EmbeddingTransformer, self).__init__() | ||
self._setDefault( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try it on some other models from : https://sbert.net/docs/sentence_transformer/pretrained_models.html
tools/init_scripts/init_retriever.sh
Outdated
/databricks/python/bin/pip install --extra-index-url https://pypi.nvidia.com cudf-cu11~=${RAPIDS_VERSION} cuml-cu11~=${RAPIDS_VERSION} pylibraft-cu11~=${RAPIDS_VERSION} rmm-cu11~=${RAPIDS_VERSION} | ||
|
||
# install model navigator | ||
/databricks/python/bin/pip install --extra-index-url https://pypi.nvidia.com onnxruntime-gpu==1.16.3 "tensorrt==9.3.0.post12.dev1" "triton-model-navigator<1" "sentence_transformers~=2.2.2" "faker" "urllib3<2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove faker
from sentence_transformers import SentenceTransformer | ||
from sentence_transformers.util import batch_to_device | ||
from pyspark.ml.functions import predict_batch_udf | ||
from faker import Faker |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove this dep as previously discussed? you can add a little fake data passage and replicate it if you need
"logging.getLogger(\"sentence_transformers.SentenceTransformer\").setLevel(logging.ERROR)\n", | ||
"mlflow.autolog(disable=True)\n", | ||
"\n", | ||
"# Record the start time\n", | ||
"start_time = datetime.now()\n", | ||
"\n", | ||
"print(f\"Demo started\")" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probabbly dont need this stuff for a minimal demo
"warnings.filterwarnings(\"ignore\", category=UserWarning, module=\"tritonclient.grpc\")\n", | ||
"import logging\n", | ||
"\n", | ||
"logging.getLogger(\"py4j\").setLevel(logging.ERROR)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need this line?
"number_of_input_rows = 999\n", | ||
"# Shuffle the DataFrame with a fixed seed\n", | ||
"seed = 42\n", | ||
"\n", | ||
"# Check if the row count is less than 10\n", | ||
"if number_of_input_rows <= 0 or number_of_input_rows >= 1000000:\n", | ||
" raise ValueError(f\"Limit is {number_of_input_rows}, which should be less than 1M.\")\n", | ||
"\n", | ||
"if number_of_input_rows > 1000:\n", | ||
"\n", | ||
" # Cross-join the DataFrame with itself to create n x n pairs for string concatenation (synthetic data)\n", | ||
" cross_joined_df = df.crossJoin(\n", | ||
" df.withColumnRenamed(\"combined\", \"combined_\")\n", | ||
" )\n", | ||
"\n", | ||
" # Create a new column 'result_vector' by concatenating the two source vectors\n", | ||
" tmp_df = cross_joined_df.withColumn(\n", | ||
" \"result_vector\",\n", | ||
" F.concat(F.col(\"combined\"), F.lit(\". \\n\"), F.col(\"combined_\")),\n", | ||
" )\n", | ||
"\n", | ||
" # Select only the necessary columns and show the result\n", | ||
" tmp_df = tmp_df.select(\"result_vector\")\n", | ||
" df = tmp_df.withColumnRenamed(\"result_vector\", \"combined\").withColumn(\n", | ||
" \"id\", monotonically_increasing_id()\n", | ||
" )\n", | ||
"\n", | ||
"df = df.limit(number_of_input_rows).orderBy(rand(seed)).repartition(10).cache()\n", | ||
"\n", | ||
"print(f\"Loaded: {number_of_input_rows} rows\")" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probabbly can remove the cross join stuff for the demo, i would rather use a large dataset and subset than a small dataset and augment
"outputs": [], | ||
"source": [ | ||
"# dataTransformer = HuggingFaceSentenceEmbedder(modelName=\"intfloat/e5-large-v2\", inputCol=\"combined\", outputCol=\"embeddings\", runtime=\"tensorrt\")\n", | ||
"dataTransformer = HuggingFaceSentenceEmbedder(modelName=\"sentence-transformers/all-MiniLM-L6-v2\", inputCol=\"combined\", outputCol=\"embeddings\", runtime=\"tensorrt\")\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: dataTransformer -> embedder
"# dataTransformer = HuggingFaceSentenceEmbedder(modelName=\"intfloat/e5-large-v2\", inputCol=\"combined\", outputCol=\"embeddings\", runtime=\"tensorrt\")\n", | ||
"dataTransformer = HuggingFaceSentenceEmbedder(modelName=\"sentence-transformers/all-MiniLM-L6-v2\", inputCol=\"combined\", outputCol=\"embeddings\", runtime=\"tensorrt\")\n", | ||
"\n", | ||
"all_embeddings = dataTransformer.transform(df).cache()" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: all_embeddings -> embeddings
"queries = [\"desserts\", \"disgusting\"]\n", | ||
"ids = [1, 2]\n", | ||
"\n", | ||
"# Combine the data into a list of tuples\n", | ||
"data = list(zip(ids, queries))\n", | ||
"\n", | ||
"# Define the schema for the DataFrame\n", | ||
"schema = StructType([\n", | ||
" StructField(\"id\", IntegerType(), nullable=False),\n", | ||
" StructField(\"query\", StringType(), nullable=False)\n", | ||
"])\n", | ||
"\n", | ||
"# Create the DataFrame\n", | ||
"qDf = spark.createDataFrame(data, schema)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can probabbly make this smaller by sayinf
test_data = spark.createDataFrame([("desserts", 1), ("disgusting", 2)], ["query", "id"])
"qDf = spark.createDataFrame(data, schema)\n", | ||
"\n", | ||
"# queryTransformer = HuggingFaceSentenceEmbedder(modelName=\"intfloat/e5-large-v2\", inputCol=\"query\", outputCol=\"embeddings\", runtime=\"cpu\")\n", | ||
"queryTransformer = HuggingFaceSentenceEmbedder(modelName=\"sentence-transformers/all-MiniLM-L6-v2\", inputCol=\"query\", outputCol=\"embeddings\", runtime=\"cpu\")\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use the embedder you already made above
"outputs": [], | ||
"source": [ | ||
"# dataTransformer = HuggingFaceSentenceEmbedder(modelName=\"intfloat/e5-large-v2\", inputCol=\"combined\", outputCol=\"embeddings\", runtime=\"tensorrt\")\n", | ||
"dataTransformer = HuggingFaceSentenceEmbedder(modelName=\"sentence-transformers/all-MiniLM-L6-v2\", inputCol=\"combined\", outputCol=\"embeddings\", runtime=\"tensorrt\")\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: make model name a param, and then pass model name, give people a few to try
"rapids_knn = ApproximateNearestNeighbors(k=5)\n", | ||
"rapids_knn.setInputCol(\"embeddings\").setIdCol(\"id\")\n", | ||
"\n", | ||
"rapids_knn_model = rapids_knn.fit(all_embeddings.select(\"id\", \"embeddings\"))" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can make this a single statement with parentheses and dot chaining
"source": [ | ||
"## Step 6: Find top k Nearest Neighbors\n", | ||
"\n", | ||
"We will use fast ANN IVFFlat algorithm from Rapids" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets link to the page explaining this algo
"print(f\"Demo finished\")\n", | ||
"\n", | ||
"# Record the end time\n", | ||
"end_time = datetime.now()\n", | ||
"\n", | ||
"# Calculate the duration\n", | ||
"duration = end_time - start_time\n", | ||
"\n", | ||
"# Optionally, display the duration in seconds\n", | ||
"duration_in_seconds = duration.total_seconds()\n", | ||
"print(f\"Application duration: {duration_in_seconds:.2f} seconds\")" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont worry about timing for the demo, instead add a markdown cell with your timing results maybe up top or down here. If you want the takeaway to be that this is ultra fast, heres where you can show people the results
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #2236 +/- ##
==========================================
+ Coverage 84.43% 85.36% +0.92%
==========================================
Files 327 327
Lines 16715 16742 +27
Branches 1495 1509 +14
==========================================
+ Hits 14114 14291 +177
+ Misses 2601 2451 -150 ☔ View full report in Codecov by Sentry. |
Related Issues/PRs
#xxx
What changes are proposed in this pull request?
Briefly describe the changes included in this Pull Request.
How is this patch tested?
Does this PR change any dependencies?
Does this PR add a new feature? If so, have you added samples on website?
website/docs/documentation
folder.Make sure you choose the correct class
estimators/transformers
and namespace.DocTable
points to correct API link.yarn run start
to make sure the website renders correctly.<!--pytest-codeblocks:cont-->
before each python code blocks to enable auto-tests for python samples.WebsiteSamplesTests
job pass in the pipeline.