Skip to content

Commit

Permalink
Now allowing extra fields in prompts / tools (instead of ignoring the…
Browse files Browse the repository at this point in the history
…m altogether), but raising warnings when we encounter them. (#21)

Co-authored-by: Thejas N U <77475353+ThejasNU@users.noreply.github.com>
  • Loading branch information
glennga and ThejasNU authored Nov 22, 2024
1 parent 5f75759 commit a579e7e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
4 changes: 4 additions & 0 deletions libs/agentc_cli/tests/test_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
from agentc_core.defaults import DEFAULT_TOOL_CATALOG_NAME
from agentc_testing.repo import ExampleRepoKind
from agentc_testing.repo import initialize_repo
from agentc_testing.server import isolated_server_factory
from unittest.mock import patch

# This is to keep ruff from falsely flagging this as unused.
_ = isolated_server_factory


@pytest.mark.smoke
def test_index(tmp_path):
Expand Down
12 changes: 12 additions & 0 deletions libs/agentc_core/agentc_core/prompt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


class ToolSearchMetadata(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="allow")
name: typing.Optional[str] = None
query: typing.Optional[str] = None
annotations: typing.Optional[str] = None
Expand All @@ -42,6 +43,7 @@ def name_or_query_must_be_specified(self):

class _BaseFactory(abc.ABC):
class PromptMetadata(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="allow")
name: str
description: str
record_kind: typing.Literal[RecordKind.RawPrompt, RecordKind.JinjaPrompt]
Expand Down Expand Up @@ -80,6 +82,11 @@ class Factory(_BaseFactory):
def __iter__(self) -> typing.Iterable["RawPromptDescriptor"]:
front_matter, prompt_text = self._get_prompt_metadata()
metadata = RawPromptDescriptor.Factory.PromptMetadata.model_validate(front_matter)
if metadata.__pydantic_extra__:
logger.warning(
f"Extra fields found in {self.filename.name}: {metadata.__pydantic_extra__}. "
f"We will ignore these."
)
descriptor_args = {
"name": metadata.name,
"description": metadata.description,
Expand Down Expand Up @@ -111,6 +118,11 @@ class Factory(_BaseFactory):
def __iter__(self) -> typing.Iterable["JinjaPromptDescriptor"]:
front_matter, prompt_text = self._get_prompt_metadata()
metadata = JinjaPromptDescriptor.Factory.PromptMetadata.model_validate(front_matter)
if metadata.__pydantic_extra__:
logger.warning(
f"Extra fields found in {self.filename.name}: {metadata.__pydantic_extra__}. "
f"We will ignore these."
)
descriptor_args = {
"name": metadata.name,
"description": metadata.description,
Expand Down
2 changes: 1 addition & 1 deletion libs/agentc_core/agentc_core/record/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def is_tool(self) -> bool:
class RecordDescriptor(pydantic.BaseModel):
"""This model represents a tool's persistable description or metadata."""

model_config = pydantic.ConfigDict(validate_assignment=True, use_enum_values=True)
model_config = pydantic.ConfigDict(validate_assignment=True, use_enum_values=True, extra="allow")

record_kind: typing.Literal[
RecordKind.PythonFunction,
Expand Down
30 changes: 26 additions & 4 deletions libs/agentc_core/agentc_core/tool/descriptor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __iter__(self) -> typing.Iterable["PythonToolDescriptor"]:
for _, tool in inspect.getmembers(imported_module):
if not is_tool(tool):
continue
yield PythonToolDescriptor(
record_descriptor = PythonToolDescriptor(
record_kind=RecordKind.PythonFunction,
name=get_name(tool),
description=get_description(tool),
Expand All @@ -65,6 +65,12 @@ def __iter__(self) -> typing.Iterable["PythonToolDescriptor"]:
version=self.version,
annotations=get_annotations(tool),
)
if record_descriptor.__pydantic_extra__:
logger.warning(
f"Extra fields found in {self.filename.name} for tool {get_name(tool)}: "
f"{record_descriptor.__pydantic_extra__.keys()}. We will ignore these."
)
yield record_descriptor


class SQLPPQueryToolDescriptor(RecordDescriptor):
Expand All @@ -76,7 +82,7 @@ class SQLPPQueryToolDescriptor(RecordDescriptor):

class Factory(_BaseFactory):
class Metadata(pydantic.BaseModel, JSONSchemaValidatingMixin):
model_config = pydantic.ConfigDict(frozen=True, use_enum_values=True)
model_config = pydantic.ConfigDict(frozen=True, use_enum_values=True, extra="allow")

# Below, we enumerate all fields that appear in a .sqlpp file.
name: str
Expand Down Expand Up @@ -110,6 +116,11 @@ def __iter__(self) -> typing.Iterable["SQLPPQueryToolDescriptor"]:
elif len(matches) != 1:
logger.warning("More than one multi-line comment found. Using first comment.")
metadata = SQLPPQueryToolDescriptor.Factory.Metadata.model_validate(yaml.safe_load(matches[0]))
if metadata.__pydantic_extra__:
logger.warning(
f"Extra fields found in {self.filename.name}: {metadata.__pydantic_extra__}. "
f"We will ignore these."
)

# Now, generate a single SQL++ tool descriptor.
yield SQLPPQueryToolDescriptor(
Expand Down Expand Up @@ -145,7 +156,7 @@ class VectorSearchMetadata(pydantic.BaseModel):

class Factory(_BaseFactory):
class Metadata(pydantic.BaseModel, JSONSchemaValidatingMixin):
model_config = pydantic.ConfigDict(frozen=True, use_enum_values=True)
model_config = pydantic.ConfigDict(frozen=True, use_enum_values=True, extra="allow")

# Below, we enumerate all fields that appear in a .yaml file for semantic search.
record_kind: typing.Literal[RecordKind.SemanticSearch]
Expand Down Expand Up @@ -181,6 +192,11 @@ def name_should_be_valid_identifier(cls, v: str):
def __iter__(self) -> typing.Iterable["SemanticSearchToolDescriptor"]:
with self.filename.open("r") as fp:
metadata = SemanticSearchToolDescriptor.Factory.Metadata.model_validate(yaml.safe_load(fp))
if metadata.__pydantic_extra__:
logger.warning(
f"Extra fields found in {self.filename.name}: {metadata.__pydantic_extra__}. "
f"We will ignore these."
)
yield SemanticSearchToolDescriptor(
record_kind=RecordKind.SemanticSearch,
name=metadata.name,
Expand Down Expand Up @@ -340,10 +356,11 @@ def default(self, obj):

class Factory(_BaseFactory):
class Metadata(pydantic.BaseModel):
model_config = pydantic.ConfigDict(frozen=True, use_enum_values=True)
model_config = pydantic.ConfigDict(frozen=True, use_enum_values=True, extra="allow")

# Note: we cannot validate this model in isolation (we need the referencing descriptor as well).
class OpenAPIMetadata(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="allow")
filename: typing.Optional[str | None] = None
url: typing.Optional[str | None] = None
operations: list["HTTPRequestToolDescriptor.OperationMetadata"]
Expand All @@ -356,6 +373,11 @@ class OpenAPIMetadata(pydantic.BaseModel):
def __iter__(self) -> typing.Iterable["HTTPRequestToolDescriptor"]:
with self.filename.open("r") as fp:
metadata = HTTPRequestToolDescriptor.Factory.Metadata.model_validate(yaml.safe_load(fp))
if metadata.__pydantic_extra__:
logger.warning(
f"Extra fields found in {self.filename.name}: {metadata.__pydantic_extra__}. "
f"We will ignore these."
)
for operation in metadata.open_api.operations:
operation_handle = HTTPRequestToolDescriptor.validate_operation(
source_filename=self.filename,
Expand Down

0 comments on commit a579e7e

Please sign in to comment.