Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ankit/add documentation #15

Merged
merged 6 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
black==24.8.0
griffe-pydantic==1.0.0
pytest==8.3.2
pytest-flask==1.3.0
ruff==0.6.4
Expand Down
9 changes: 5 additions & 4 deletions backend/spielberg/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@ class AgentStatus:


class AgentResponse(BaseModel):
"""Data model for respones from agents."""

status: str = AgentStatus.SUCCESS
message: str = ""
data: dict = {}


class BaseAgent(ABC):
"""Base class for all agents"""
"""Interface for all agents. All agents should inherit from this class."""

def __init__(self, session: Session, **kwargs):
self.session: Session = session
self.output_message: OutputMessage = self.session.output_message

def get_parameters(self):
function_inferrer = FunctionInferrer.infer_from_function_reference(
self.run
)
"""Return the automatically inferred parameters for the function using the dcstring of the function."""
function_inferrer = FunctionInferrer.infer_from_function_reference(self.run)
function_json = function_inferrer.to_json_schema()
parameters = function_json.get("parameters")
if not parameters:
Expand Down
6 changes: 3 additions & 3 deletions backend/spielberg/core/reasoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
input_message: InputMessage,
session: Session,
):
"""Initialize the ReasoningEngine.
"""Initialize the ReasoningEngine with the input message and session.

:param input_message: The input message to the reasoning engine.
:param session: The session instance.
Expand All @@ -72,7 +72,7 @@ def register_agents(self, agents: List[BaseAgent]):
self.agents.extend(agents)

def build_context(self):
"""Build the context for the reasoning engine and Agents."""
"""Build the context for the reasoning engine it adds the information about the video or collection to the reasoning context."""
input_context = ContextMessage(
content=self.input_message.content, role=RoleTypes.user
)
Expand Down Expand Up @@ -111,7 +111,7 @@ def build_context(self):
self.session.reasoning_context.append(input_context)

def run_agent(self, agent_name: str, *args, **kwargs) -> AgentResponse:
"""Run an agent.
"""Run an agent with the given name and arguments.

:param str agent_name: The name of the agent to run
:param args: The arguments to pass to the agent
Expand Down
58 changes: 41 additions & 17 deletions backend/spielberg/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@


class RoleTypes(str, Enum):
"""Role types for the context message."""

system = "system"
user = "user"
assistant = "assistant"
tool = "tool"


class MsgStatus(str, Enum):
"""Output message status."""
"""Message status for the message, for loading state."""

progress = "progress"
success = "success"
Expand All @@ -27,20 +29,24 @@ class MsgStatus(str, Enum):


class MsgType(str, Enum):
"""Message type."""
"""Message type for the message. input is for the user input and output is for the spielberg output."""

input = "input"
output = "output"


class ContentType(str, Enum):
"""Content type for the content in the input/output message."""

text = "text"
video = "video"
image = "image"
search_results = "search_results"


class BaseContent(BaseModel):
"""Base content class for the content in the message."""

