diff --git a/src/raggy/documents.py b/src/raggy/documents.py index ae41041..b65dadf 100644 --- a/src/raggy/documents.py +++ b/src/raggy/documents.py @@ -4,7 +4,7 @@ from typing import Annotated from jinja2 import Environment, Template -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from raggy.utilities.ids import generate_prefixed_uuid from raggy.utilities.text import count_tokens, extract_keywords, hash_text, split_text @@ -32,11 +32,18 @@ class Document(BaseModel): text: str = Field(..., description="Document text content.") embedding: list[float] | None = Field(default=None) - metadata: DocumentMetadata = Field(default_factory=DocumentMetadata) + metadata: DocumentMetadata | dict = Field(default_factory=DocumentMetadata) tokens: int | None = Field(default=None) keywords: list[str] = Field(default_factory=list) + @field_validator("metadata", mode="before") + @classmethod + def ensure_metadata(cls, v): + if isinstance(v, dict): + return DocumentMetadata(**v) + return v + @model_validator(mode="after") def ensure_tokens(self): if self.tokens is None: diff --git a/src/raggy/loaders/web.py b/src/raggy/loaders/web.py index 8d67a0e..74d0ed8 100644 --- a/src/raggy/loaders/web.py +++ b/src/raggy/loaders/web.py @@ -109,11 +109,11 @@ async def response_to_document(self, response: Response) -> Document: """Convert an HTTP response to a Document.""" return Document( text=await self.get_document_text(response), - metadata={ - "link": str(response.url), - "source": self.source_type, - "document_type": self.document_type, - }, + metadata=dict( + link=str(response.url), + source=self.source_type, + document_type=self.document_type, + ), ) async def get_document_text(self, response: Response) -> str: @@ -131,7 +131,20 @@ async def get_document_text(self, response: Response) -> str: class SitemapLoader(URLLoader): - """A loader that loads URLs from a sitemap.""" + """A loader that loads URLs from a sitemap. + Attributes: + include: A list of strings or regular expressions. Only URLs that match one of these will be included. + exclude: A list of strings or regular expressions. URLs that match one of these will be excluded. + url_loader: The loader to use for loading the URLs. + Examples: + Load all URLs from a sitemap: + ```python + from raggy.loaders.web import SitemapLoader + loader = SitemapLoader(urls=["https://controlflow.ai/sitemap.xml"]) + documents = await loader.load() + print(documents) + ``` + """ include: list[str | re.Pattern] = Field(default_factory=list) exclude: list[str | re.Pattern] = Field(default_factory=list)