From 24883b6c5607ff1af9cd63bed79586fd4106f76e Mon Sep 17 00:00:00 2001 From: Luncenok Date: Fri, 25 Oct 2024 23:55:58 +0200 Subject: [PATCH] fix player history and state, remove ascii map from prompt --- src/llm_postor/game/agents/adventure_agent.py | 3 +-- src/llm_postor/game/agents/base_agent.py | 6 ++++-- src/llm_postor/game/consts.py | 2 +- src/llm_postor/game/gui_handler.py | 8 ++++---- src/llm_postor/game/llm_prompts.py | 3 --- src/llm_postor/game/models/history.py | 10 +--------- 6 files changed, 11 insertions(+), 21 deletions(-) diff --git a/src/llm_postor/game/agents/adventure_agent.py b/src/llm_postor/game/agents/adventure_agent.py index 4fb794f..991f98b 100644 --- a/src/llm_postor/game/agents/adventure_agent.py +++ b/src/llm_postor/game/agents/adventure_agent.py @@ -24,9 +24,8 @@ def create_plan(self) -> str: plan_prompt = ADVENTURE_PLAN_TEMPLATE.format( player_name=self.player_name, player_role=self.role, - ASCII_MAP=ASCII_MAP, history=self.state.history, - tasks=self.state.current_tasks, + tasks=[str(task) for task in self.state.current_tasks], actions="" + "".join(self.state.available_actions) + "", diff --git a/src/llm_postor/game/agents/base_agent.py b/src/llm_postor/game/agents/base_agent.py index 5f40b63..dce555a 100644 --- a/src/llm_postor/game/agents/base_agent.py +++ b/src/llm_postor/game/agents/base_agent.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Any, List +from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field from llm_postor.game.agents.usage_metadata import UsageMetadata @@ -26,7 +27,7 @@ def to_dict(self): class Agent(ABC, BaseModel): # has to be Any because of MagicMock. TODO: Fix test integration with pydanitc - llm: Any # Optional[ChatOpenAI | ChatGoogleGenerativeAI] = None + llm: ChatOpenAI = None state: AgentState = Field(default_factory=AgentState) responses: List[str] = Field(default_factory=list) player_name: str = "" @@ -54,7 +55,8 @@ def add_token_usage(self, msg: dict): def update_cost(self): for_model = self.llm.model_name if for_model not in TOKEN_COSTS: - for_model = "gpt-4o-mini" + print(f"Model {for_model} not found in TOKEN_COSTS. defaulting to openai/gpt-4o-mini") + for_model = "openai/gpt-4o-mini" self.state.token_usage.cost += self.state.token_usage.input_tokens * TOKEN_COSTS[for_model]["input_tokens"] self.state.token_usage.cost += self.state.token_usage.output_tokens * TOKEN_COSTS[for_model]["output_tokens"] diff --git a/src/llm_postor/game/consts.py b/src/llm_postor/game/consts.py index 6465e71..aada2a4 100644 --- a/src/llm_postor/game/consts.py +++ b/src/llm_postor/game/consts.py @@ -50,7 +50,7 @@ "cache_read": 0.075 / million, "output_tokens": 0.6 / million, }, - "google/gemini-1.5-flash-exp": { + "google/gemini-flash-1.5-exp": { "input_tokens": 0, "cache_read": 0, "output_tokens": 0, diff --git a/src/llm_postor/game/gui_handler.py b/src/llm_postor/game/gui_handler.py index b444e6c..eeac826 100644 --- a/src/llm_postor/game/gui_handler.py +++ b/src/llm_postor/game/gui_handler.py @@ -121,18 +121,18 @@ def _display_location(self, player: Player): ) def _display_action_taken(self, player: Player): - action = player.history.rounds[-1].response + action = player.state.response if action.isdigit(): - st.write(f"Action Taken: {player.history.rounds[-1].actions[int(action)]}") + st.write(f"Action Taken: {player.state.actions[int(action)]}") else: st.write(f"Action Taken: {action}") def _display_action_result(self, player: Player): - st.write(f"Action Result: {player.history.rounds[-1].action_result}") + st.write(f"Action Result: {player.state.action_result}") def _display_recent_actions(self, player: Player): st.write("Seen Actions:") - for action in player.history.rounds[-1].seen_actions: + for action in player.state.seen_actions: st.write(f"- {action}") def _display_map(self, game_state: GameState): diff --git a/src/llm_postor/game/llm_prompts.py b/src/llm_postor/game/llm_prompts.py index 133b461..ecd4077 100644 --- a/src/llm_postor/game/llm_prompts.py +++ b/src/llm_postor/game/llm_prompts.py @@ -37,9 +37,6 @@ Your Current Objectives: {tasks} -Current Map Layout: -{ASCII_MAP} - Actions You Can Take: {actions} diff --git a/src/llm_postor/game/models/history.py b/src/llm_postor/game/models/history.py index 9e1a4f3..6111fb8 100644 --- a/src/llm_postor/game/models/history.py +++ b/src/llm_postor/game/models/history.py @@ -46,15 +46,7 @@ def to_dict(self): class PlayerHistory(BaseModel): - rounds: List[RoundData] = Field( - default_factory=lambda: [ - RoundData( - location=GameLocation.LOC_CAFETERIA, - stage=GamePhase.MAIN_MENU, - life=PlayerState.ALIVE, - ) - ] - ) + rounds: List[RoundData] = Field(default_factory=list) def add_round(self, round_data: RoundData): self.rounds.append(round_data)