Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Oobabooga's text-generation-webui #278

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ class WebUI:
"info": "This is the URL of the Kobold AI Client API you want your worker to connect to. "
"You will probably be running your own Kobold AI Client, and you should enter the URL here.",
},
"is_oobabooga": {
"label": "Is Oobabooga",
"info": "If you are using the Oobabooga API, set this to true.",
},
"max_length": {
"label": "Maximum Length",
"info": "This is the maximum number of tokens your worker will generate per request.",
Expand Down Expand Up @@ -815,6 +819,13 @@ def initialise(self):
info=self._info("kai_url"),
)

config.default("is_oobabooga", False)
is_oobabooga = gr.Checkbox(
label=self._label("is_oobabooga"),
value=config.is_oobabooga,
info=self._info("is_oobabooga"),
)

config.default("max_length", 80)
max_length = gr.Slider(
0,
Expand Down Expand Up @@ -885,6 +896,7 @@ def initialise(self):
vram_to_leave_free,
scribe_name,
kai_url,
is_oobabooga,
max_length,
max_context_length,
branded_model,
Expand Down
7 changes: 7 additions & 0 deletions worker/argparser/scribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,11 @@
help="The URL in which the KoboldAI Client API can be found.",
)

arg_parser.add_argument(
"--oobabooga",
action="store",
required=False,
help="Set to true if you want to use the Oobabooga API.",
)

args = arg_parser.parse_args()
24 changes: 18 additions & 6 deletions worker/bridge_data/scribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self):
self.kai_available = False
self.model = None
self.kai_url = "http://localhost:5000"
self.is_oobabooga = os.environ.get("IS_OOBABOOGA", "false") == "true"
self.max_length = int(os.environ.get("HORDE_MAX_LENGTH", "80"))
self.max_context_length = int(os.environ.get("HORDE_MAX_CONTEXT_LENGTH", "1024"))
self.branded_model = os.environ.get("HORDE_BRANDED_MODEL", "false") == "true"
Expand All @@ -36,6 +37,8 @@ def reload_data(self):
self.max_threads = 1
if args.kai_url:
self.kai_url = args.kai_url
if args.oobabooga:
self.is_oobabooga = True
if args.sfw:
self.nsfw = False
if args.blacklist:
Expand All @@ -55,7 +58,7 @@ def reload_data(self):
def validate_kai(self):
logger.debug("Retrieving settings from KoboldAI Client...")
try:
req = requests.get(self.kai_url + "/api/latest/model")
req = requests.get(self.kai_url + ("/v1/model" if self.is_oobabooga else "/api/latest/model"))
self.model = req.json()["result"]
# Normalize huggingface and local downloaded model names
if "/" not in self.model:
Expand All @@ -65,13 +68,22 @@ def validate_kai(self):
# self.max_context_length = req.json()["value"]
# req = requests.get(self.kai_url + "/api/latest/config/max_length")
# self.max_length = req.json()["value"]

if self.model not in self.softprompts:
req = requests.get(self.kai_url + "/api/latest/config/soft_prompts_list")
self.softprompts[self.model] = [sp["value"] for sp in req.json()["values"]]
req = requests.get(self.kai_url + "/api/latest/config/soft_prompt")
self.current_softprompt = req.json()["value"]
if self.is_oobabooga:
self.softprompts[self.model] = []
else:
req = requests.get(self.kai_url + "/api/latest/config/soft_prompts_list")
self.softprompts[self.model] = [sp["value"] for sp in req.json()["values"]]

if not self.is_oobabooga:
req = requests.get(self.kai_url + "/api/latest/config/soft_prompt")
self.current_softprompt = req.json()["value"]
except requests.exceptions.JSONDecodeError:
logger.error(f"Server {self.kai_url} is up but does not appear to be a KoboldAI server.")
if self.is_oobabooga:
logger.error(f"Server {self.kai_url} is up but does not appear to be a Oobabooga server.")
else:
logger.error(f"Server {self.kai_url} is up but does not appear to be a KoboldAI server.")
self.kai_available = False
return
except requests.exceptions.ConnectionError:
Expand Down
66 changes: 52 additions & 14 deletions worker/jobs/scribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, mm, bd, pop):
def start_job(self):
"""Starts a Scribe job from a pop request"""
logger.debug("Starting job in threadpool for model: {}", self.current_model)

