Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add local LLM support (with function calling) #97

Merged
merged 21 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions memgpt/local_llm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
## How to connect MemGPT to non-OpenAI LLMs

**If you have a hosted ChatCompletion-compatible endpoint that works with function calling**:
- simply set `OPENAI_API_BASE` to the IP+port of your endpoint:

```sh
export OPENAI_API_BASE=...
```

Note: for this to work, the endpoint **MUST** support function calls. As of 10/22/2023, most ChatCompletion endpoints do **NOT** support function calls, so if you want to play with MemGPT and open models, you probably need to follow the instructions below.

## Integrating a function-call finetuned LLM with MemGPT

**If you have a hosted local model that is function-call finetuned**:
- implement a wrapper class for that model
- the wrapper class needs to implement two functions:
- one to go from ChatCompletion messages/functions schema to a prompt string
- and one to go from raw LLM outputs to a ChatCompletion response
- put that model behind a server (e.g. using WebUI) and set `OPENAI_API_BASE`

```python
class LLMChatCompletionWrapper(ABC):

@abstractmethod
def chat_completion_to_prompt(self, messages, functions):
"""Go from ChatCompletion to a single prompt string"""
pass

@abstractmethod
def output_to_chat_completion_response(self, raw_llm_output):
"""Turn the LLM output string into a ChatCompletion response"""
pass
```

## Example with Airoboros LLM

To help you get started, we've implemented an example wrapper class for a popular llama2 model finetuned on function calling (airoboros). We want MemGPT to run well on open models as much as you do, so we'll be actively updating this page with more examples. Additionally, we welcome contributions from the community! If you find an open LLM that works well with MemGPT, please open a PR with a model wrapper and we'll merge it ASAP.

```python
class Airoboros21Wrapper(LLMChatCompletionWrapper):
"""Wrapper for Airoboros 70b v2.1: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1"""

def chat_completion_to_prompt(self, messages, functions):
"""
Examples for how airoboros expects its prompt inputs: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
Examples for how airoboros expects to see function schemas: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#agentfunction-calling
"""

def output_to_chat_completion_response(self, raw_llm_output):
"""Turn raw LLM output into a ChatCompletion style response with:
"message" = {
"role": "assistant",
"content": ...,
"function_call": {
"name": ...
"arguments": {
"arg1": val1,
...
}
}
}
"""
```

---

## Status of ChatCompletion w/ function calling and open LLMs

MemGPT uses function calling to do memory management. With OpenAI's ChatCompletion API, you can pass in a function schema in the ‘functions' keyword arg, and the API response will include a ‘function_call’ field that includes the function name and the function arguments (generated JSON). How this works under the hood is your ‘functions’ keyword is combined with the ‘messages’ and ‘system' to form one big string input to the transformer, and the output of the transformer is parsed to extract the JSON function call.

In the future, more open LLMs and LLM servers (that can host OpenAI-compatable ChatCompletion endpoints) may start including parsing code to do this automatically as standard practice. However, in the meantime, when you see a model that says it supports “function calling”, like Airoboros, it doesn't mean that you can just load Airoboros into a ChatCompletion-compatable endpoint like FastChat, and then use the same OpenAI API call and it'll just work.