model_config = ConfigDict(
arbitrary_types_allowed=True,
use_enum_values=True,
Expand All @@ -54,6 +60,8 @@ class BaseContent(BaseModel):


class TextContent(BaseContent):
"""Text content model class for text content."""

text: str = ""
type: ContentType = ContentType.text

Expand All @@ -72,6 +80,8 @@ class VideoData(BaseModel):


class VideoContent(BaseContent):
"""Video content model class for video content."""

video: Optional[VideoData] = None
type: ContentType = ContentType.video

Expand All @@ -87,6 +97,8 @@ class ImageData(BaseModel):


class ImageContent(BaseContent):
"""Image content model class for image content."""

image: Optional[ImageData] = None
type: ContentType = ContentType.image

Expand Down Expand Up @@ -116,13 +128,7 @@ class SearchResultsContent(BaseContent):


class BaseMessage(BaseModel):
"""Base message class.

:param str session_id: Session is of the messages
:param str conv_id: Conversation id
:param int msg_id: (optional) Message id
:param msg_type: Type of the message
"""
"""Base message class for the input/output message. All the input/output messages will be inherited from this class."""

model_config = ConfigDict(
arbitrary_types_allowed=True,
Expand All @@ -145,37 +151,37 @@ class BaseMessage(BaseModel):


class InputMessage(BaseMessage):
"""Input message to the agent

:param BaseDB db: Database instance
:param MsgType msg_type: :class:`MsgType` of the message
"""
"""Input message from the user. This class is used to create the input message from the user."""

db: BaseDB
msg_type: MsgType = MsgType.input

def publish(self):
"""Store the message in the database. for conversation history."""
self.db.add_or_update_msg_to_conv(**self.model_dump(exclude={"db"}))


class OutputMessage(BaseMessage):
"""Output message from the agent"""
"""Output message from the spielberg. This class is used to create the output message from the spielberg."""

db: BaseDB = Field(exclude=True)
msg_type: MsgType = MsgType.output
status: MsgStatus = MsgStatus.progress

def update_status(self, status: MsgStatus):
"""Update the status of the message and publish the message to the socket. for loading state."""
self.status = status
self._publish()

def push_update(self):
"""Publish the message to the socket."""
try:
emit("chat", self.model_dump(), namespace="/chat")
except Exception as e:
print(f"Error in emitting message: {str(e)}")

def publish(self):
"""Store the message in the database. for conversation history and publish the message to the socket."""
self._publish()

def _publish(self):
Expand All @@ -187,7 +193,7 @@ def _publish(self):


class ContextMessage(BaseModel):
"""Context message class."""
"""Context message class. This class is used to create the context message for the reasoning context."""

model_config = ConfigDict(
arbitrary_types_allowed=True,
Expand All @@ -201,6 +207,7 @@ class ContextMessage(BaseModel):
role: RoleTypes = RoleTypes.system

def to_llm_msg(self):
"""Convert the context message to the llm message."""
msg = {
"role": self.role,
"content": self.content,
Expand All @@ -222,10 +229,13 @@ def to_llm_msg(self):

@classmethod
def from_json(cls, json_data):
"""Create the context message from the json data."""
return cls(**json_data)


class Session:
"""A class to manage and interact with a session in the database. The session is used to store the conversation and reasoning context messages."""

def __init__(
self,
db: BaseDB,
Expand All @@ -250,12 +260,14 @@ def __init__(
self.get_context_messages()

def save_context_messages(self):
"""Save the reasoning context messages to the database."""
context = {
"reasoning": [message.to_llm_msg() for message in self.reasoning_context],
}
self.db.add_or_update_context_msg(self.session_id, context)

def get_context_messages(self):
"""Get the reasoning context messages from the database."""
if not self.reasoning_context:
context = self.db.get_context_messages(self.session_id)
self.reasoning_context = [
Expand All @@ -266,9 +278,18 @@ def get_context_messages(self):
return self.reasoning_context

def create(self):
"""Create a new session in the database."""
self.db.create_session(**self.__dict__)

def new_message(self, msg_type: MsgType = MsgType.output, **kwargs):
def new_message(
self, msg_type: MsgType = MsgType.output, **kwargs
) -> Union[InputMessage, OutputMessage]:
"""Returns a new input/output message object.

:param MsgType msg_type: The type of the message, input or output.
:param dict kwargs: The message attributes.
:return: The input/output message object.
"""
if msg_type == MsgType.input:
return InputMessage(
db=self.db,
Expand All @@ -284,13 +305,16 @@ def new_message(self, msg_type: MsgType = MsgType.output, **kwargs):
)

def get(self):
"""Get the session from the database."""
session = self.db.get_session(self.session_id)
conversation = self.db.get_conversations(self.session_id)
session["conversation"] = conversation
return session

def get_all(self):
"""Get all the sessions from the database."""
return self.db.get_sessions()

def delete(self):
"""Delete the session from the database."""
return self.db.delete_session(self.session_id)
2 changes: 2 additions & 0 deletions backend/spielberg/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


class BaseDB(ABC):
"""Interface for all databases. It provides a common interface for all databases to follow."""

@abstractmethod
def create_session(
self, session_id: str, video_id: str = None, collection_id: str = None
Expand Down
Loading