-
Notifications
You must be signed in to change notification settings - Fork 803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inference] Use huggingface_hub inference client for TGI adapter #53
Conversation
I don't think the openai-compat layer works correctly with tool calling. I specifically needed to use Can you show a test where you point |
Hi @ashwinb, you were right, we need full control of the tokens sent to the model which is not easy to get using I tested the brave search tool calling with the Server:
(works also using the distribution Note: I unabled Brave Search tool only in the agentic system client.
Output:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome, thank you so much. Just a few comments.
Hello @ashwinb, thanks for your review! I updated the Adapter based on your comments. I split the TGI adapter for Local or Remote TGI endpoints and Hugging Face Inference Endpoints due to differences in retrieving the
I've tested both adapters for inference and agentic function calling. Please let me know if you have any suggestions for further improvements! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thanks once again.
I do have one important comment about the get_namespace()
method pulling in the huggingface_hub
dep into the config file. Please move that helper method outside as a free floating function outside into the impl file.
|
||
from typing import Optional | ||
|
||
from huggingface_hub import HfApi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we want the config files to be extremely lightweight so one can use them outside of the larger containers. as such they should only have datatypes and no "behavior" or functions at all. I suggest you move at least the get_namespace()
method out into the impl where it needs to be used so this dependency is not here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed ✅
* Use huggingface_hub inference client for TGI inference * Update the default value for TGI URL * Use InferenceClient.text_generation for TGI inference * Fixes post-review and split TGI adapter into local and Inference Endpoints ones * Update CLI reference and add typing * Rename TGI Adapter class * Use HfApi to get the namespace when not provide in the hf endpoint name * Remove unecessary method argument * Improve TGI adapter initialization condition * Move helper into impl file + fix merging conflicts
Test Plan: First, start a TGI container with `meta-llama/Llama-Guard-3-8B` model serving on port 5099. See #53 and its description for how. Then run llama-stack with the following run config: ``` image_name: safety docker_image: null conda_env: safety apis_to_serve: - models - inference - shields - safety api_providers: inference: providers: - remote::tgi safety: providers: - meta-reference telemetry: provider_id: meta-reference config: {} routing_table: inference: - provider_id: remote::tgi config: url: http://localhost:5099 api_token: null hf_endpoint_name: null routing_key: Llama-Guard-3-8B safety: - provider_id: meta-reference config: llama_guard_shield: model: Llama-Guard-3-8B excluded_categories: [] disable_input_check: false disable_output_check: false prompt_guard_shield: null routing_key: llama_guard ``` Now simply run `python -m llama_stack.apis.safety.client localhost <port>` and check that the llama_guard shield calls run correctly. (The injection_shield calls fail as expected since we have not set up a router for them.)
Why this PR? What does it do?
This PR refines the existing TGI integration #52. Specifically:
Test Plan
I tested the refined TGI integration with both local endpoint and Hugging Face Inference Endpoints.
1. Setup the distribution
Locally, we need to first start up the TGI container:
Then we build a llama stack
llama stack build local-plus-tgi-inference --name 8b-instruct
This command creates a conda environment with the necessary dependencies. Then, you'll be prompted for:
2. Run the server
llama stack run local-plus-tgi-inference --name 8b-instruct --port 6001
3. Test the distribution
python -m llama_toolchain.inference.client localhost 6001