Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add custom embedder #2236

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open

Conversation

vonodiripsa
Copy link
Contributor

Related Issues/PRs

#xxx

What changes are proposed in this pull request?

Briefly describe the changes included in this Pull Request.

How is this patch tested?

  • I have written tests (not required for typo or doc fix) and confirmed the proposed feature/bug-fix/change works.

Does this PR change any dependencies?

  • No. You can skip this section.
  • Yes. Make sure the dependencies are resolved correctly, and list changes here.

Does this PR add a new feature? If so, have you added samples on website?

  • No. You can skip this section.
  • Yes. Make sure you have added samples following below steps.
  1. Find the corresponding markdown file for your new feature in website/docs/documentation folder.
    Make sure you choose the correct class estimators/transformers and namespace.
  2. Follow the pattern in markdown file and add another section for your new API, including pyspark, scala (and .NET potentially) samples.
  3. Make sure the DocTable points to correct API link.
  4. Navigate to website folder, and run yarn run start to make sure the website renders correctly.
  5. Don't forget to add <!--pytest-codeblocks:cont--> before each python code blocks to enable auto-tests for python samples.
  6. Make sure the WebsiteSamplesTests job pass in the pipeline.

self,
inputCol=None,
outputCol=None,
useTRTFlag=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: useTRTFlag -> runtime: "cpu", "gpu", "tensorrt", default cpu


