forked from langgenius/dify
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add gpustack model provider (langgenius#10158)
- Loading branch information
Showing
17 changed files
with
705 additions
and
1 deletion.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 15 additions & 0 deletions
15
api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+56.6 KB
api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 11 additions & 0 deletions
11
api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 10 additions & 0 deletions
10
api/core/model_runtime/model_providers/gpustack/gpustack.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import logging | ||
|
||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class GPUStackProvider(ModelProvider): | ||
def validate_provider_credentials(self, credentials: dict) -> None: | ||
pass |
120 changes: 120 additions & 0 deletions
120
api/core/model_runtime/model_providers/gpustack/gpustack.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
provider: gpustack | ||
label: | ||
en_US: GPUStack | ||
icon_small: | ||
en_US: icon_s_en.png | ||
icon_large: | ||
en_US: icon_l_en.png | ||
supported_model_types: | ||
- llm | ||
- text-embedding | ||
- rerank | ||
configurate_methods: | ||
- customizable-model | ||
model_credential_schema: | ||
model: | ||
label: | ||
en_US: Model Name | ||
zh_Hans: 模型名称 | ||
placeholder: | ||
en_US: Enter your model name | ||
zh_Hans: 输入模型名称 | ||
credential_form_schemas: | ||
- variable: endpoint_url | ||
label: | ||
zh_Hans: 服务器地址 | ||
en_US: Server URL | ||
type: text-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100 | ||
en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100 | ||
- variable: api_key | ||
label: | ||
en_US: API Key | ||
type: secret-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 输入您的 API Key | ||
en_US: Enter your API Key | ||
- variable: mode | ||
show_on: | ||
- variable: __model_type | ||
value: llm | ||
label: | ||
en_US: Completion mode | ||
type: select | ||
required: false | ||
default: chat | ||
placeholder: | ||
zh_Hans: 选择补全类型 | ||
en_US: Select completion type | ||
options: | ||
- value: completion | ||
label: | ||
en_US: Completion | ||
zh_Hans: 补全 | ||
- value: chat | ||
label: | ||
en_US: Chat | ||
zh_Hans: 对话 | ||
- variable: context_size | ||
label: | ||
zh_Hans: 模型上下文长度 | ||
en_US: Model context size | ||
required: true | ||
type: text-input | ||
default: "8192" | ||
placeholder: | ||
zh_Hans: 输入您的模型上下文长度 | ||
en_US: Enter your Model context size | ||
- variable: max_tokens_to_sample | ||
label: | ||
zh_Hans: 最大 token 上限 | ||
en_US: Upper bound for max tokens | ||
show_on: | ||
- variable: __model_type | ||
value: llm | ||
default: "8192" | ||
type: text-input | ||
- variable: function_calling_type | ||
show_on: | ||
- variable: __model_type | ||
value: llm | ||
label: | ||
en_US: Function calling | ||
type: select | ||
required: false | ||
default: no_call | ||
options: | ||
- value: function_call | ||
label: | ||
en_US: Function Call | ||
zh_Hans: Function Call | ||
- value: tool_call | ||
label: | ||
en_US: Tool Call | ||
zh_Hans: Tool Call | ||
- value: no_call | ||
label: | ||
en_US: Not Support | ||
zh_Hans: 不支持 | ||
- variable: vision_support | ||
show_on: | ||
- variable: __model_type | ||
value: llm | ||
label: | ||
zh_Hans: Vision 支持 | ||
en_US: Vision Support | ||
type: select | ||
required: false | ||
default: no_support | ||
options: | ||
- value: support | ||
label: | ||
en_US: Support | ||
zh_Hans: 支持 | ||
- value: no_support | ||
label: | ||
en_US: Not Support | ||
zh_Hans: 不支持 |
Empty file.
45 changes: 45 additions & 0 deletions
45
api/core/model_runtime/model_providers/gpustack/llm/llm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from collections.abc import Generator | ||
|
||
from yarl import URL | ||
|
||
from core.model_runtime.entities.llm_entities import LLMResult | ||
from core.model_runtime.entities.message_entities import ( | ||
PromptMessage, | ||
PromptMessageTool, | ||
) | ||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import ( | ||
OAIAPICompatLargeLanguageModel, | ||
) | ||
|
||
|
||
class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel): | ||
def _invoke( | ||
self, | ||
model: str, | ||
credentials: dict, | ||
prompt_messages: list[PromptMessage], | ||
model_parameters: dict, | ||
tools: list[PromptMessageTool] | None = None, | ||
stop: list[str] | None = None, | ||
stream: bool = True, | ||
user: str | None = None, | ||
) -> LLMResult | Generator: | ||
return super()._invoke( | ||
model, | ||
credentials, | ||
prompt_messages, | ||
model_parameters, | ||
tools, | ||
stop, | ||
stream, | ||
user, | ||
) | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
self._add_custom_parameters(credentials) | ||
super().validate_credentials(model, credentials) | ||
|
||
@staticmethod | ||
def _add_custom_parameters(credentials: dict) -> None: | ||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") | ||
credentials["mode"] = "chat" |
Empty file.
146 changes: 146 additions & 0 deletions
146
api/core/model_runtime/model_providers/gpustack/rerank/rerank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
from json import dumps | ||
from typing import Optional | ||
|
||
import httpx | ||
from requests import post | ||
from yarl import URL | ||
|
||
from core.model_runtime.entities.common_entities import I18nObject | ||
from core.model_runtime.entities.model_entities import ( | ||
AIModelEntity, | ||
FetchFrom, | ||
ModelPropertyKey, | ||
ModelType, | ||
) | ||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | ||
from core.model_runtime.errors.invoke import ( | ||
InvokeAuthorizationError, | ||
InvokeBadRequestError, | ||
InvokeConnectionError, | ||
InvokeError, | ||
InvokeRateLimitError, | ||
InvokeServerUnavailableError, | ||
) | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel | ||
|
||
|
||
class GPUStackRerankModel(RerankModel): | ||
""" | ||
Model class for GPUStack rerank model. | ||
""" | ||
|
||
def _invoke( | ||
self, | ||
model: str, | ||
credentials: dict, | ||
query: str, | ||
docs: list[str], | ||
score_threshold: Optional[float] = None, | ||
top_n: Optional[int] = None, | ||
user: Optional[str] = None, | ||
) -> RerankResult: | ||
""" | ||
Invoke rerank model | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param query: search query | ||
:param docs: docs for reranking | ||
:param score_threshold: score threshold | ||
:param top_n: top n documents to return | ||
:param user: unique user id | ||
:return: rerank result | ||
""" | ||
if len(docs) == 0: | ||
return RerankResult(model=model, docs=[]) | ||
|
||
endpoint_url = credentials["endpoint_url"] | ||
headers = { | ||
"Authorization": f"Bearer {credentials.get('api_key')}", | ||
"Content-Type": "application/json", | ||
} | ||
|
||
data = {"model": model, "query": query, "documents": docs, "top_n": top_n} | ||
|
||
try: | ||
response = post( | ||
str(URL(endpoint_url) / "v1" / "rerank"), | ||
headers=headers, | ||
data=dumps(data), | ||
timeout=10, | ||
) | ||
response.raise_for_status() | ||
results = response.json() | ||
|
||
rerank_documents = [] | ||
for result in results["results"]: | ||
index = result["index"] | ||
if "document" in result: | ||
text = result["document"]["text"] | ||
else: | ||
text = docs[index] | ||
|
||
rerank_document = RerankDocument( | ||
index=index, | ||
text=text, | ||
score=result["relevance_score"], | ||
) | ||
|
||
if score_threshold is None or result["relevance_score"] >= score_threshold: | ||
rerank_documents.append(rerank_document) | ||
|
||
return RerankResult(model=model, docs=rerank_documents) | ||
except httpx.HTTPStatusError as e: | ||
raise InvokeServerUnavailableError(str(e)) | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
""" | ||
Validate model credentials | ||
:param model: model name | ||
:param credentials: model credentials | ||
:return: | ||
""" | ||
try: | ||
self._invoke( | ||
model=model, | ||
credentials=credentials, | ||
query="What is the capital of the United States?", | ||
docs=[ | ||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States " | ||
"Census, Carson City had a population of 55,274.", | ||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " | ||
"are a political division controlled by the United States. Its capital is Saipan.", | ||
], | ||
score_threshold=0.8, | ||
) | ||
except Exception as ex: | ||
raise CredentialsValidateFailedError(str(ex)) | ||
|
||
@property | ||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | ||
""" | ||
Map model invoke error to unified error | ||
""" | ||
return { | ||
InvokeConnectionError: [httpx.ConnectError], | ||
InvokeServerUnavailableError: [httpx.RemoteProtocolError], | ||
InvokeRateLimitError: [], | ||
InvokeAuthorizationError: [httpx.HTTPStatusError], | ||
InvokeBadRequestError: [httpx.RequestError], | ||
} | ||
|
||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: | ||
""" | ||
generate custom model entities from credentials | ||
""" | ||
entity = AIModelEntity( | ||
model=model, | ||
label=I18nObject(en_US=model), | ||
model_type=ModelType.RERANK, | ||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, | ||
) | ||
|
||
return entity |
Empty file.
35 changes: 35 additions & 0 deletions
35
api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Optional | ||
|
||
from yarl import URL | ||
|
||
from core.entities.embedding_type import EmbeddingInputType | ||
from core.model_runtime.entities.text_embedding_entities import ( | ||
TextEmbeddingResult, | ||
) | ||
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( | ||
OAICompatEmbeddingModel, | ||
) | ||
|
||
|
||
class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel): | ||
""" | ||
Model class for GPUStack text embedding model. | ||
""" | ||
|
||
def _invoke( | ||
self, | ||
model: str, | ||
credentials: dict, | ||
texts: list[str], | ||
user: Optional[str] = None, | ||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, | ||
) -> TextEmbeddingResult: | ||
return super()._invoke(model, credentials, texts, user, input_type) | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
self._add_custom_parameters(credentials) | ||
super().validate_credentials(model, credentials) | ||
|
||
@staticmethod | ||
def _add_custom_parameters(credentials: dict) -> None: | ||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") |
Oops, something went wrong.