Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement embedding generation in supported inference providers #589

Merged

Conversation

dineshyv
Copy link
Contributor

@dineshyv dineshyv commented Dec 9, 2024

What does this PR do?

This PR adds the ability to generate embeddings in all supported inference providers.

Test Plan

pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py -k "bedrock" --inference-model="amazon.titan-embed-text-v2:0"  --env EMBEDDING_DIMENSION=1024


 pytest -v -s -k "vllm"  --inferrence-model="intfloat/e5-mistral-7b-instruct"  llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=4096  --env VLLM_URL="http://localhost:9798/v1"


pytest -v -s --inference-model="nomic-ai/nomic-embed-text-v1.5"  llama_stack/providers/tests/inference/test_embeddings.py  -k "fireworks"  --env FIREWORKS_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=128

pytest -v -s --inference-model="togethercomputer/m2-bert-80M-2k-retrieval"  llama_stack/providers/tests/inference/test_embeddings.py  -k "together"  --env TOGETHER_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=768


pytest -v -s -k "ollama"  --inference-model="all-minilm:v8"  llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384

 torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="sentence-transformers/all-MiniLM-L6-v2"  llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384


@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 9, 2024
@dineshyv dineshyv force-pushed the support-embeddings-inference branch from 9f13e0b to 600b5a7 Compare December 11, 2024 00:58
Base automatically changed from add-model-type to main December 11, 2024 18:16
@dineshyv dineshyv changed the base branch from main to revert-605-revert-588-add-model-type December 11, 2024 18:18
@dineshyv dineshyv force-pushed the support-embeddings-inference branch from 600b5a7 to e167e9e Compare December 11, 2024 19:15
raise NotImplementedError()
model = await self.model_store.get_model(model_id)
r = self._get_client().embeddings.create(
model=model.provider_resource_id, input=contents
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this actually work with the InterleavedTextMedia type as is? that seems impossible to believe actually. What if there's an image in there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most providers accept a str or List[str] for input. I think most of the embedding models are text based and not even support images. I will update this to handle the input based on the supported input types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realized we have interleaved_text_media_as_str. using that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


async def register_model(self, model: Model) -> Model:
# ollama does not have embedding models running. Check if the model is in list of available models.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the sub-comment "does not have embedding models running" -- why are embedding models somehow special? why is this same block of code not applicable to other models here also?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is a bit wierd. for regular models, you need an explicit ollama run model_nameto be able to use the model. but for embedding models, you dont need to do a run. You can just directly call the embeddings API as long as the embedding model has been pulled.
But the issue is, for embedding models, i wont show up in ollama ps. so, we need to do a list to get the currently downloaded/pulled models.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dineshyv I think we should talk to the Ollama folks around this on the Discord and see if they have any suggestions.

Copy link
Contributor

@ashwinb ashwinb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bunch of comments inline

Copy link
Contributor Author

@dineshyv dineshyv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed feedback. For now, I have made all providers assert that there is no media in content. Adding multimodal support is bit involved depending on the provider and I would like to tackle that as a follow up.

@ashwinb
Copy link
Contributor

ashwinb commented Dec 12, 2024

Addressed feedback. For now, I have made all providers assert that there is no media in content. Adding multimodal support is bit involved depending on the provider and I would like to tackle that as a follow up.

100% agreed.

from pydantic import BaseModel


class SentenceTransformersInferenceConfig(BaseModel): ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has no model field and yet the impl looks for it. BUG!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah in the check_model function? that function is unused. removing it.



class SentenceTransformersInferenceImpl(
SentenceTransformerEmbeddingMixin,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to act on it, but I don't think we need a mixin because there isn't much state. we could just have free-floating utility functions for this thing.

Copy link
Contributor

@ashwinb ashwinb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approving, but there's one bug in sentence transformer config

@@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.inference.vllm",
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="inline::sentence-transformers",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we provide inline::sentence-transformers as a default inference provider for backward compatibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metareference already supports embedding generation through sentence transformers. so its not needed. this is for cases like TGI which needs embedding generation only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanxi0830 inline::meta-reference also implements this functionality and is how we get backwards compatibility. But are you saying this for distributions which don't use inline::meta-reference?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dineshyv basically Xi might be suggesting that our templates should have two inference providers now for backwards compat. that does complicate things a bit because you need to register the embedding model with a specific provider but it's a good point so the client continues to work as is.

Copy link
Contributor

@yanxi0830 yanxi0830 Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. how do we get the embeddings from all-MiniLM-L6-v2 when we are using together for inference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will update templates in a follow up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, faiss can work with together's embedding models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The explicit sentence transformer provider is only for cases like TGI where there is no way to have a hosted embedding model .

# What does this PR do?
Moves all the memory providers to use the inference API and improved the
memory tests to setup the inference stack correctly and use the
embedding models


## Test Plan
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.2-3B-Instruct"
--embedding-model="sentence-transformers/all-MiniLM-L6-v2"
llama_stack/providers/tests/inference/test_embeddings.py --env
EMBEDDING_DIMENSION=384


pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=weaviate"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> --env
WEAVIATE_API_KEY=foo --env WEAVIATE_CLUSTER_URL=bar
 
pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=chroma"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>--env
CHROMA_HOST=localhost --env CHROMA_PORT=8000

pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=pgvector"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env
PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>

pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=faiss"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>
@dineshyv dineshyv merged commit 9d082d9 into revert-605-revert-588-add-model-type Dec 12, 2024
2 checks passed
@dineshyv dineshyv deleted the support-embeddings-inference branch December 12, 2024 19:17
dineshyv added a commit that referenced this pull request Dec 12, 2024
This PR adds the ability to generate embeddings in all supported
inference providers.

```
pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py -k "bedrock" --inference-model="amazon.titan-embed-text-v2:0"  --env EMBEDDING_DIMENSION=1024

 pytest -v -s -k "vllm"  --inferrence-model="intfloat/e5-mistral-7b-instruct"  llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=4096  --env VLLM_URL="http://localhost:9798/v1"

pytest -v -s --inference-model="nomic-ai/nomic-embed-text-v1.5"  llama_stack/providers/tests/inference/test_embeddings.py  -k "fireworks"  --env FIREWORKS_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=128

pytest -v -s --inference-model="togethercomputer/m2-bert-80M-2k-retrieval"  llama_stack/providers/tests/inference/test_embeddings.py  -k "together"  --env TOGETHER_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=768

pytest -v -s -k "ollama"  --inference-model="all-minilm:v8"  llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384

 torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="sentence-transformers/all-MiniLM-L6-v2"  llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384

```
dineshyv added a commit that referenced this pull request Dec 12, 2024
This PR does the following:
1) adds the ability to generate embeddings in all supported inference
providers.
2) Moves all the memory providers to use the inference API and improved
the memory tests to setup the inference stack correctly and use the
embedding models

This is a merge from #589 and #598
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants