-
Notifications
You must be signed in to change notification settings - Fork 803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add NVIDIA NIM inference adapter #355
Conversation
a5760c0
to
2a25ace
Compare
# the root directory of this source tree. | ||
|
||
from ._config import NVIDIAConfig | ||
from ._nvidia import NVIDIAInferenceAdapter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be a dynamic import within get_adapter_impl()
-- we want configs to be manipulated without needing implementation dependencies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ptal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this PR. So good!
Re: testing, I'd like to have a reproducible e2e test (ala what we have in providers/tests/inference/test_text_inference.py
and providers/tests/inference/test_vision_inference.py
) -- just having an nvidia specific fixture there which could then be invoked as
pytest -s -v --providers inference=nvidia test_text_inference.py --env ...
would be great.
) | ||
|
||
@property | ||
def is_hosted(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could be is_nvidia_hosted
perhaps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's really an internal thing. i've removed it from the NVIDIAConfig api entirely.
@@ -0,0 +1,182 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a nitpick, you can ignore it you feel strongly. we don't usually do underscores in files in the repo - at least not yet. we don't even strongly enforce what symbols get exported out a module (that part is a bit sad admittedly.) could you make the files not have starting underscores?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my inclination is to be cautious about the exported symbols, but it's important to be cohesive w/ the project. i'll change these. ptal.
|
||
from llama_models.datatypes import SamplingParams | ||
from llama_models.llama3.api.datatypes import ( | ||
InterleavedTextMedia, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the explicit imports. we will be code-modding all our other code to do this sane thing soon :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i spent so much time trying to figure out which classes were coming from which packages 😆
CoreModelId.llama3_2_90b_vision_instruct.value, | ||
), | ||
# TODO(mf): how do we handle Nemotron models? | ||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a "base" llama model this model would correspond most closely with? we like to know it because we try to format tools, etc. in a way which the model will work best with. this isn't strictly necessary if the provider / API works very robustly with tool calling, etc. but so far given our experience with various "openai" wrapper APIs, it has been spotty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvidia/llama-3.1-nemotron-51b-instruct (typo in my comment) is https://build.nvidia.com/nvidia/llama-3_1-nemotron-51b-instruct/modelcard
there's now a 70b variant at https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct/modelcard
stream: Optional[bool] = False, | ||
logprobs: Optional[LogProbConfig] = None, | ||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: | ||
raise NotImplementedError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any chance this could be done? it's OK if not, but we have gone back and filled up many of the missing completion() methods also now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me come back and add it in another PR, same for embedding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ashwinb - Are we planning to support that endpoint long-term?
It seems the industry is moving away from it. E.g., OpenAI marked their completions api as legacy https://platform.openai.com/docs/api-reference/completions and Anthropic too: https://docs.anthropic.com/en/api/complete
I'm wondering if it's a good idea to move away from the endpoint as well while llama-stack is still early days.
Otherwise might be painful to maintain it for a long time + deprecate + remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reposting response here for visibility:
Thanks for all your contributions @aidan Do -- we want to keep supporting completions at least for now because we believe having raw access to a model is as important. Unlike other providers, Llamas are open source and people play with them and iterate in a variety of ways. The kinds of manipulations we do with a chat_completion endpoint internally may not be what users intend sometimes. Sometimes, they just want a carefully formatted prompt to hit the model directly.
On that theme, I think it would be great if Groq could build completions endpoints on their end too. But until that time, NotImplementedError() would have to do.
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] | ||
]: | ||
if tool_prompt_format: | ||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
i've updated the PR description to note that this does not cover structured output, vision models, embedding or completion apis. if it's ok, i'll follow up with PRs to add those features.
|
@ashwinb i find
it's an accuracy test because it checks the value of first/last name, birth year, and num seasons. i find that -
suggestions (not mutually exclusive) -
|
@mattf I agree with your comments on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Merging!
# What does this PR do? this PR adds a basic inference adapter to NVIDIA NIMs what it does - - chat completion api - tool calls - streaming - structured output - logprobs - support hosted NIM on integrate.api.nvidia.com - support downloaded NIM containers what it does not do - - completion api - embedding api - vision models - builtin tools - have certainty that sampling strategies are correct ## Feature/Issue validation/testing/test plan `pytest -s -v --providers inference=nvidia llama_stack/providers/tests/inference/ --env NVIDIA_API_KEY=...` all tests should pass. there are pydantic v1 warnings. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Was this discussed/approved via a Github issue? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? - [x] Did you write any new necessary tests? Thanks for contributing 🎉!
out_of_tokens = "out_of_tokens" | ||
""" | ||
|
||
# TODO(mf): are end_of_turn and end_of_message semantics correct? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm implementing a Groq adapter and wondering this as well
): | ||
yield ChatCompletionResponseEventType.start | ||
while True: | ||
yield ChatCompletionResponseEventType.progress |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is such a cool trick
# What does this PR do? Contributes towards issue (#432) - Groq text chat completions - Streaming - All the sampling params that Groq supports A lot of inspiration taken from @mattf's good work at #355 **What this PR does not do** - Tool calls (Future PR) - Adding llama-guard model - See if we can add embeddings ### PR Train - #609 👈 - #630 ## Test Plan <details> <summary>Environment</summary> ```bash export GROQ_API_KEY=<api_key> wget https://raw.githubusercontent.com/aidando73/llama-stack/240e6e2a9c20450ffdcfbabd800a6c0291f19288/build.yaml wget https://raw.githubusercontent.com/aidando73/llama-stack/92c9b5297f9eda6a6e901e1adbd894e169dbb278/run.yaml # Build and run environment pip install -e . \ && llama stack build --config ./build.yaml --image-type conda \ && llama stack run ./run.yaml \ --port 5001 ``` </details> <details> <summary>Manual tests</summary> Using this jupyter notebook to test manually: https://github.com/aidando73/llama-stack/blob/2140976d76ee7ef46025c862b26ee87585381d2a/hello.ipynb Use this code to test passing in the api key from provider_data ``` from llama_stack_client import LlamaStackClient client = LlamaStackClient( base_url="http://localhost:5001", ) response = client.inference.chat_completion( model_id="Llama3.2-3B-Instruct", messages=[ {"role": "user", "content": "Hello, world client!"}, ], # Test passing in groq_api_key from the client # Need to comment out the groq_api_key in the run.yaml file x_llama_stack_provider_data='{"groq_api_key": "<api-key>"}', # stream=True, ) response ``` </details> <details> <summary>Integration</summary> `pytest llama_stack/providers/tests/inference/test_text_inference.py -v -k groq` (run in same environment) ``` llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_3b-groq] PASSED [ 6%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_3b-groq] SKIPPED (Other inf...) [ 12%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[llama_3b-groq] SKIPPED [ 18%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_3b-groq] PASSED [ 25%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_3b-groq] SKIPPED (Ot...) [ 31%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_3b-groq] PASSED [ 37%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_3b-groq] SKIPPED [ 43%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_3b-groq] SKIPPED [ 50%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-groq] PASSED [ 56%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-groq] SKIPPED (Other inf...) [ 62%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[llama_8b-groq] SKIPPED [ 68%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-groq] PASSED [ 75%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-groq] SKIPPED (Ot...) [ 81%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-groq] PASSED [ 87%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-groq] SKIPPED [ 93%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-groq] SKIPPED [100%] ======================================= 6 passed, 10 skipped, 160 deselected, 7 warnings in 2.05s ======================================== ``` </details> <details> <summary>Unit tests</summary> `pytest llama_stack/providers/tests/inference/groq/ -v` ``` llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_sets_model PASSED [ 5%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_user_message PASSED [ 10%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_system_message PASSED [ 15%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_completion_message PASSED [ 20%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_logprobs PASSED [ 25%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_response_format PASSED [ 30%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_repetition_penalty PASSED [ 35%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_stream PASSED [ 40%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_n_is_1 PASSED [ 45%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_if_max_tokens_is_0_then_it_is_not_included PASSED [ 50%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_max_tokens_if_set PASSED [ 55%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_temperature PASSED [ 60%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_top_p PASSED [ 65%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_returns_response PASSED [ 70%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_maps_stop_to_end_of_message PASSED [ 75%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_maps_length_to_end_of_message PASSED [ 80%] llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertStreamChatCompletionResponse::test_returns_stream PASSED [ 85%] llama_stack/providers/tests/inference/groq/test_init.py::TestGroqInit::test_raises_runtime_error_if_config_is_not_groq_config PASSED [ 90%] llama_stack/providers/tests/inference/groq/test_init.py::TestGroqInit::test_returns_groq_adapter PASSED [ 95%] llama_stack/providers/tests/inference/groq/test_init.py::TestGroqConfig::test_api_key_defaults_to_env_var PASSED [100%] ==================================================== 20 passed, 11 warnings in 0.08s ===================================================== ``` </details> ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [x] Updated relevant documentation - [x] Wrote necessary unit or integration tests.
What does this PR do?
this PR adds a basic inference adapter to NVIDIA NIMs
what it does -
what it does not do -
Feature/Issue validation/testing/test plan
pytest -s -v --providers inference=nvidia llama_stack/providers/tests/inference/ --env NVIDIA_API_KEY=...
all tests should pass. there are pydantic v1 warnings.
Before submitting
Pull Request section?
to it if that's the case.
Thanks for contributing 🎉!