Skip to content

Commit

Permalink
add token count and cost estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
Luncenok committed Oct 24, 2024
1 parent ab162a9 commit 7c19b67
Show file tree
Hide file tree
Showing 14 changed files with 112 additions and 3 deletions.
1 change: 1 addition & 0 deletions game_state_end_4omini.json

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions scratchpad/response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
content='Next steps:\n\n1. **Complete the task**: Empty the cafeteria trash.\n2. **Move to Medbay**: After completing the task, head to the Medbay to start working on the wiring task.\n3. **Begin the wiring task**: Once in Medbay, start the task to fix wiring. \n\nThis plan focuses on completing your current objective efficiently while also progressing towards your other tasks.'
additional_kwargs={'refusal': None}
response_metadata={
'token_usage': {
'completion_tokens': 82, 'prompt_tokens': 998, 'total_tokens': 1080, 'completion_tokens_details':
{'audio_tokens': None, 'reasoning_tokens': 0},
'prompt_tokens_details':
{'audio_tokens': None, 'cached_tokens': 0}
},
'model_name': 'gpt-4o-mini-2024-07-18',
'system_fingerprint': 'fp_482c22a7bc',
'finish_reason': 'stop',
'logprobs': None}
id='run-45031354-22f3-4a4b-a8f5-9dfac7dd5ecc-0'
usage_metadata={'input_tokens': 998, 'output_tokens': 82, 'total_tokens': 1080, 'input_token_details': {'cache_read': 0}, 'output_token_details': {'reasoning': 0}}
2 changes: 1 addition & 1 deletion src/demo_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def main():
return

game_engine.init_game()
game_engine.state.set_stage(GamePhase.MAIN_MENU) # pause the game at the main menu
# game_engine.state.set_stage(GamePhase.MAIN_MENU) # pause the game at the main menu

if game_engine.state.game_stage == GamePhase.MAIN_MENU:
st.warning("Viewing the game state only. No actions are being performed.")
Expand Down
6 changes: 6 additions & 0 deletions src/game/agents/adventure_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Any

from langchain_openai import ChatOpenAI

from game.consts import ASCII_MAP
from .base_agent import Agent
from langchain.schema import HumanMessage
Expand All @@ -8,6 +10,7 @@


class AdventureAgent(Agent):
llm: ChatOpenAI
def update_state(
self,
observations: str,
Expand All @@ -33,6 +36,8 @@ def create_plan(self) -> str:
current_location=self.state.current_location,
)
plan = self.llm.invoke([HumanMessage(content=plan_prompt)])
print(plan.usage_metadata)
self.add_token_usage(plan.usage_metadata)
return plan_prompt, plan.content.strip()

def choose_action(self, plan: str) -> int:
Expand All @@ -45,6 +50,7 @@ def choose_action(self, plan: str) -> int:
plan=plan,
)
chosen_action = self.llm.invoke([HumanMessage(content=action_prompt)])
self.add_token_usage(chosen_action.usage_metadata)
chosen_action = chosen_action.content.strip()
return action_prompt, self.check_action_valid(chosen_action)

Expand Down
19 changes: 19 additions & 0 deletions src/game/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
from typing import Any, List
from pydantic import BaseModel, Field

from game.agents.usage_metadata import UsageMetadata
from game.consts import TOKEN_COSTS


class AgentState(BaseModel):
history: str = Field(default_factory=str)
current_tasks: List[str] = Field(default_factory=list)
available_actions: List[str] = Field(default_factory=list)
messages: List[str] = Field(default_factory=list)
current_location: str = Field(default_factory=str)
token_usage: UsageMetadata = Field(default_factory=UsageMetadata)

def to_dict(self):
return {
Expand All @@ -17,6 +21,7 @@ def to_dict(self):
"available_actions": self.available_actions,
"messages": self.messages,
"current_location": self.current_location,
"token_usage": self.token_usage.to_dict(),
}


Expand All @@ -37,6 +42,20 @@ def update_state(
@abstractmethod
def act(self) -> Any:
pass

def add_token_usage(self, msg: dict):
# {'input_tokens': 998, 'output_tokens': 82, 'total_tokens': 1080, 'input_token_details': {'cache_read': 0}, 'output_token_details': {'reasoning': 0}}
self.state.token_usage.input_tokens += msg["input_tokens"]
self.state.token_usage.output_tokens += msg["output_tokens"]
self.state.token_usage.total_tokens += msg["total_tokens"]
self.state.token_usage.cache_read += msg["input_token_details"]["cache_read"]
self.update_cost()

def update_cost(self):
self.state.token_usage.cost += self.state.token_usage.input_tokens * TOKEN_COSTS[self.llm.model_name]["input_tokens"]
self.state.token_usage.cost += self.state.token_usage.output_tokens * TOKEN_COSTS[self.llm.model_name]["output_tokens"]
self.state.token_usage.cost += self.state.token_usage.cache_read * TOKEN_COSTS[self.llm.model_name]["cache_read"]


def to_dict(self):
llm_data = "human"
Expand Down
2 changes: 2 additions & 0 deletions src/game/agents/discussion_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def create_discussion_points(self) -> str:
statements=self.state.messages,
)
discussion_points = self.llm.invoke([HumanMessage(content=discussion_prompt)])
self.add_token_usage(discussion_points.usage_metadata)
self.responses.append(f"Discussion points: {discussion_points.content.strip()}")
return discussion_prompt, discussion_points.content.strip()

Expand All @@ -35,6 +36,7 @@ def respond_to_statements(self, statements: str, points: str) -> str:
statements=statements,
)
response = self.llm.invoke([HumanMessage(content=response_prompt)])
self.add_token_usage(response.usage_metadata)
self.responses.append(response.content.strip())
return response_prompt, response.content.strip()

Expand Down
18 changes: 18 additions & 0 deletions src/game/agents/usage_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pydantic import Field
from openai import BaseModel

class UsageMetadata(BaseModel):
input_tokens: int = Field(default=0)
output_tokens: int = Field(default=0)
total_tokens: int = Field(default=0)
cache_read: int = Field(default=0)
cost: float = Field(default=0.0)

