Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add NVIDIA NIM inference adapter #355

Merged
merged 19 commits into from
Nov 23, 2024

Conversation

mattf
Copy link
Contributor

@mattf mattf commented Nov 1, 2024

What does this PR do?

this PR adds a basic inference adapter to NVIDIA NIMs

what it does -

  • chat completion api
    • tool calls
    • streaming
    • structured output
    • logprobs
  • support hosted NIM on integrate.api.nvidia.com
  • support downloaded NIM containers

what it does not do -

  • completion api
  • embedding api
  • vision models
  • builtin tools
  • have certainty that sampling strategies are correct

Feature/Issue validation/testing/test plan

pytest -s -v --providers inference=nvidia llama_stack/providers/tests/inference/ --env NVIDIA_API_KEY=...

all tests should pass. there are pydantic v1 warnings.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Thanks for contributing 🎉!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 1, 2024
@mattf mattf force-pushed the add-nvidia-inference-adapter branch from a5760c0 to 2a25ace Compare November 19, 2024 16:38
# the root directory of this source tree.

from ._config import NVIDIAConfig
from ._nvidia import NVIDIAInferenceAdapter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a dynamic import within get_adapter_impl() -- we want configs to be manipulated without needing implementation dependencies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptal

Copy link
Contributor

@ashwinb ashwinb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this PR. So good!

Re: testing, I'd like to have a reproducible e2e test (ala what we have in providers/tests/inference/test_text_inference.py and providers/tests/inference/test_vision_inference.py) -- just having an nvidia specific fixture there which could then be invoked as

pytest -s -v --providers inference=nvidia test_text_inference.py --env ...

would be great.

)

@property
def is_hosted(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be is_nvidia_hosted perhaps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's really an internal thing. i've removed it from the NVIDIAConfig api entirely.

@@ -0,0 +1,182 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a nitpick, you can ignore it you feel strongly. we don't usually do underscores in files in the repo - at least not yet. we don't even strongly enforce what symbols get exported out a module (that part is a bit sad admittedly.) could you make the files not have starting underscores?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my inclination is to be cautious about the exported symbols, but it's important to be cohesive w/ the project. i'll change these. ptal.


from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explicit imports. we will be code-modding all our other code to do this sane thing soon :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i spent so much time trying to figure out which classes were coming from which packages 😆

CoreModelId.llama3_2_90b_vision_instruct.value,
),
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a "base" llama model this model would correspond most closely with? we like to know it because we try to format tools, etc. in a way which the model will work best with. this isn't strictly necessary if the provider / API works very robustly with tool calling, etc. but so far given our experience with various "openai" wrapper APIs, it has been spotty.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
raise NotImplementedError()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any chance this could be done? it's OK if not, but we have gone back and filled up many of the missing completion() methods also now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me come back and add it in another PR, same for embedding

Copy link
Contributor

@aidando73 aidando73 Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashwinb - Are we planning to support that endpoint long-term?

It seems the industry is moving away from it. E.g., OpenAI marked their completions api as legacy https://platform.openai.com/docs/api-reference/completions and Anthropic too: https://docs.anthropic.com/en/api/complete

I'm wondering if it's a good idea to move away from the endpoint as well while llama-stack is still early days.

Otherwise might be painful to maintain it for a long time + deprecate + remove

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reposting response here for visibility:

Thanks for all your contributions @aidan Do -- we want to keep supporting completions at least for now because we believe having raw access to a model is as important. Unlike other providers, Llamas are open source and people play with them and iterate in a variety of ways. The kinds of manipulations we do with a chat_completion endpoint internally may not be what users intend sometimes. Sometimes, they just want a carefully formatted prompt to hit the model directly.

On that theme, I think it would be great if Groq could build completions endpoints on their end too. But until that time, NotImplementedError() would have to do.

ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

@mattf
Copy link
Contributor Author

mattf commented Nov 21, 2024

Thank you for this PR. So good!

Re: testing, I'd like to have a reproducible e2e test (ala what we have in providers/tests/inference/test_text_inference.py and providers/tests/inference/test_vision_inference.py) -- just having an nvidia specific fixture there which could then be invoked as

pytest -s -v --providers inference=nvidia test_text_inference.py --env ...

would be great.

i've updated the PR description to note that this does not cover structured output, vision models, embedding or completion apis.

if it's ok, i'll follow up with PRs to add those features.

