diff --git a/examples/sagemaker/run-nomic-embed-text.ipynb b/examples/sagemaker/run-nomic-embed-text.ipynb index 9c3fb412..f9bff2a3 100644 --- a/examples/sagemaker/run-nomic-embed-text.ipynb +++ b/examples/sagemaker/run-nomic-embed-text.ipynb @@ -145,7 +145,7 @@ } ], "source": [ - "response = embed_text(texts, endpoint_name, region_name=region_name, batch_size=32, dimensionality=128)\n", + "response = embed_text(texts, endpoint_name, region_name=region_name, batch_size=32, dimensionality=128, task_type=\"search_document\")\n", "embeddings = response[\"embeddings\"]\n", "np.array(embeddings).shape" ] diff --git a/nomic/aws/sagemaker.py b/nomic/aws/sagemaker.py index efd8eb0b..479b074f 100644 --- a/nomic/aws/sagemaker.py +++ b/nomic/aws/sagemaker.py @@ -38,26 +38,6 @@ def parse_sagemaker_response(response): return resp["embeddings"] -def preprocess_texts(texts: List[str], task_type: str = "search_document"): - """ - Preprocess a list of texts for embedding using a sagemaker model. - - Args: - texts: List of texts to be embedded. - task_type: The task type to use when embedding. One of `search_query`, `search_document`, `classification`, `clustering` - - Returns: - List of texts formatted for sagemaker embedding. - """ - assert task_type in [ - "search_query", - "search_document", - "classification", - "clustering", - ], f"Invalid task type: {task_type}" - return [f"{task_type}: {text}" for text in texts] - - def batch_transform_text( s3_input_path: str, s3_output_path: str, @@ -157,7 +137,13 @@ def embed_text( logger.warning("No texts to embed.") return None - texts = preprocess_texts(texts, task_type) + assert task_type in [ + "search_query", + "search_document", + "classification", + "clustering", + ], f"Invalid task type: {task_type}" + assert dimensionality in ( 64, 128, @@ -175,6 +161,7 @@ def embed_text( "texts": texts[i : i + batch_size], "binary": binary, "dimensionality": dimensionality, + "task_type": task_type, } ) response = client.invoke_endpoint(EndpointName=sagemaker_endpoint, Body=batch, ContentType="application/json") diff --git a/setup.py b/setup.py index 7ef94bb1..cc348e5f 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ "pylint", "pytest", "isort", - "pyright", + "pyright<=1.1.377", "myst-parser", "mkdocs-material", "mkautodoc",