super().start_job()
if self.status == JobStatus.FAULTED:
self.start_submit_thread()
Expand All @@ -51,26 +52,63 @@ def start_job(self):
)
time_state = time.time()
if self.requested_softprompt != self.bridge_data.current_softprompt:
requests.put(
self.bridge_data.kai_url + "/api/latest/config/soft_prompt/",
json={"value": self.requested_softprompt},
)
time.sleep(1) # Wait a second to unload the softprompt
if not self.bridge_data.is_ooababooga:
requests.put(
self.bridge_data.kai_url + "/api/latest/config/soft_prompt/",
json={"value": self.requested_softprompt},
)
time.sleep(1) # Wait a second to unload the softprompt

loop_retry = 0
gen_success = False
while not gen_success and loop_retry < 5:
try:
gen_req = requests.post(
self.bridge_data.kai_url + "/api/latest/generate/",
json=self.current_payload,
timeout=300,
)
gen_req = {
"json": {
"results": []
},
"status_code": 200,
}

if self.bridge_data.is_oobabooga:
for _ in range(self.current_payload['n']):
req = requests.post(
self.bridge_data.kai_url + "/v1/generate",
json={
"prompt": self.current_payload["prompt"],
"max_new_tokens": self.current_payload["max_length"],
"temperature": self.current_payload["temperature"] if "temperature" in self.current_payload else 0.7,
"top_p": self.current_payload["top_p"] if "top_p" in self.current_payload else 0.9,
"top_k": self.current_payload["top_k"] if "top_k" in self.current_payload else 20,
"top_a": self.current_payload["top_a"] if "top_a" in self.current_payload else 0,
"repetition_penalty": self.current_payload["rep_pen"] if "rep_pen" in self.current_payload else 0,
"repetition_penalty_range": self.current_payload["rep_pen_range"] if "rep_pen_range" in self.current_payload else 0,
"penalty_alpha": self.current_payload["rep_pen_slope"] if "rep_pen_slope" in self.current_payload else 0,
"typical_p": self.current_payload["typical"] if "typical" in self.current_payload else 1,
"tfs": self.current_payload["tfs"] if "tfs" in self.current_payload else 1,
},
timeout=300,
)
if req.status_code != 200:
gen_req["status_code"] = req.status_code
break

gen_req["json"]["results"].append(req.json()["results"][0])
else:
req = requests.post(
self.bridge_data.kai_url + "/api/latest/generate/",
json=self.current_payload,
timeout=300,
)

gen_req["json"] = req.json()
gen_req["status_code"] = req.status_code
except (requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout):
logger.error(f"Worker {self.bridge_data.kai_url} unavailable. Retrying in 3 seconds...")
loop_retry += 1
time.sleep(3)
continue
if type(gen_req.json()) is not dict:
if type(gen_req["json"]) is not dict:
logger.error(
(
f"KAI instance {self.bridge_data.kai_url} API unexpected response on generate: {gen_req}. "
Expand All @@ -80,22 +118,22 @@ def start_job(self):
time.sleep(3)
loop_retry += 1
continue
if gen_req.status_code == 503:
if gen_req["status_code"] == 503:
logger.debug(
f"KAI instance {self.bridge_data.kai_url} Busy (attempt {loop_retry}). Will try again...",
)
time.sleep(3)
loop_retry += 1
continue
if gen_req.status_code == 422:
if gen_req["status_code"] == 422:
logger.error(
f"KAI instance {self.bridge_data.kai_url} reported validation error.",
)
self.status = JobStatus.FAULTED
self.start_submit_thread()
return
try:
req_json = gen_req.json()
req_json = gen_req["json"]
except json.decoder.JSONDecodeError:
logger.error(
(
Expand Down