➜ pytest -s -v --providers inference=nvidia llama_stack/providers/tests/inference/test_{text,vision}_inference.py --env NVIDIA_API_KEY=... --inference-model Llama3.1-8B-Instruct
/home/matt/.conda/envs/stack/lib/python3.10/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
===================================================================================== test session starts ======================================================================================
platform linux -- Python 3.10.15, pytest-8.3.3, pluggy-1.5.0 -- /home/matt/.conda/envs/stack/bin/python
cachedir: .pytest_cache
rootdir: /home/matt/Documents/Repositories/meta-llama/llama-stack
configfile: pyproject.toml
plugins: anyio-4.6.2.post1, asyncio-0.24.0, httpx-0.34.0
asyncio: mode=strict, default_loop_scope=None
collected 11 items                                                                                                                                                                             

llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[-nvidia] Resolved 4 providers
 inner-inference => nvidia
 models => __routing_table__
 inference => __autorouted__
 inspect => __builtin__

Initializing NVIDIAInferenceAdapter(https://integrate.api.nvidia.com)...
Models: Llama3.1-8B-Instruct served by nvidia

PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-nvidia] SKIPPED (Other inference providers don't support completion() yet)
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completions_structured_output[-nvidia] SKIPPED (This test is not quite robust)
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[-nvidia] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[-nvidia] SKIPPED (Other inference providers don't support structured output yet)
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[-nvidia] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-nvidia] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[-nvidia] PASSED
llama_stack/providers/tests/inference/test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_non_streaming[-nvidia-image0-expected_strings0] SKIPPED (Other...)
llama_stack/providers/tests/inference/test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_non_streaming[-nvidia-image1-expected_strings1] SKIPPED (Other...)
llama_stack/providers/tests/inference/test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_streaming[-nvidia] SKIPPED (Other inference providers don't su...)

========================================================================== 5 passed, 6 skipped, 10 warnings in 29.23s ==========================================================================

@mattf
Copy link
Contributor Author

mattf commented Nov 22, 2024

@ashwinb i find test_structured_output to be flakey. it's both a functionality and accuracy test -

        answer = AnswerFormat.model_validate_json(response.completion_message.content)
        assert answer.first_name == "Michael"
        assert answer.last_name == "Jordan"
        assert answer.year_of_birth == 1963
        assert answer.num_seasons_in_nba == 15

it's an accuracy test because it checks the value of first/last name, birth year, and num seasons.

i find that -

  • llama-3.1-8b-instruct and llama-3.2-3b-instruct pass the functionality portion
  • llama-3.2-3b-instruct consistently fails the accuracy portion (thinking MJ was in the NBA for 14 seasons)
  • llama-3.1-8b-instruct occasionally fails the accuracy portion

suggestions (not mutually exclusive) -

  1. turn the test into functionality only, skip the value checks
  2. split the test into a functionality version and an xfail accuracy version
  3. add context to the prompt so the llm can answer without accessing embedded memory

@mattf mattf requested a review from ashwinb November 22, 2024 20:52
@ashwinb
Copy link
Contributor

ashwinb commented Nov 23, 2024

@mattf I agree with your comments on test_structured_output completely. I think the third option makes the most sense. I will update the test pronto.

Copy link
Contributor

@ashwinb ashwinb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Merging!

@ashwinb ashwinb merged commit 4e6c984 into meta-llama:main Nov 23, 2024
2 checks passed
SLR722 pushed a commit that referenced this pull request Nov 27, 2024
# What does this PR do?

this PR adds a basic inference adapter to NVIDIA NIMs

what it does -
 - chat completion api
   - tool calls
   - streaming
   - structured output
   - logprobs
 - support hosted NIM on integrate.api.nvidia.com
 - support downloaded NIM containers

what it does not do -
 - completion api
 - embedding api
 - vision models
 - builtin tools
 - have certainty that sampling strategies are correct

## Feature/Issue validation/testing/test plan

`pytest -s -v --providers inference=nvidia
llama_stack/providers/tests/inference/ --env NVIDIA_API_KEY=...`

all tests should pass. there are pydantic v1 warnings.


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
- [x] Did you write any new necessary tests?

Thanks for contributing 🎉!
out_of_tokens = "out_of_tokens"
"""

# TODO(mf): are end_of_turn and end_of_message semantics correct?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm implementing a Groq adapter and wondering this as well

):
yield ChatCompletionResponseEventType.start
while True:
yield ChatCompletionResponseEventType.progress
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is such a cool trick

ashwinb pushed a commit that referenced this pull request Jan 3, 2025
# What does this PR do?

Contributes towards issue (#432)

- Groq text chat completions
- Streaming
- All the sampling params that Groq supports

A lot of inspiration taken from @mattf's good work at
#355

**What this PR does not do**

- Tool calls (Future PR)
- Adding llama-guard model
- See if we can add embeddings

### PR Train

- #609 👈 
- #630


## Test Plan

<details>

<summary>Environment</summary>

```bash
export GROQ_API_KEY=<api_key>

wget https://raw.githubusercontent.com/aidando73/llama-stack/240e6e2a9c20450ffdcfbabd800a6c0291f19288/build.yaml
wget https://raw.githubusercontent.com/aidando73/llama-stack/92c9b5297f9eda6a6e901e1adbd894e169dbb278/run.yaml

# Build and run environment
pip install -e . \
&& llama stack build --config ./build.yaml --image-type conda \
&& llama stack run ./run.yaml \
  --port 5001
```

</details>

<details>

<summary>Manual tests</summary>

Using this jupyter notebook to test manually:
https://github.com/aidando73/llama-stack/blob/2140976d76ee7ef46025c862b26ee87585381d2a/hello.ipynb

Use this code to test passing in the api key from provider_data

```
from llama_stack_client import LlamaStackClient

client = LlamaStackClient(
    base_url="http://localhost:5001",
)

response = client.inference.chat_completion(
    model_id="Llama3.2-3B-Instruct",
    messages=[
        {"role": "user", "content": "Hello, world client!"},
    ],
    # Test passing in groq_api_key from the client
    # Need to comment out the groq_api_key in the run.yaml file
    x_llama_stack_provider_data='{"groq_api_key": "<api-key>"}',
    # stream=True,
)
response
```

</details>

<details>
<summary>Integration</summary>

`pytest llama_stack/providers/tests/inference/test_text_inference.py -v
-k groq`

(run in same environment)

```
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_3b-groq] PASSED                 [  6%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_3b-groq] SKIPPED (Other inf...) [ 12%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[llama_3b-groq] SKIPPED [ 18%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_3b-groq] PASSED [ 25%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_3b-groq] SKIPPED (Ot...) [ 31%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_3b-groq] PASSED  [ 37%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_3b-groq] SKIPPED [ 43%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_3b-groq] SKIPPED [ 50%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-groq] PASSED                 [ 56%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-groq] SKIPPED (Other inf...) [ 62%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[llama_8b-groq] SKIPPED [ 68%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-groq] PASSED [ 75%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-groq] SKIPPED (Ot...) [ 81%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-groq] PASSED  [ 87%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-groq] SKIPPED [ 93%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-groq] SKIPPED [100%]

======================================= 6 passed, 10 skipped, 160 deselected, 7 warnings in 2.05s ========================================
```
</details>

<details>
<summary>Unit tests</summary>

`pytest llama_stack/providers/tests/inference/groq/ -v`

```
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_sets_model PASSED            [  5%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_user_message PASSED [ 10%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_system_message PASSED [ 15%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_converts_completion_message PASSED [ 20%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_logprobs PASSED [ 25%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_response_format PASSED [ 30%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_does_not_include_repetition_penalty PASSED [ 35%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_stream PASSED       [ 40%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_n_is_1 PASSED                [ 45%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_if_max_tokens_is_0_then_it_is_not_included PASSED [ 50%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_max_tokens_if_set PASSED [ 55%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_temperature PASSED  [ 60%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertChatCompletionRequest::test_includes_top_p PASSED        [ 65%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_returns_response PASSED [ 70%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_maps_stop_to_end_of_message PASSED [ 75%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertNonStreamChatCompletionResponse::test_maps_length_to_end_of_message PASSED [ 80%]
llama_stack/providers/tests/inference/groq/test_groq_utils.py::TestConvertStreamChatCompletionResponse::test_returns_stream PASSED [ 85%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqInit::test_raises_runtime_error_if_config_is_not_groq_config PASSED [ 90%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqInit::test_returns_groq_adapter PASSED                            [ 95%]
llama_stack/providers/tests/inference/groq/test_init.py::TestGroqConfig::test_api_key_defaults_to_env_var PASSED                   [100%]

==================================================== 20 passed, 11 warnings in 0.08s =====================================================
```

</details>

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [x] Updated relevant documentation
- [x] Wrote necessary unit or integration tests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants