Skip to content

Commit

Permalink
add ggml based ctransformer models (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
bluecoconut authored May 16, 2023
1 parent 5a1f0f1 commit 5fcf51d
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions lambdaprompt/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def set_backend(backend_name):
backends['completion'] = MPT7BInstructCompletion()
elif backend_name == 'StarCoder':
backends['completion'] = StarCoderCompletion()
elif backend_name == 'StarCoderGGML':
backends['completion'] = StarCoderGGMLQuantizedCompletion()
elif backend_name == 'SantaCoderGGML':
backends['completion'] = SantaCoderGGMLQuantizedCompletion()
elif backend_name == 'GPT3' or backend_name == 'OpenAI':
backends['completion'] = OpenAICompletion()
backends['chat'] = OpenAIChat()
Expand Down Expand Up @@ -142,6 +146,68 @@ def parse_response(self, answer):
return answer["choices"][0]["message"]['content']


class CTransformersBackend(Backend):
# https://github.com/marella/ctransformers

class Parameters(Backend.Parameters):
max_new_tokens: int = 200
temperature: float = 0.01
top_p: float = 0.92
top_k: int = 0
repetition_penalty: float = 1.1
stop: Optional[Union[str, List[str]]]

def __init__(self, model_name, model_type, **param_override):
try:
from ctransformers import AutoModelForCausalLM
except ImportError:
raise ImportError("You must install ctransformers to use this backend (`pip install ctransformers`)")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
model_type=model_type)
super().__init__(**param_override)

def preprocess(self, prompt):
return prompt

async def __call__(self, prompt, **kwargs):
prompt = self.preprocess(prompt)
genkwargs = self.parse_param(**kwargs)
max_new_tokens = genkwargs.pop("max_new_tokens")

tokens = self.model.tokenize(prompt)
stop = genkwargs.pop("stop", None) or []
if isinstance(stop, str):
stop = [stop]
end_ids = [self.model.tokenize(x) for x in stop]

def should_stop(response_tokens):
for end in end_ids:
if all(x == y for x, y in zip(response_tokens[-len(end):], end)):
return True
if len(response_tokens) >= max_new_tokens:
return True
return False

response = []
for token in self.model.generate(tokens, **genkwargs):
response.append(token)
if should_stop(response):
break

return self.model.detokenize(response)


class StarCoderGGMLQuantizedCompletion(CTransformersBackend):
def __init__(self, **kwargs):
super().__init__("nouamanetazi/starcoder-ggml", model_type='starcoder', **kwargs)


class SantaCoderGGMLQuantizedCompletion(CTransformersBackend):
def __init__(self, **kwargs):
super().__init__("danforbes/santacoder-ggml-q4_1", model_type='starcoder', **kwargs)


class HuggingFaceBackend(Backend):
class Parameters(Backend.Parameters):
temperature: float = 0.01
Expand Down

0 comments on commit 5fcf51d

Please sign in to comment.