-
Notifications
You must be signed in to change notification settings - Fork 803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement embedding generation in supported inference providers #589
implement embedding generation in supported inference providers #589
Conversation
9f13e0b
to
600b5a7
Compare
2c233f0
to
2b6aa71
Compare
600b5a7
to
e167e9e
Compare
raise NotImplementedError() | ||
model = await self.model_store.get_model(model_id) | ||
r = self._get_client().embeddings.create( | ||
model=model.provider_resource_id, input=contents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this actually work with the InterleavedTextMedia
type as is? that seems impossible to believe actually. What if there's an image in there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
most providers accept a str or List[str] for input. I think most of the embedding models are text based and not even support images. I will update this to handle the input based on the supported input types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just realized we have interleaved_text_media_as_str. using that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bedrock actually supports embeddings for image: https://aws.amazon.com/blogs/machine-learning/build-a-reverse-image-search-engine-with-amazon-titan-multimodal-embeddings-in-amazon-bedrock-and-aws-managed-services/#:~:text=To%20convert%20images%20to%20vectors,optimize%20for%20speed%20and%20performance.
|
||
async def register_model(self, model: Model) -> Model: | ||
# ollama does not have embedding models running. Check if the model is in list of available models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the sub-comment "does not have embedding models running" -- why are embedding models somehow special? why is this same block of code not applicable to other models here also?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is a bit wierd. for regular models, you need an explicit ollama run model_name
to be able to use the model. but for embedding models, you dont need to do a run. You can just directly call the embeddings API as long as the embedding model has been pulled.
But the issue is, for embedding models, i wont show up in ollama ps. so, we need to do a list to get the currently downloaded/pulled models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dineshyv I think we should talk to the Ollama folks around this on the Discord and see if they have any suggestions.
llama_stack/providers/inline/inference/meta_reference/inference.py
Outdated
Show resolved
Hide resolved
llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bunch of comments inline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed feedback. For now, I have made all providers assert that there is no media in content. Adding multimodal support is bit involved depending on the provider and I would like to tackle that as a follow up.
100% agreed. |
from pydantic import BaseModel | ||
|
||
|
||
class SentenceTransformersInferenceConfig(BaseModel): ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this has no model
field and yet the impl looks for it. BUG!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah in the check_model function? that function is unused. removing it.
|
||
|
||
class SentenceTransformersInferenceImpl( | ||
SentenceTransformerEmbeddingMixin, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to act on it, but I don't think we need a mixin because there isn't much state. we could just have free-floating utility functions for this thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approving, but there's one bug in sentence transformer config
@@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]: | |||
module="llama_stack.providers.inline.inference.vllm", | |||
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig", | |||
), | |||
InlineProviderSpec( | |||
api=Api.inference, | |||
provider_type="inline::sentence-transformers", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we provide inline::sentence-transformers
as a default inference provider for backward compatibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metareference already supports embedding generation through sentence transformers. so its not needed. this is for cases like TGI which needs embedding generation only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yanxi0830 inline::meta-reference
also implements this functionality and is how we get backwards compatibility. But are you saying this for distributions which don't use inline::meta-reference
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dineshyv basically Xi might be suggesting that our templates should have two inference providers now for backwards compat. that does complicate things a bit because you need to register the embedding model with a specific provider but it's a good point so the client continues to work as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g. how do we get the embeddings from all-MiniLM-L6-v2
when we are using together
for inference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, will update templates in a follow up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, faiss can work with together's embedding models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The explicit sentence transformer provider is only for cases like TGI where there is no way to have a hosted embedding model .
# What does this PR do? Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models ## Test Plan torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.2-3B-Instruct" --embedding-model="sentence-transformers/all-MiniLM-L6-v2" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=weaviate" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> --env WEAVIATE_API_KEY=foo --env WEAVIATE_CLUSTER_URL=bar pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=chroma" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>--env CHROMA_HOST=localhost --env CHROMA_PORT=8000 pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=pgvector" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=faiss" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>
9d082d9
into
revert-605-revert-588-add-model-type
This PR adds the ability to generate embeddings in all supported inference providers. ``` pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py -k "bedrock" --inference-model="amazon.titan-embed-text-v2:0" --env EMBEDDING_DIMENSION=1024 pytest -v -s -k "vllm" --inferrence-model="intfloat/e5-mistral-7b-instruct" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=4096 --env VLLM_URL="http://localhost:9798/v1" pytest -v -s --inference-model="nomic-ai/nomic-embed-text-v1.5" llama_stack/providers/tests/inference/test_embeddings.py -k "fireworks" --env FIREWORKS_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=128 pytest -v -s --inference-model="togethercomputer/m2-bert-80M-2k-retrieval" llama_stack/providers/tests/inference/test_embeddings.py -k "together" --env TOGETHER_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=768 pytest -v -s -k "ollama" --inference-model="all-minilm:v8" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="sentence-transformers/all-MiniLM-L6-v2" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 ```
This PR does the following: 1) adds the ability to generate embeddings in all supported inference providers. 2) Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models This is a merge from #589 and #598
What does this PR do?
This PR adds the ability to generate embeddings in all supported inference providers.
Test Plan