Skip to content

Commit

Permalink
make memory configurable, consistently truncate discord messages, fix…
Browse files Browse the repository at this point in the history
… action prompt
  • Loading branch information
ssube committed Jun 3, 2024
1 parent a970572 commit f25dd57
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 45 deletions.
2 changes: 1 addition & 1 deletion client/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export interface StringParameter {
export interface NumberParameter {
type: 'number';
default?: number;
enum?: Array<string>;
enum?: Array<number>;
}

export type Parameter = BooleanParameter | NumberParameter | StringParameter;
Expand Down
20 changes: 16 additions & 4 deletions client/src/prompt.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ export function enumerateSignificantParameterValues(name: string, world: World)
}
}

export function convertSignificantParameter(name: string, parameter: Parameter, world: Maybe<World>): Parameter {
export function convertSignificantParameter<T extends Parameter>(name: string, parameter: T, world: Maybe<World>): T {
if (parameter.type === 'boolean') {
return parameter;
}
Expand All @@ -154,15 +154,27 @@ export function formatAction(action: string, parameters: Record<string, boolean
return `~${action}:${Object.entries(parameters).map(([name, value]) => `${name}=${value}`).join(',')}`;
}

export function getEnumOrDefault<T>(defaultValue: Maybe<T>, enumValues: Maybe<Array<T>>, evenMoreDefault: T): T {
if (doesExist(defaultValue)) {
return defaultValue;
}

if (doesExist(enumValues)) {
return enumValues[0];
}

return evenMoreDefault;
}

export function makeDefaultParameterValues(parameters: Record<string, Parameter>) {
return Object.entries(parameters).reduce((acc, [name, parameter]) => {
switch (parameter.type) {
case 'boolean':
return { ...acc, [name]: mustDefault(parameter.default, false) };
return { ...acc, [name]: getEnumOrDefault(parameter.default, [], false) };
case 'number':
return { ...acc, [name]: mustDefault(parameter.default, 0) };
return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, 0) };
case 'string':
return { ...acc, [name]: mustDefault(parameter.default, '') };
return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, '') };
default:
return acc;
}
Expand Down
12 changes: 6 additions & 6 deletions taleweave/actions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
broadcast,
get_agent_for_character,
get_character_agent_for_name,
get_game_config,
get_prompt,
world_context,
)
Expand All @@ -22,8 +23,6 @@

logger = getLogger(__name__)

MAX_CONVERSATION_STEPS = 2


def action_examine(target: str) -> str:
"""
Expand Down Expand Up @@ -173,7 +172,8 @@ def action_ask(character: str, question: str) -> str:
character: The name of the character to ask. You cannot ask yourself questions.
question: The question to ask them.
"""
# capture references to the current character and room, because they will be overwritten
config = get_game_config()

with action_context() as (action_room, action_character):
# sanity checks
question_character, question_agent = get_character_agent_for_name(character)
Expand Down Expand Up @@ -216,7 +216,7 @@ def action_ask(character: str, question: str) -> str:
end_prompt,
echo_function=action_tell.__name__,
echo_parameter="message",
max_length=MAX_CONVERSATION_STEPS,
max_length=config.world.character.conversation_limit,
)