def to_dict(self):
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"total_tokens": self.total_tokens,
"cache_read": self.cache_read,
"cost": self.cost,
}
1 change: 1 addition & 0 deletions src/game/agents/voting_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def choose_action(self, discussion_log: str) -> int:
),
)
chosen_action = self.llm.invoke([HumanMessage(content=action_prompt)])
self.add_token_usage(chosen_action.usage_metadata)
chosen_action_str = chosen_action.content.strip()
self.responses.append(f"Chosen vote: {chosen_action_str}")
vote = self.check_action_valid(chosen_action_str)
Expand Down
18 changes: 18 additions & 0 deletions src/game/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,21 @@
| | | /
|/----------------|--------------|--------------|--------|-------/
"""
million = 1000000
TOKEN_COSTS = {
"gpt-4o": {
"input_tokens": 2.5 / million,
"cache_read": 1.25 / million,
"output_tokens": 10 / million,
},
"gpt-4o-mini": {
"input_tokens": 0.15 / million,
"cache_read": 0.075 / million,
"output_tokens": 0.6 / million,
},
"gemini-1.5-flash": {
"input_tokens": 0,
"cache_read": 0,
"output_tokens": 0,
},
}
4 changes: 2 additions & 2 deletions src/game/game_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def perform_step(self) -> bool:
start = self.state.round_of_discussion_start
now = self.state.round_number
max = game_consts.NUM_CHATS
self.state.log_action(f"Discussion{now-start}/{max}: round: {now}. Player to act: {self.state.player_to_act_next}")
self.state.log_action(f"Discussion ({now-start+1}/{max}): round: {now}. Player to act: {self.state.player_to_act_next}")
self.perform_discussion_step()
else:
print("Game is in MAIN_MENU stage - read_only mode")
Expand Down Expand Up @@ -341,7 +341,7 @@ def broadcast_observation(self, key: str, message: str) -> None:

def broadcast_message(self, message: str) -> None:
"""Broadcasts a chat message to all alive players."""
self.broadcast("chat", message)
self.broadcast_observation("chat", message)

def mark_dead_players_as_reported(self) -> None:
"""Marks all dead players as reported to avoid double reporting."""
Expand Down
13 changes: 13 additions & 0 deletions src/game/game_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ def get_player_targets(self, player: Player) -> List[Player]:
and other_player.state.life == PlayerState.ALIVE
and other_player.state.location == player.state.location
]

def get_total_cost(self) -> int:
output = {}
for player in self.players:
output[f"{player.name}_cost"] = player.state.token_usage.cost
total_cost = sum(output.values())
output["total_cost"] = total_cost
output["average_per_round"] = total_cost / (self.round_number+1)
output["average_per_player"] = total_cost / len(self.players)
output["average_per_round_per_player"] = total_cost / (len(self.players) * (self.round_number) + self.player_to_act_next + 1)
return output



def to_dict(self):
return {
Expand Down
1 change: 1 addition & 0 deletions src/game/gui_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def update_gui(self, game_state: GameState):
with sidebar:
self._display_short_player_info(player, st)
with self.game_log_placeholder.container():
st.json(game_state.get_total_cost())
st.text("\n".join(game_state.playthrough))
annotated_text(
"[Warek]: I agree that ",
Expand Down
3 changes: 3 additions & 0 deletions src/game/models/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from game.models.engine import GameLocation, GamePhase
from game.models.tasks import Task
from game.agents.usage_metadata import UsageMetadata


class PlayerState(str, Enum):
Expand All @@ -18,6 +19,7 @@ class RoundData(BaseModel):
life: PlayerState = PlayerState.ALIVE
tasks: List[Task] = Field(default_factory=list)
llm_responses: List[str] = Field(default_factory=list)
token_usage: UsageMetadata = Field(default_factory=UsageMetadata)
prompt: str = ""
actions: List[str] = Field(default_factory=list)
response: str = ""
Expand All @@ -31,6 +33,7 @@ def to_dict(self):
"location": self.location.value,
"stage": self.stage.value,
"life": self.life.value,
"token_usage": self.token_usage.to_dict(),
"tasks": [str(task) for task in self.tasks],
"llm_responses": self.llm_responses,
"prompt": self.prompt,
Expand Down
12 changes: 12 additions & 0 deletions src/game/players/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI

from game.agents.usage_metadata import UsageMetadata


class AIPlayer(Player):
llm_model_name: str
Expand Down Expand Up @@ -44,6 +46,7 @@ def prompt_action(self, actions: List[str]) -> int:
)
prompts, chosen_action = self.adventure_agent.act()
self.state.llm_responses = self.adventure_agent.responses
self.add_token_usage(self.adventure_agent.state.token_usage)
self.state.response = str(chosen_action)
self.state.prompt = prompts
return chosen_action
Expand All @@ -54,6 +57,7 @@ def prompt_discussion(self) -> str:
self.discussion_agent.update_state(observations=history, messages=statements)
message_prompt, message = self.discussion_agent.act()
self.state.llm_responses = self.discussion_agent.responses
self.add_token_usage(self.discussion_agent.state.token_usage)
self.state.response = message
self.state.prompt = message_prompt
return message
Expand All @@ -65,9 +69,17 @@ def prompt_vote(self, voting_actions: List[str]) -> int:
)
vote_prompt, vote = self.voting_agent.choose_action(self.get_message_str())
self.state.llm_responses = self.voting_agent.responses
self.add_token_usage(self.voting_agent.state.token_usage)
self.state.response = str(vote)
self.state.prompt = vote_prompt
return vote

def add_token_usage(self, usage: UsageMetadata):
self.state.token_usage.input_tokens += usage.input_tokens
self.state.token_usage.output_tokens += usage.output_tokens
self.state.token_usage.total_tokens += usage.total_tokens
self.state.token_usage.cache_read += usage.cache_read
self.state.token_usage.cost += usage.cost

def __str__(self):
return self.name
Expand Down

0 comments on commit 7c19b67

Please sign in to comment.