Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time Travel Fixes #364

Merged
merged 11 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 58 additions & 22 deletions agentops/time_travel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
@singleton
class TimeTravel:
def __init__(self):
self._completion_overrides_map = {}
self._prompt_override_map = {}
self._completion_overrides = {}

script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(script_dir)
Expand All @@ -19,12 +18,9 @@ def __init__(self):
try:
with open(cache_path, "r") as file:
time_travel_cache_json = json.load(file)
self._completion_overrides_map = time_travel_cache_json.get(
self._completion_overrides = time_travel_cache_json.get(
"completion_overrides"
)
self._prompt_override_map = time_travel_cache_json.get(
"prompt_override"
)
except FileNotFoundError:
return

Expand All @@ -36,7 +32,7 @@ def fetch_time_travel_id(ttd_id):
if ttd_res.code != 200:
raise Exception(f"Failed to fetch TTD with status code {ttd_res.code}")

prompt_to_returns_map = {
completion_overrides = {
"completion_overrides": {
(
str({"messages": item["prompt"]["messages"]})
Expand All @@ -47,7 +43,7 @@ def fetch_time_travel_id(ttd_id):
}
}
with open("agentops_time_travel.json", "w") as file:
json.dump(prompt_to_returns_map, file, indent=4)
json.dump(completion_overrides, file, indent=4)

set_time_travel_active_state(True)
except ApiServerException as e:
Expand All @@ -60,20 +56,60 @@ def fetch_completion_override_from_time_travel_cache(kwargs):
if not check_time_travel_active():
return

if TimeTravel()._completion_overrides_map:
search_prompt = str({"messages": kwargs["messages"]})
result_from_cache = TimeTravel()._completion_overrides_map.get(search_prompt)
return result_from_cache


def fetch_prompt_override_from_time_travel_cache(kwargs):
if not check_time_travel_active():
return
if TimeTravel()._completion_overrides:
return find_cache_hit(kwargs["messages"], TimeTravel()._completion_overrides)


# NOTE: This is specific to the messages: [{'role': '...', 'content': '...'}, ...] format
def find_cache_hit(prompt_messages, completion_overrides):
if not isinstance(prompt_messages, (list, tuple)):
print(
"Time Travel Error - unexpected type for prompt_messages. Expected 'list' or 'tuple'. Got ",
type(prompt_messages),
)
return None

if not isinstance(completion_overrides, dict):
print(
"Time Travel Error - unexpected type for completion_overrides. Expected 'dict'. Got ",
type(completion_overrides),
)
return None
for key, value in completion_overrides.items():
try:
completion_override_dict = eval(key)
if not isinstance(completion_override_dict, dict):
print(
"Time Travel Error - unexpected type for completion_override_dict. Expected 'dict'. Got ",
type(completion_override_dict),
)
continue

if TimeTravel()._prompt_override_map:
search_prompt = str({"messages": kwargs["messages"]})
result_from_cache = TimeTravel()._prompt_override_map.get(search_prompt)
return json.loads(result_from_cache)
cached_messages = completion_override_dict.get("messages")
if not isinstance(cached_messages, list):
print(
"Time Travel Error - unexpected type for cached_messages. Expected 'list'. Got ",
type(cached_messages),
)
continue

if len(cached_messages) != len(prompt_messages):
continue

if all(
isinstance(a, dict)
and isinstance(b, dict)
and a.get("content") == b.get("content")
for a, b in zip(prompt_messages, cached_messages)
):
return value
except (SyntaxError, ValueError, TypeError) as e:
print(
f"Time Travel Error - Error processing completion_overrides item: {e}"
)
except Exception as e:
print(f"Time Travel Error - Unexpected error in find_cache_hit: {e}")
return None


def check_time_travel_active():
Expand Down Expand Up @@ -114,7 +150,7 @@ def set_time_travel_active_state(is_active: bool):

if is_active:
manage_time_travel_state(activated=True)
print("AgentOps: Time Travel Activated")
print("🖇 AgentOps: Time Travel Activated")
else:
manage_time_travel_state(activated=False)
print("🖇 AgentOps: Time Travel Deactivated")
Expand Down
6 changes: 3 additions & 3 deletions examples/openai-gpt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@
"outputs": [],
"source": [
"message = [{\"role\": \"user\", \"content\": \"Write a 12 word poem about secret agents.\"}]\n",
"res = openai.chat.completions.create(\n",
"response = openai.chat.completions.create(\n",
" model=\"gpt-3.5-turbo\", messages=message, temperature=0.5, stream=False\n",
")\n",
"print(res.choices[0].message[\"content\"])"
"print(response.choices[0].message.content)"
]
},
{
Expand Down Expand Up @@ -283,7 +283,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
7 changes: 4 additions & 3 deletions examples/recording-events.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@
"\n",
"openai = OpenAI()\n",
"\n",
"message = ({\"role\": \"user\", \"content\": \"Hello\"},)\n",
"messages = [{\"role\": \"user\", \"content\": \"Hello\"}]\n",
"response = openai.chat.completions.create(\n",
" model=\"gpt-3.5-turbo\", messages=message, temperature=0.5\n",
")"
" model=\"gpt-3.5-turbo\", messages=messages, temperature=0.5\n",
")\n",
"print(response.choices[0].message.content)"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"psutil==5.9.8",
"packaging==23.2",
"termcolor==2.4.0",
"PyYAML==6.0.1"
]
[project.optional-dependencies]
dev = [
Expand Down
2 changes: 1 addition & 1 deletion tests/core_manual_tests/time_travel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
load_dotenv()
client = OpenAI()

agentops.init(tags=["TTD Test", openai.__version__])
agentops.init(default_tags=["TTD Test", openai.__version__])

try:
chat_completion_1 = client.chat.completions.create(
Expand Down
5 changes: 2 additions & 3 deletions tests/test_time_travel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@ class TestTimeTravel(unittest.TestCase):
@patch(
"builtins.open",
new_callable=mock_open,
read_data='{"completion_overrides": {}, "prompt_override": {}}',
read_data='{"completion_overrides": {}}',
)
def test_init(self, mock_open, mock_abspath, mock_dirname):
mock_abspath.return_value = "/path/to/script"
mock_dirname.return_value = "/path/to"
instance = TimeTravel()
self.assertEqual(instance._completion_overrides_map, {})
self.assertEqual(instance._prompt_override_map, {})
self.assertEqual(instance._completion_overrides, {})

@patch("os.path.dirname")
@patch("os.path.abspath")
Expand Down
Loading