Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added text-generation-webui infrence support #221

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,37 @@ character_maker(

The prompt above typically takes just over 2.5 seconds to complete on a A6000 GPU when using LLaMA 7B. If we were to run the same prompt adapted to be a single generation call (the standard practice today) it takes about 5 seconds to complete (4 of which is token generation and 1 of which is prompt processing). *This means Guidance acceleration delivers a 2x speedup over the standard approach for this prompt.* In practice the exact speed-up factor depends on the format of your specific prompt and the size of your model (larger models benefit more). Acceleration is also only supported for Transformers LLMs at the moment. See the [notebook](https://github.com/microsoft/guidance/blob/main/notebooks/guidance_acceleration.ipynb) for more details.


This class allows integration with [text-generation-webui](https://github.com/oobabooga/text-generation-webui) which
allows for easy setup and enables running Larger Models with less VRAM. Additonally ensures that machine which runs guidance
and [text-generation-webui](https://github.com/oobabooga/text-generation-webui) infrence server to do not need to be the same.



````python

guidance.llm = guidance.llms.TGWUI("http://127.0.0.1:9000")

# define the prompt
character_maker = guidance("""The following is a character profile for a Soccer Game in JSON format.
```json
{
"Nationality": "{{nationality}}",
"league": "{{league}}",
"name": "{{gen 'name'}}",
"age": {{gen 'age' pattern='[0-9]+' stop=','}},
"overall": {{gen 'overall' pattern='[0-9]+' stop=','}},
"description": "{{gen 'description' temperature=1.25}}",
}```""")

# generate a character
character_maker(
nationality="Türkiye",
league="Premier League"
)
````


## Token healing (<a href="https://github.com/microsoft/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb">notebook</a>)

The standard greedy tokenizations used by most language models introduce a subtle and powerful bias that can have all kinds of unintended consequences for your prompts. Using a process we call "token healing" `guidance` automatically removes these surprising biases, freeing you to focus on designing the prompts you want without worrying about tokenization artifacts.
Expand Down
1 change: 1 addition & 0 deletions guidance/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ._transformers import Transformers
from ._mock import Mock
from ._llm import LLM, LLMSession, SyncSession
from ._tgwui import TGWUI
from ._deep_speed import DeepSpeed
from . import transformers
from . import caches
125 changes: 125 additions & 0 deletions guidance/llms/_tgwui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import time
import collections
import regex
import pygtrie
import torch
import queue
import threading
import logging
import collections.abc
import asyncio
import requests
from typing import Any, Dict, Optional, Callable
from ._llm import LLM, LLMSession, SyncSession


class TGWUI(LLM):
def __init__(self, base_url, chat_mode=False):
self.chat_mode = False # by default models are not in role-based chat mode
self.base_url = base_url
self.model_name = "unknown"


def __getitem__(self, key):
"""Gets an attribute from the LLM."""
return getattr(self, key)

def session(self, asynchronous=False):
"""Creates a session for the LLM.

This implementation is meant to be overridden by subclasses.
"""
return TWGUISession(self)

def encode(self, string, **kwargs):
args={"text": string, "kwargs": kwargs}
try:
response = requests.post(self.base_url+'/api/v1/encode',json=args)
response.raise_for_status()
except requests.exceptions.RequestException as e:
print(f'Encode request failed with error {e}')
return None

resp = response.json()
if 'results' in resp and len(resp['results']) > 0 and 'tokens' in resp['results'][0]:
return resp['results'][0]['tokens'][0]
else:
print('Unexpected response format')
return None


def decode(self, tokens, **kwargs):
args={"tokens": tokens, "kwargs": kwargs}
try:
response = requests.post(self.base_url+'/api/v1/decode',json=args)
response.raise_for_status()
except requests.exceptions.RequestException as e:
print(f'Decode request failed with error {e}')
return None

resp = response.json()
if 'results' in resp and len(resp['results']) > 0 and 'ids' in resp['results'][0]:
return resp['results'][0]['ids']
else:
print('Unexpected response format')
return None




class TWGUISession(LLMSession):
def __init__(self, llm):
self.llm = llm
self._call_counts = {} # tracks the number of repeated identical calls to the LLM with non-zero temperature

def __enter__(self):
return self

async def __call__(
self, prompt, stop=None, stop_regex=None, temperature=None, n=1, max_tokens=1000, logprobs=None,
top_p=1.0, echo=False, logit_bias=None, token_healing=None, pattern=None, stream=None,
cache_seed=0, caching=None, **completion_kwargs
):
args={
"prompt":prompt, "stop": stop, "stop_regex":stop_regex, "temperature": temperature, "n":n,
"max_tokens":max_tokens, "logprobs":logprobs, "top_p":top_p, "echo":echo, "logit_bias":logit_bias,
"token_healing":token_healing, "pattern":pattern, "stream":stream, "cache_seed":cache_seed,
"completion_kwargs":completion_kwargs
}

try:
response = requests.post(self.llm.base_url+'/api/v1/call',json=args)
response.raise_for_status() # This will raise a HTTPError if the response was an error
except requests.exceptions.RequestException as e: # This will catch any kind of request exception
print(f'Request failed with error {e}')
return None

resp = response.json()
if 'choices' in resp and len(resp['choices']) > 0:
return resp['choices'][0]['text']
else:
print('Unexpected response format')
return None

def __exit__(self, exc_type, exc_value, traceback):
pass

def _gen_key(self, args_dict):
del args_dict["self"] # skip the "self" arg
return "_---_".join([str(v) for v in ([args_dict[k] for k in args_dict] + [self.llm.model_name, self.llm.__class__.__name__, self.llm.cache_version])])

def _cache_params(self, args_dict) -> Dict[str, Any]:
"""get the parameters for generating the cache key"""
key = self._gen_key(args_dict)
# if we have non-zero temperature we include the call count in the cache key
if args_dict.get("temperature", 0) > 0:
args_dict["call_count"] = self._call_counts.get(key, 0)

# increment the call count
self._call_counts[key] = args_dict["call_count"] + 1
args_dict["model_name"] = self.llm.model_name
args_dict["cache_version"] = self.llm.cache_version
args_dict["class_name"] = self.llm.__class__.__name__

return args_dict