Skip to content

Commit

Permalink
fix player history and state, remove ascii map from prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
Luncenok committed Oct 25, 2024
1 parent 16d01f4 commit 24883b6
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 21 deletions.
3 changes: 1 addition & 2 deletions src/llm_postor/game/agents/adventure_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def create_plan(self) -> str:
plan_prompt = ADVENTURE_PLAN_TEMPLATE.format(
player_name=self.player_name,
player_role=self.role,
ASCII_MAP=ASCII_MAP,
history=self.state.history,
tasks=self.state.current_tasks,
tasks=[str(task) for task in self.state.current_tasks],
actions="<action>"
+ "</action><action>".join(self.state.available_actions)
+ "</action>",
Expand Down
6 changes: 4 additions & 2 deletions src/llm_postor/game/agents/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, List
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

from llm_postor.game.agents.usage_metadata import UsageMetadata
Expand All @@ -26,7 +27,7 @@ def to_dict(self):

class Agent(ABC, BaseModel):
# has to be Any because of MagicMock. TODO: Fix test integration with pydanitc
llm: Any # Optional[ChatOpenAI | ChatGoogleGenerativeAI] = None
llm: ChatOpenAI = None
state: AgentState = Field(default_factory=AgentState)
responses: List[str] = Field(default_factory=list)
player_name: str = ""
Expand Down Expand Up @@ -54,7 +55,8 @@ def add_token_usage(self, msg: dict):
def update_cost(self):
for_model = self.llm.model_name
if for_model not in TOKEN_COSTS:
for_model = "gpt-4o-mini"
print(f"Model {for_model} not found in TOKEN_COSTS. defaulting to openai/gpt-4o-mini")
for_model = "openai/gpt-4o-mini"

self.state.token_usage.cost += self.state.token_usage.input_tokens * TOKEN_COSTS[for_model]["input_tokens"]
self.state.token_usage.cost += self.state.token_usage.output_tokens * TOKEN_COSTS[for_model]["output_tokens"]
Expand Down
2 changes: 1 addition & 1 deletion src/llm_postor/game/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"cache_read": 0.075 / million,
"output_tokens": 0.6 / million,
},
"google/gemini-1.5-flash-exp": {
"google/gemini-flash-1.5-exp": {
"input_tokens": 0,
"cache_read": 0,
"output_tokens": 0,
Expand Down
8 changes: 4 additions & 4 deletions src/llm_postor/game/gui_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,18 @@ def _display_location(self, player: Player):
)

def _display_action_taken(self, player: Player):
action = player.history.rounds[-1].response
action = player.state.response
if action.isdigit():
st.write(f"Action Taken: {player.history.rounds[-1].actions[int(action)]}")
st.write(f"Action Taken: {player.state.actions[int(action)]}")
else:
st.write(f"Action Taken: {action}")

def _display_action_result(self, player: Player):
st.write(f"Action Result: {player.history.rounds[-1].action_result}")
st.write(f"Action Result: {player.state.action_result}")

def _display_recent_actions(self, player: Player):
st.write("Seen Actions:")
for action in player.history.rounds[-1].seen_actions:
for action in player.state.seen_actions:
st.write(f"- {action}")

def _display_map(self, game_state: GameState):
Expand Down
3 changes: 0 additions & 3 deletions src/llm_postor/game/llm_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@
Your Current Objectives:
{tasks}
Current Map Layout:
{ASCII_MAP}
Actions You Can Take:
{actions}
Expand Down
10 changes: 1 addition & 9 deletions src/llm_postor/game/models/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,7 @@ def to_dict(self):


class PlayerHistory(BaseModel):
rounds: List[RoundData] = Field(
default_factory=lambda: [
RoundData(
location=GameLocation.LOC_CAFETERIA,
stage=GamePhase.MAIN_MENU,
life=PlayerState.ALIVE,
)
]
)
rounds: List[RoundData] = Field(default_factory=list)

def add_round(self, round_data: RoundData):
self.rounds.append(round_data)
Expand Down

0 comments on commit 24883b6

Please sign in to comment.