Skip to content

Commit

Permalink
add cost estimation graph (0-shot to display, 1-shot to make future d…
Browse files Browse the repository at this point in the history
…otted)
  • Loading branch information
Luncenok committed Oct 27, 2024
1 parent e6d2732 commit c412a50
Showing 1 changed file with 77 additions and 2 deletions.
79 changes: 77 additions & 2 deletions src/llm_postor/game/gui_handler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import random
import uuid
from sklearn.linear_model import LinearRegression
import streamlit as st
from typing import List, Optional
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from streamlit.delta_generator import DeltaGenerator
from annotated_text import annotated_text
Expand Down Expand Up @@ -48,7 +49,14 @@ def display_gui(self, game_engine: GameEngine, chat_analyzer: ChatAnalyzer):
st.session_state.results = results
if "results" in st.session_state:
self._display_annotated_text(json.loads(st.session_state.results))
# st.json(game_engine.state.get_total_cost())
# Cost Visualization
cost_data = self.get_cost_data(game_engine)
estimated_cost_data = self.estimate_future_cost(cost_data, 10)
combined_cost_data = self.combine_data(cost_data, estimated_cost_data)
self.plot_cost(combined_cost_data)
st.text("Cost Breakdown:")
st.json(game_engine.state.get_total_cost(), expanded=False)
st.text("Raw Game State:")
st.json(game_engine.state.to_dict(), expanded=False)

if should_perform_step:
Expand Down Expand Up @@ -263,3 +271,70 @@ def _display_discussion_chat(self, players: List[Player]):
discussion_chat = "\n".join([x for x in player.get_chat_messages() if x.startswith(f"[{player.name}]")])
st.text_area(label="Discussion log:", value=discussion_chat)
return discussion_chat

def get_cost_data(self, game_engine: GameEngine) -> Dict[str, List[float]]:
"""Extracts cost data from player history."""
cost_data = {}
for player in game_engine.state.players:
costs = [round(r.token_usage.cost, 4) for r in player.history.rounds]
cost_data[player.name] = costs
return cost_data

def estimate_future_cost(self, cost_data: Dict[str, List[float]], rounds_to_forecast: int) -> Dict[str, List[float]]:
"""Estimates future cost using linear regression."""
estimated_cost_data = {}
for player_name, costs in cost_data.items():
# Prepare data for linear regression
X = [[i] for i in range(len(costs))]
y = costs
model = LinearRegression()
model.fit(X, y)

# Estimate future costs
estimated_costs = [round(model.predict([[i]])[0], 4) for i in range(len(costs), len(costs) + rounds_to_forecast)]
estimated_cost_data[player_name] = estimated_costs
return estimated_cost_data

def combine_data(self, cost_data: Dict[str, List[float]], estimated_cost_data: Dict[str, List[float]]) -> Dict[str, List[float]]:
"""Combines actual and estimated cost data."""
combined_cost_data = {}
for player_name in cost_data:
combined_cost_data[player_name] = cost_data[player_name] + estimated_cost_data[player_name]
return combined_cost_data

def plot_cost(self, cost_data: Dict[str, List[float]]):
"""Plots cost data using Plotly."""
fig = go.Figure()

for player_name, costs in cost_data.items():
# Separate actual and estimated costs
actual_costs = costs[:len(cost_data['Mateusz'])-10]
estimated_costs = costs[len(cost_data['Mateusz'])-10:]

# Plot actual costs as solid lines
fig.add_trace(go.Scatter(x=list(range(1, len(actual_costs) + 1)), y=actual_costs, name=player_name, mode='lines'))

# Plot estimated costs as dashed lines
fig.add_trace(go.Scatter(x=list(range(len(actual_costs), len(costs))), y=estimated_costs, name=player_name, mode='lines', line=dict(dash='dash')))

# Calculate total cost
total_costs = [sum(costs[i] for costs in cost_data.values()) for i in range(len(cost_data['Mateusz']))]

# Separate actual and estimated total costs
actual_total_costs = total_costs[:len(cost_data['Mateusz'])-10]
estimated_total_costs = total_costs[len(cost_data['Mateusz'])-10:]

# Plot actual total cost as solid lines
fig.add_trace(go.Scatter(x=list(range(1, len(actual_total_costs) + 1)), y=actual_total_costs, name="Total Cost", mode='lines'))

# Plot estimated total cost as dashed lines
fig.add_trace(go.Scatter(x=list(range(len(actual_total_costs), len(total_costs))), y=estimated_total_costs, name="Total Cost", mode='lines', line=dict(dash='dash')))

fig.update_layout(
title="Player Cost Over Rounds",
xaxis_title="Round Number",
yaxis_title="Cost",
legend_title="Players"
)

st.plotly_chart(fig)

0 comments on commit c412a50

Please sign in to comment.