Skip to content

Commit

Permalink
feat: update the xinf tool's API key to optional
Browse files Browse the repository at this point in the history
  • Loading branch information
hwzhuhao committed Oct 8, 2024
1 parent b933c9d commit 9d9448a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,15 @@ def _invoke(
model = self.runtime.credentials.get("model", None)
if not model:
return self.create_text_message("Please input model")

api_key = self.runtime.credentials.get("api_key") or "abc"
headers = {"Authorization": f"Bearer {api_key}"}
# set model
try:
url = str(URL(base_url) / "sdapi" / "v1" / "options")
response = post(
url,
json={"sd_model_checkpoint": model},
headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
headers=headers,
)
if response.status_code != 200:
raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model")
Expand Down Expand Up @@ -257,14 +258,15 @@ def img2img(
draw_options["prompt"] = f"{lora},{prompt}"
else:
draw_options["prompt"] = prompt

api_key = self.runtime.credentials.get("api_key") or "abc"
headers = {"Authorization": f"Bearer {api_key}"}
try:
url = str(URL(base_url) / "sdapi" / "v1" / "img2img")
response = post(
url,
json=draw_options,
timeout=120,
headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
headers=headers,
)
if response.status_code != 200:
return self.create_text_message("Failed to generate image")
Expand Down Expand Up @@ -298,14 +300,15 @@ def text2img(
else:
draw_options["prompt"] = prompt
draw_options["override_settings"]["sd_model_checkpoint"] = model

api_key = self.runtime.credentials.get("api_key") or "abc"
headers = {"Authorization": f"Bearer {api_key}"}
try:
url = str(URL(base_url) / "sdapi" / "v1" / "txt2img")
response = post(
url,
json=draw_options,
timeout=120,
headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
headers=headers,
)
if response.status_code != 200:
return self.create_text_message("Failed to generate image")
Expand Down
14 changes: 10 additions & 4 deletions api/core/tools/provider/builtin/xinference/xinference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@

class XinferenceProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
base_url = credentials.get("base_url")
api_key = credentials.get("api_key")
model = credentials.get("model")
base_url = credentials.get("base_url", "").removesuffix("/")
api_key = credentials.get("api_key", "")
if not api_key:
api_key = "abc"
credentials["api_key"] = api_key
model = credentials.get("model", "")
if not base_url or not model:
raise ToolProviderCredentialValidationError("Xinference base_url and model is required")
headers = {"Authorization": f"Bearer {api_key}"}
res = requests.post(
f"{base_url}/sdapi/v1/options",
headers={"Authorization": f"Bearer {api_key}"},
headers=headers,
json={"sd_model_checkpoint": model},
)
if res.status_code != 200:
Expand Down
2 changes: 1 addition & 1 deletion api/core/tools/provider/builtin/xinference/xinference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ credentials_for_provider:
zh_Hans: 请输入你的模型名称
api_key:
type: secret-input
required: true
required: false
label:
en_US: API Key
zh_Hans: Xinference 服务器的 API Key
Expand Down

0 comments on commit 9d9448a

Please sign in to comment.