(1) When an open LLM says it supports function calling, they probably mean that the model was finetuned on some function call data. Remember, transformers are just string-in-string-out, so there are many ways to format this function call data. Airoboros formats the function schema in YAML style (see https://huggingface.co/jondurbin/airoboros-l2-70b-3.1.2#agentfunction-calling)) and the output is in JSON style. To get this to work behind a ChatCompletion API, you still have to do the parsing from ‘functions’ keyword arg (containing the schema) to the model's expected schema style in the prompt (YAML for Airoboros), and you have to run some code to extract the function call (JSON for Airoboros) and package it cleanly as a ‘function_call’ field in the response.

(2) Partly because of how complex it is to support function calling, most (all?) of the community projects that do OpenAI ChatCompletion endpoints for arbitrary open LLMs do not support function calling, because if they did, they would need to write model-specific parsing code for each one.

## How can you run MemGPT with open LLMs that support function calling?

Because of the poor state of function calling support in existing ChatCompletion API serving code, we instead provide a light wrapper on top of ChatCompletion that uses a parser specific to Airoboros. We hope that this example code will help the community add additional compatability of MemGPT with more function-calling LLMs - we will also add more model support as we test more models and find those that work well enough to run MemGPT's function set.

To run the example of MemGPT with Airoboros, you'll need to host the model with some open LLM hosting code, for example Oobagooba (see here). Then, all you need to do is point MemGPT to this API endpoint. Now, instead of calling ChatCompletion on OpenAI's API, MemGPT will use it's own ChatCompletion wrapper that parses the system, messages, and function arguments into a format that Airoboros has been finetuned on, and once Airoboros generates a string output, MemGPT will parse the response to extract a potential function call (knowing what we know about Airoboros expected function call output).
Empty file added memgpt/local_llm/__init__.py
Empty file.
88 changes: 88 additions & 0 deletions memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""MemGPT sends a ChatCompletion request

Under the hood, we use the functions argument to turn
"""


"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend"""

import os
import json
import requests

from .webui_settings import DETERMINISTIC, SIMPLE
from .llm_chat_completion_wrappers import airoboros

HOST = os.getenv('OPENAI_API_BASE')
HOST_TYPE = os.getenv('BACKEND_TYPE') # default None == ChatCompletion
cpacker marked this conversation as resolved.
Show resolved Hide resolved


class DotDict(dict):
"""Allow dot access on properties similar to OpenAI response object"""

def __getattr__(self, attr):
return self.get(attr)

def __setattr__(self, key, value):
self[key] = value


async def get_chat_completion(
model, # no model, since the model is fixed to whatever you set in your own backend
messages,
functions,
function_call="auto",
):
if function_call != "auto":
raise ValueError(f"function_call == {function_call} not supported (auto only)")

if True or model == 'airoboros_v2.1':
llm_wrapper = airoboros.Airoboros21Wrapper()

# First step: turn the message sequence into a prompt that the model expects
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
# print(prompt)

if HOST_TYPE != 'webui':
cpacker marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(HOST_TYPE)

request = SIMPLE
request['prompt'] = prompt

try:

URI = f'{HOST}/v1/generate'
response = requests.post(URI, json=request)
if response.status_code == 200:
# result = response.json()['results'][0]['history']
result = response.json()
# print(f"raw API response: {result}")
result = result['results'][0]['text']
print(f"json API response.text: {result}")
else:
raise Exception(f"API call got non-200 response code")

# cleaned_result, chatcompletion_result = parse_st_json_output(result)
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
print(json.dumps(chat_completion_result, indent=2))
# print(cleaned_result)

# unpack with response.choices[0].message.content
response = DotDict({
'model': None,
'choices': [DotDict({
'message': DotDict(chat_completion_result),
'finish_reason': 'stop', # TODO vary based on webui response
})],
'usage': DotDict({
# TODO fix
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
})
})
return response

except Exception as e:
# TODO
raise e
Empty file.
146 changes: 146 additions & 0 deletions memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import json

from .wrapper_base import LLMChatCompletionWrapper


class Airoboros21Wrapper(LLMChatCompletionWrapper):
cpacker marked this conversation as resolved.
Show resolved Hide resolved
"""Wrapper for Airoboros 70b v2.1: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1
"""

def __init__(self, simplify_json_content=True, include_assistant_prefix=True, clean_function_args=True):
self.simplify_json_content = simplify_json_content
self.include_assistant_prefix = include_assistant_prefix
self.clean_func_args = clean_function_args

def chat_completion_to_prompt(self, messages, functions):
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format

A chat.
USER: {prompt}
ASSISTANT:

Functions support: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#agentfunction-calling

As an AI assistant, please select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format.

Input: I want to know how many times 'Python' is mentioned in my text file.

Available functions:
file_analytics:
description: This tool performs various operations on a text file.
params:
action: The operation we want to perform on the data, such as "count_occurrences", "find_line", etc.
filters:
keyword: The word or phrase we want to search for.

OpenAI functions schema style:

{
"name": "send_message",
"description": "Sends a message to the human user",
"parameters": {
"type": "object",
"properties": {
# https://json-schema.org/understanding-json-schema/reference/array.html
"message": {
"type": "string",
"description": "Message contents. All unicode (including emojis) are supported.",
},
},
"required": ["message"],
}
},
"""
prompt = ""

# System insturctions go first
assert messages[0]['role'] == 'system'
prompt += messages[0]['content']

# Next is the functions preamble
def create_function_description(schema):
# airorobos style
func_str = ""
func_str += f"{schema['name']}:"
func_str += f"\n description: {schema['description']}"
func_str += f"\n params:"
for param_k, param_v in schema['parameters']['properties'].items():
# TODO we're ignoring type
func_str += f"\n {param_k}: {param_v['description']}"
# TODO we're ignoring schema['parameters']['required']
return func_str

prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
prompt += f"\nAvailable functions:"
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"

# Last are the user/assistant messages
for message in messages[1:]:
assert message['role'] in ['user', 'assistant', 'function'], message

if message['role'] == 'user':
if self.simplify_json_content:
try:
content_json = json.loads(message['content'])
content_simple = content_json['message']
prompt += f"\nUSER: {content_simple}"
except:
prompt += f"\nUSER: {message['content']}"
elif message['role'] == 'assistant':
prompt += f"\nASSISTANT: {message['content']}"
elif message['role'] == 'function':
# TODO
continue
# prompt += f"\nASSISTANT: (function return) {message['content']}"
else:
raise ValueError(message)

if self.include_assistant_prefix:
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
cpacker marked this conversation as resolved.
Show resolved Hide resolved
prompt += f"\nASSISTANT:"

return prompt

def clean_function_args(self, function_name, function_args):
"""Some basic MemGPT-specific cleaning of function args"""
cleaned_function_name = function_name
cleaned_function_args = function_args.copy()

if function_name == 'send_message':
# strip request_heartbeat
cleaned_function_args.pop('request_heartbeat', None)

# TODO more cleaning to fix errors LLM makes
return cleaned_function_name, cleaned_function_args

def output_to_chat_completion_response(self, raw_llm_output):
"""Turn raw LLM output into a ChatCompletion style response with:
"message" = {
"role": "assistant",
"content": ...,
"function_call": {
"name": ...
"arguments": {
"arg1": val1,
...
}
}
}
"""
function_json_output = json.loads(raw_llm_output)
function_name = function_json_output['function']
function_parameters = function_json_output['params']

if self.clean_func_args:
function_name, function_parameters = self.clean_function_args(function_name, function_parameters)

message = {
'role': 'assistant',
'content': None,
'function_call': {
'name': function_name,
'arguments': json.dumps(function_parameters),
}
}
return message
14 changes: 14 additions & 0 deletions memgpt/local_llm/llm_chat_completion_wrappers/wrapper_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from abc import ABC, abstractmethod


class LLMChatCompletionWrapper(ABC):

@abstractmethod
def chat_completion_to_prompt(self, messages, functions):
"""Go from ChatCompletion to a single prompt string"""
pass

@abstractmethod
def output_to_chat_completion_response(self, raw_llm_output):
"""Turn the LLM output string into a ChatCompletion response"""
pass
54 changes: 54 additions & 0 deletions memgpt/local_llm/webui_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
DETERMINISTIC = {
cpacker marked this conversation as resolved.
Show resolved Hide resolved
'max_new_tokens': 250,
'do_sample': False,
'temperature': 0,
'top_p': 0,
'typical_p': 1,
'repetition_penalty': 1.18,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1,
'top_k': 1,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,
'stopping_strings': [
'\nUSER:',
'\nASSISTANT:',
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
'truncation_length': 4096,
'ban_eos_token': False,
'skip_special_tokens': True,
'top_a': 0,
'tfs': 1,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'mirostat_mode': 2,
'mirostat_tau': 4,
'mirostat_eta': 0.1,
'use_mancer': False
}

SIMPLE = {
'stopping_strings': [
'\nUSER:',
'\nASSISTANT:',
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
'truncation_length': 4096,
}
Loading