# Define additional parameters
useTRT = Param(Params._dummy(), "useTRT", "True if use TRT acceleration")
driverOnly = Param(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove driver Only code

Comment on lines 210 to 211
inputCol="combined",
outputCol="embeddings",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

look at other examples of proper defaults for these columns in library

Comment on lines 306 to 308
for batch_size in [64, 32, 16, 8, 4, 2, 1]:
for sentence_length in [20, 300, 512]:
yield (batch_size, sentence_length)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make these magic numbers, parameters with defaults

"""
Create a data loader with synthetic data using Faker.
"""
faker = Faker()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: lets try to remove this dependency

for sentence_length in [20, 300, 512]:
yield (batch_size, sentence_length)

def get_dataloader(repeat_times: int = 2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: _get_dataloader

func, dataloader=tqdm(get_dataloader(), total=total_batches), config=conf
)

def run_on_driver(self, queries, spark):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise _

"""
return self._defaultCopy(extra)

def load_data_food_reviews(self, spark, path=None, limit=1000):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this code into demo

Comment on lines 15 to 30
class SuppressLogging:
def __init__(self):
self._original_stderr = None

def start(self):
"""Start suppressing logging by redirecting sys.stderr to /dev/null."""
if self._original_stderr is None:
self._original_stderr = sys.stderr
sys.stderr = open('/dev/null', 'w')

def stop(self):
"""Stop suppressing logging and restore sys.stderr."""
if self._original_stderr is not None:
sys.stderr.close()
sys.stderr = self._original_stderr
self._original_stderr = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

FloatType,
)

class EmbeddingTransformer(Transformer, HasInputCol, HasOutputCol):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: HuggingFaceSentenceEmbedder

Also name the file HuggingFaceSentenceEmbedder.py

Comment on lines 202 to 203
modelName="intfloat/e5-large-v2",
moduleName="e5-large-v2",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no defaults here, and try to make this module Name thing go away

Initialize the EmbeddingTransformer with input/output columns and optional TRT flag.
"""
super(EmbeddingTransformer, self).__init__()
self._setDefault(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/databricks/python/bin/pip install --extra-index-url https://pypi.nvidia.com cudf-cu11~=${RAPIDS_VERSION} cuml-cu11~=${RAPIDS_VERSION} pylibraft-cu11~=${RAPIDS_VERSION} rmm-cu11~=${RAPIDS_VERSION}

# install model navigator
/databricks/python/bin/pip install --extra-index-url https://pypi.nvidia.com onnxruntime-gpu==1.16.3 "tensorrt==9.3.0.post12.dev1" "triton-model-navigator<1" "sentence_transformers~=2.2.2" "faker" "urllib3<2"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove faker

@bvonodiripsa bvonodiripsa changed the title Feat: Add custom embedder feat: Add custom embedder Jun 18, 2024
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import batch_to_device
from pyspark.ml.functions import predict_batch_udf
from faker import Faker
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this dep as previously discussed? you can add a little fake data passage and replicate it if you need

Comment on lines 85 to 91
"logging.getLogger(\"sentence_transformers.SentenceTransformer\").setLevel(logging.ERROR)\n",
"mlflow.autolog(disable=True)\n",
"\n",
"# Record the start time\n",
"start_time = datetime.now()\n",
"\n",
"print(f\"Demo started\")"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probabbly dont need this stuff for a minimal demo

"warnings.filterwarnings(\"ignore\", category=UserWarning, module=\"tritonclient.grpc\")\n",
"import logging\n",
"\n",
"logging.getLogger(\"py4j\").setLevel(logging.ERROR)\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need this line?

Comment on lines 145 to 174
"number_of_input_rows = 999\n",
"# Shuffle the DataFrame with a fixed seed\n",
"seed = 42\n",
"\n",
"# Check if the row count is less than 10\n",
"if number_of_input_rows <= 0 or number_of_input_rows >= 1000000:\n",
" raise ValueError(f\"Limit is {number_of_input_rows}, which should be less than 1M.\")\n",
"\n",
"if number_of_input_rows > 1000:\n",
"\n",
" # Cross-join the DataFrame with itself to create n x n pairs for string concatenation (synthetic data)\n",
" cross_joined_df = df.crossJoin(\n",
" df.withColumnRenamed(\"combined\", \"combined_\")\n",
" )\n",
"\n",
" # Create a new column 'result_vector' by concatenating the two source vectors\n",
" tmp_df = cross_joined_df.withColumn(\n",
" \"result_vector\",\n",
" F.concat(F.col(\"combined\"), F.lit(\". \\n\"), F.col(\"combined_\")),\n",
" )\n",
"\n",
" # Select only the necessary columns and show the result\n",
" tmp_df = tmp_df.select(\"result_vector\")\n",
" df = tmp_df.withColumnRenamed(\"result_vector\", \"combined\").withColumn(\n",
" \"id\", monotonically_increasing_id()\n",
" )\n",
"\n",
"df = df.limit(number_of_input_rows).orderBy(rand(seed)).repartition(10).cache()\n",
"\n",
"print(f\"Loaded: {number_of_input_rows} rows\")"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probabbly can remove the cross join stuff for the demo, i would rather use a large dataset and subset than a small dataset and augment

"outputs": [],
"source": [
"# dataTransformer = HuggingFaceSentenceEmbedder(modelName=\"intfloat/e5-large-v2\", inputCol=\"combined\", outputCol=\"embeddings\", runtime=\"tensorrt\")\n",
"dataTransformer = HuggingFaceSentenceEmbedder(modelName=\"sentence-transformers/all-MiniLM-L6-v2\", inputCol=\"combined\", outputCol=\"embeddings\", runtime=\"tensorrt\")\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dataTransformer -> embedder

"# dataTransformer = HuggingFaceSentenceEmbedder(modelName=\"intfloat/e5-large-v2\", inputCol=\"combined\", outputCol=\"embeddings\", runtime=\"tensorrt\")\n",
"dataTransformer = HuggingFaceSentenceEmbedder(modelName=\"sentence-transformers/all-MiniLM-L6-v2\", inputCol=\"combined\", outputCol=\"embeddings\", runtime=\"tensorrt\")\n",
"\n",
"all_embeddings = dataTransformer.transform(df).cache()"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: all_embeddings -> embeddings

Comment on lines 253 to 266
"queries = [\"desserts\", \"disgusting\"]\n",
"ids = [1, 2]\n",
"\n",
"# Combine the data into a list of tuples\n",
"data = list(zip(ids, queries))\n",
"\n",
"# Define the schema for the DataFrame\n",
"schema = StructType([\n",
" StructField(\"id\", IntegerType(), nullable=False),\n",
" StructField(\"query\", StringType(), nullable=False)\n",
"])\n",
"\n",
"# Create the DataFrame\n",
"qDf = spark.createDataFrame(data, schema)\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can probabbly make this smaller by sayinf

test_data = spark.createDataFrame([("desserts", 1), ("disgusting", 2)], ["query", "id"])

"qDf = spark.createDataFrame(data, schema)\n",
"\n",
"# queryTransformer = HuggingFaceSentenceEmbedder(modelName=\"intfloat/e5-large-v2\", inputCol=\"query\", outputCol=\"embeddings\", runtime=\"cpu\")\n",
"queryTransformer = HuggingFaceSentenceEmbedder(modelName=\"sentence-transformers/all-MiniLM-L6-v2\", inputCol=\"query\", outputCol=\"embeddings\", runtime=\"cpu\")\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use the embedder you already made above

"outputs": [],
"source": [
"# dataTransformer = HuggingFaceSentenceEmbedder(modelName=\"intfloat/e5-large-v2\", inputCol=\"combined\", outputCol=\"embeddings\", runtime=\"tensorrt\")\n",
"dataTransformer = HuggingFaceSentenceEmbedder(modelName=\"sentence-transformers/all-MiniLM-L6-v2\", inputCol=\"combined\", outputCol=\"embeddings\", runtime=\"tensorrt\")\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: make model name a param, and then pass model name, give people a few to try

Comment on lines 327 to 330
"rapids_knn = ApproximateNearestNeighbors(k=5)\n",
"rapids_knn.setInputCol(\"embeddings\").setIdCol(\"id\")\n",
"\n",
"rapids_knn_model = rapids_knn.fit(all_embeddings.select(\"id\", \"embeddings\"))"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can make this a single statement with parentheses and dot chaining

"source": [
"## Step 6: Find top k Nearest Neighbors\n",
"\n",
"We will use fast ANN IVFFlat algorithm from Rapids"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets link to the page explaining this algo

Comment on lines 430 to 440
"print(f\"Demo finished\")\n",
"\n",
"# Record the end time\n",
"end_time = datetime.now()\n",
"\n",
"# Calculate the duration\n",
"duration = end_time - start_time\n",
"\n",
"# Optionally, display the duration in seconds\n",
"duration_in_seconds = duration.total_seconds()\n",
"print(f\"Application duration: {duration_in_seconds:.2f} seconds\")"
Copy link
Collaborator

@mhamilton723 mhamilton723 Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont worry about timing for the demo, instead add a markdown cell with your timing results maybe up top or down here. If you want the takeaway to be that this is ultra fast, heres where you can show people the results

@mhamilton723
Copy link
Collaborator

/azp run

Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@codecov-commenter
Copy link

codecov-commenter commented Jul 3, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 85.36%. Comparing base (440f18e) to head (7280ea7).
Report is 4 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2236      +/-   ##
==========================================
+ Coverage   84.43%   85.36%   +0.92%     
==========================================
  Files         327      327              
  Lines       16715    16742      +27     
  Branches     1495     1509      +14     
==========================================
+ Hits        14114    14291     +177     
+ Misses       2601     2451     -150     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants