Skip to content

Commit

Permalink
feat: add gpustack model provider (langgenius#10158)
Browse files Browse the repository at this point in the history
  • Loading branch information
gitlawr authored and JunXu01 committed Nov 9, 2024
1 parent b67ada5 commit 2330768
Show file tree
Hide file tree
Showing 17 changed files with 705 additions and 1 deletion.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 10 additions & 0 deletions api/core/model_runtime/model_providers/gpustack/gpustack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import logging

from core.model_runtime.model_providers.__base.model_provider import ModelProvider

logger = logging.getLogger(__name__)


class GPUStackProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass
120 changes: 120 additions & 0 deletions api/core/model_runtime/model_providers/gpustack/gpustack.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
provider: gpustack
label:
en_US: GPUStack
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
supported_model_types:
- llm
- text-embedding
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: endpoint_url
label:
zh_Hans: 服务器地址
en_US: Server URL
type: text-input
required: true
placeholder:
zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100
en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 输入您的 API Key
en_US: Enter your API Key
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
type: select
required: false
default: chat
placeholder:
zh_Hans: 选择补全类型
en_US: Select completion type
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: "8192"
placeholder:
zh_Hans: 输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens_to_sample
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
show_on:
- variable: __model_type
value: llm
default: "8192"
type: text-input
- variable: function_calling_type
show_on:
- variable: __model_type
value: llm
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: function_call
label:
en_US: Function Call
zh_Hans: Function Call
- value: tool_call
label:
en_US: Tool Call
zh_Hans: Tool Call
- value: no_call
label:
en_US: Not Support
zh_Hans: 不支持
- variable: vision_support
show_on:
- variable: __model_type
value: llm
label:
zh_Hans: Vision 支持
en_US: Vision Support
type: select
required: false
default: no_support
options:
- value: support
label:
en_US: Support
zh_Hans: 支持
- value: no_support
label:
en_US: Not Support
zh_Hans: 不支持
Empty file.
45 changes: 45 additions & 0 deletions api/core/model_runtime/model_providers/gpustack/llm/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from collections.abc import Generator

from yarl import URL

from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
OAIAPICompatLargeLanguageModel,
)


class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
return super()._invoke(
model,
credentials,
prompt_messages,
model_parameters,
tools,
stop,
stream,
user,
)

def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)

@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
credentials["mode"] = "chat"
Empty file.
146 changes: 146 additions & 0 deletions api/core/model_runtime/model_providers/gpustack/rerank/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from json import dumps
from typing import Optional

import httpx
from requests import post
from yarl import URL

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelPropertyKey,
ModelType,
)
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel


class GPUStackRerankModel(RerankModel):
"""
Model class for GPUStack rerank model.
"""

def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])

endpoint_url = credentials["endpoint_url"]
headers = {
"Authorization": f"Bearer {credentials.get('api_key')}",
"Content-Type": "application/json",
}

data = {"model": model, "query": query, "documents": docs, "top_n": top_n}

try:
response = post(
str(URL(endpoint_url) / "v1" / "rerank"),
headers=headers,
data=dumps(data),
timeout=10,
)
response.raise_for_status()
results = response.json()

rerank_documents = []
for result in results["results"]:
index = result["index"]
if "document" in result:
text = result["document"]["text"]
else:
text = docs[index]

rerank_document = RerankDocument(
index=index,
text=text,
score=result["relevance_score"],
)

if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)

return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError],
}

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
)

return entity
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Optional

from yarl import URL

from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.text_embedding_entities import (
TextEmbeddingResult,
)
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
OAICompatEmbeddingModel,
)


class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
"""
Model class for GPUStack text embedding model.
"""

def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
return super()._invoke(model, credentials, texts, user, input_type)

def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)

@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
Loading

0 comments on commit 2330768

Please sign in to comment.