diff --git a/comps/guardrails/README.md b/comps/guardrails/README.md index 1107e7fe47..4407a4647e 100644 --- a/comps/guardrails/README.md +++ b/comps/guardrails/README.md @@ -38,9 +38,9 @@ export LANGCHAIN_TRACING_V2=true export LANGCHAIN_API_KEY=${your_langchain_api_key} export LANGCHAIN_PROJECT="opea/gaurdrails" volume=$PWD/data -model_id="meta-llama/LlamaGuard-7b" -docker pull ghcr.io/huggingface/tgi-gaudi:1.2.1 -docker run -p 8088:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host -e HTTPS_PROXY=$https_proxy -e HTTP_PROXY=$https_proxy ghcr.io/huggingface/tgi-gaudi:1.2.1 --model-id $model_id +model_id="meta-llama/Meta-Llama-Guard-2-8B" +docker pull ghcr.io/huggingface/tgi-gaudi:2.0.1 +docker run -p 8088:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host -e HTTPS_PROXY=$https_proxy -e HTTP_PROXY=$https_proxy ghcr.io/huggingface/tgi-gaudi:2.0.1 --model-id $model_id ``` ## 1.3 Verify the TGI Gaudi Service @@ -57,7 +57,7 @@ curl 127.0.0.1:8088/generate \ Optional: If you have deployed a Guardrails model with TGI Gaudi Service other than default model (i.e., `meta-llama/LlamaGuard-7b`) [from section 1.2](## 1.2 Start TGI Gaudi Service), you will need to add the eviornment variable `SAFETY_GUARD_MODEL_ID` containing the model id. For example, the following informs the Guardrails Service the deployed model used LlamaGuard2: ```bash -export SAFETY_GUARD_MODEL_ID="meta-llama/Meta-Llama-Guard-2-8" +export SAFETY_GUARD_MODEL_ID="meta-llama/Meta-Llama-Guard-2-8B" ``` ```bash diff --git a/comps/guardrails/langchain/guardrails_tgi_gaudi.py b/comps/guardrails/langchain/guardrails_tgi_gaudi.py index 03d1935053..34cecf5d85 100644 --- a/comps/guardrails/langchain/guardrails_tgi_gaudi.py +++ b/comps/guardrails/langchain/guardrails_tgi_gaudi.py @@ -90,6 +90,6 @@ def safety_guard(input: TextDoc) -> TextDoc: repetition_penalty=1.03, ) # chat engine for server-side prompt templating - llm_engine_hf = ChatHuggingFace(llm=llm_guard) + llm_engine_hf = ChatHuggingFace(llm=llm_guard, model_id=safety_guard_model) print("guardrails - router] LLM initialized.") opea_microservices["opea_service@guardrails_tgi_gaudi"].start()