if result:
Expand All @@ -233,7 +233,7 @@ def action_tell(character: str, message: str) -> str:
character: The name of the character to tell. You cannot talk to yourself.
message: The message to tell them.
"""
# capture references to the current character and room, because they will be overwritten
config = get_game_config()

with action_context() as (action_room, action_character):
# sanity checks
Expand Down Expand Up @@ -268,7 +268,7 @@ def action_tell(character: str, message: str) -> str:
end_prompt,
echo_function=action_tell.__name__,
echo_parameter="message",
max_length=MAX_CONVERSATION_STEPS,
max_length=config.world.character.conversation_limit,
)

if result:
Expand Down
20 changes: 11 additions & 9 deletions taleweave/bot/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ async def broadcast_event(message: str | GameEvent):
event_messages[event_message.id] = message


def truncate(text: str, length: int = 1000) -> str:
if len(text) > length:
return text[:length] + "..."
return text


def embed_from_event(event: GameEvent) -> Embed | None:
if isinstance(event, GenerateEvent):
return embed_from_generate(event)
Expand Down Expand Up @@ -357,7 +363,7 @@ def embed_from_action(event: ActionEvent):

def embed_from_reply(event: ReplyEvent):
reply_embed = Embed(title=event.room.name, description=event.speaker.name)
reply_embed.add_field(name="Reply", value=event.text)
reply_embed.add_field(name="Reply", value=truncate(event.text))
return reply_embed


Expand All @@ -367,12 +373,8 @@ def embed_from_generate(event: GenerateEvent) -> Embed:


def embed_from_result(event: ResultEvent):
text = event.result
if len(text) > 1000:
text = text[:1000] + "..."

result_embed = Embed(title=event.room.name, description=event.character.name)
result_embed.add_field(name="Result", value=text)
result_embed.add_field(name="Result", value=truncate(event.result))
return result_embed


Expand All @@ -384,14 +386,14 @@ def embed_from_player(event: PlayerEvent):
title = format_prompt("discord_leave_title", event=event)
description = format_prompt("discord_leave_result", event=event)

player_embed = Embed(title=title, description=description)
player_embed = Embed(title=title, description=truncate(description))
return player_embed


def embed_from_prompt(event: PromptEvent):
# TODO: ping the player
prompt_embed = Embed(title=event.room.name, description=event.character.name)
prompt_embed.add_field(name="Prompt", value=event.prompt)
prompt_embed.add_field(name="Prompt", value=truncate(event.prompt))
return prompt_embed


Expand All @@ -400,5 +402,5 @@ def embed_from_status(event: StatusEvent):
title=event.room.name if event.room else "",
description=event.character.name if event.character else "",
)
status_embed.add_field(name="Status", value=event.text)
status_embed.add_field(name="Status", value=truncate(event.text))
return status_embed
2 changes: 1 addition & 1 deletion taleweave/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def snapshot_system(world: World, turn: int, data: None = None) -> None:
set_dungeon_master(world_builder)

# start the sim
logger.debug("simulating world: %s", world)
logger.debug("simulating world: %s", world.name)
simulate_world(
world,
turns=args.turns,
Expand Down
26 changes: 12 additions & 14 deletions taleweave/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
set_current_world,
set_game_systems,
)
from taleweave.errors import ActionError
from taleweave.game_system import GameSystem
from taleweave.models.entity import Character, Room, World
from taleweave.models.event import ActionEvent, ResultEvent
Expand Down Expand Up @@ -117,12 +118,9 @@ def result_parser(value, **kwargs):
# TODO: only emit valid actions that parse and run correctly, and try to avoid parsing the JSON twice
event = ActionEvent.from_json(value, room, character)
else:
# TODO: this path should be removed and throw
# logger.warning(
# "invalid action, emitting as result event - this is a bug somewhere"
# )
# event = ResultEvent(value, room, character)
raise ValueError("invalid non-JSON action")
raise ActionError(
"Your last reply was not valid JSON. Please try again and reply with a valid function call in JSON format."
)

broadcast(event)

Expand Down Expand Up @@ -216,14 +214,14 @@ def prompt_character_planning(
while not stop_condition(current=i):
result = loop_retry(
agent,
get_prompt("world_simulate_character_planning"),
context={
"event_count": event_count,
"events_prompt": events_prompt,
"note_count": note_count,
"notes_prompt": notes_prompt,
"room_summary": summarize_room(room, character),
},
format_prompt(
"world_simulate_character_planning",
event_count=event_count,
events_prompt=events_prompt,
note_count=note_count,
notes_prompt=notes_prompt,
room_summary=summarize_room(room, character),
),
result_parser=result_parser,
stop_condition=stop_condition,
toolbox=planner_toolbox,
Expand Down
11 changes: 7 additions & 4 deletions taleweave/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from packit.agent import Agent, agent_easy_connect
from pydantic import RootModel

from taleweave.context import get_all_character_agents, set_character_agent
from taleweave.context import (
get_all_character_agents,
get_game_config,
set_character_agent,
)
from taleweave.models.entity import World
from taleweave.player import LocalPlayer

MEMORY_LIMIT = 25 # 10


def create_agents(
world: World,
Expand Down Expand Up @@ -69,6 +71,7 @@ def snapshot_world(world: World, turn: int):
def restore_memory(
data: Sequence[str | Dict[str, str]]
) -> deque[str | AIMessage | HumanMessage | SystemMessage]:
config = get_game_config()
memories = []

for memory in data:
Expand All @@ -85,7 +88,7 @@ def restore_memory(
elif memory_type == "ai":
memories.append(AIMessage(content=memory_content))

return deque(memories, maxlen=MEMORY_LIMIT)
return deque(memories, maxlen=config.world.character.memory_limit)


def save_world(world, filename):
Expand Down
3 changes: 2 additions & 1 deletion taleweave/systems/digest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from taleweave.game_system import FormatPerspective, GameSystem
from taleweave.models.entity import Character, Room, World, WorldEntity
from taleweave.models.event import ActionEvent, GameEvent
from taleweave.utils.prompt import format_str
from taleweave.utils.search import find_containing_room

logger = getLogger(__name__)
Expand All @@ -22,7 +23,7 @@ def create_turn_digest(
if prompt_key in library.prompts:
try:
template = library.prompts[prompt_key]
message = template.format(event=event)
message = format_str(template, event=event)
messages.append(message)
except Exception:
logger.exception("error formatting digest event: %s", event)
Expand Down
13 changes: 8 additions & 5 deletions taleweave/utils/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
from jinja2 import Environment

from taleweave.context import get_prompt_library
from taleweave.utils.string import and_list, or_list
from taleweave.utils.world import describe_entity, name_entity

logger = getLogger(__name__)

jinja_env = Environment()
jinja_env.filters["describe"] = describe_entity
jinja_env.filters["name"] = name_entity
jinja_env.filters["and_list"] = and_list
jinja_env.filters["or_list"] = or_list


def format_prompt(prompt_key: str, **kwargs) -> str:
try:
Expand All @@ -19,9 +26,5 @@ def format_prompt(prompt_key: str, **kwargs) -> str:


def format_str(template_str: str, **kwargs) -> str:
env = Environment()
env.filters["describe"] = describe_entity
env.filters["name"] = name_entity

template = env.from_string(template_str)
template = jinja_env.from_string(template_str)
return template.render(**kwargs)

0 comments on commit f25dd57

Please sign in to comment.