Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to JSON handling for local LLMs #269

Merged
merged 4 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions memgpt/local_llm/json_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import json


def extract_first_json(string):
"""Handles the case of two JSON objects back-to-back"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just counts opening vs closing brackets to try and pull the first JSON object out of a potential "run-on JSON object"

depth = 0
start_index = None

for i, char in enumerate(string):
if char == "{":
if depth == 0:
start_index = i
depth += 1
elif char == "}":
depth -= 1
if depth == 0 and start_index is not None:
try:
return json.loads(string[start_index : i + 1])
except json.JSONDecodeError as e:
raise json.JSONDecodeError(f"Matched closing bracket, but decode failed with error: {str(e)}")
print("No valid JSON object found.")
raise json.JSONDecodeError("Couldn't find starting bracket")


def add_missing_heartbeat(llm_json):
"""Manually insert heartbeat requests into messages that should have them

Use the following heuristic:
- if (function call is not send_message && prev message['role'] == user): insert heartbeat

Basically, if MemGPT is calling a function (not send_message) immediately after the user sending a message,
it probably is a retriever or insertion call, in which case we likely want to eventually reply with send_message

"message" = {
"role": "assistant",
"content": ...,
"function_call": {
"name": ...
"arguments": {
"arg1": val1,
...
}
}
}
"""
raise NotImplementedError
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this in a separate PR to address #245



def clean_json(raw_llm_output, messages=None, functions=None):
"""Try a bunch of hacks to parse the data coming out of the LLM"""

try:
data = json.loads(raw_llm_output)
except json.JSONDecodeError:
try:
data = json.loads(raw_llm_output + "}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤌

except json.JSONDecodeError:
try:
data = extract_first_json(raw_llm_output + "}")
except:
raise

return data
12 changes: 5 additions & 7 deletions memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

from .wrapper_base import LLMChatCompletionWrapper
from ..json_parser import clean_json
from ...errors import LLMJSONParsingError


Expand Down Expand Up @@ -184,9 +185,9 @@ def output_to_chat_completion_response(self, raw_llm_output):
raw_llm_output = "{" + raw_llm_output

try:
function_json_output = json.loads(raw_llm_output)
function_json_output = clean_json(raw_llm_output)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this should be added to the other wrappers too

except Exception as e:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
try:
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
Expand Down Expand Up @@ -393,12 +394,9 @@ def output_to_chat_completion_response(self, raw_llm_output):
raw_llm_output = "{" + raw_llm_output

try:
function_json_output = json.loads(raw_llm_output)
function_json_output = clean_json(raw_llm_output)
except Exception as e:
try:
function_json_output = json.loads(raw_llm_output + "\n}")
except:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
try:
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
Expand Down
Loading