Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

special case LocalDocument more #170

Merged
merged 6 commits into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 24 additions & 26 deletions ragna/core/_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterator, Optional, Type, TypeVar
from typing import TYPE_CHECKING, Any, Iterator, Optional, Type, TypeVar, Union

import jwt
from pydantic import BaseModel
Expand Down Expand Up @@ -75,49 +75,47 @@ def extract_pages(self) -> Iterator[Page]:


class LocalDocument(Document):
def __init__(
self,
path: Optional[str | Path] = None,
"""Document class for files on the local file system.

!!! tip

This object is usually not instantiated manually, but rather through
[ragna.core.LocalDocument.from_path][].
"""

@classmethod
def from_path(
cls,
path: Union[str, Path],
*,
id: Optional[uuid.UUID] = None,
name: Optional[str] = None,
metadata: Optional[dict[str, Any]] = None,
handler: Optional[DocumentHandler] = None,
) -> None:
"""Document class for files on the local file system.
) -> LocalDocument:
"""Create a [ragna.core.LocalDocument][] from a path.

Args:
path: Path to a file.
path: Local path to the file.
id: ID of the document. If omitted, one is generated.
name: Name of the document. If omitted, is inferred from the `path` or the
`metadata`.
metadata: Metadata of the document. If not included, `path` will be added
under the `"path"` key.
metadata: Optional metadata of the document.
handler: Document handler. If omitted, a builtin handler is selected based
on the suffix of the `path`.

Raises:
RagnaException: If `path` is omitted and and also not passed as part of
`metadata`.
RagnaException: If `path` is passed directly and as part of `metadata`.
RagnaException: If `metadata` is passed and contains a `"path"` key.
"""
if metadata is None:
metadata = {}
metadata_path = metadata.get("path")

if path is None and metadata_path is None:
elif "path" in metadata:
raise RagnaException(
"Path was neither passed directly or as part of the metadata"
"The metadata already includes a 'path' key. "
"Did you mean to instantiate the class directly?"
)
elif path is not None and metadata_path is not None:
raise RagnaException("Path was passed directly and as part of the metadata")
elif path is not None:
metadata["path"] = str(Path(path).expanduser().resolve())

if name is None:
name = Path(metadata["path"]).name
path = Path(path).expanduser().resolve()
metadata["path"] = str(path)

super().__init__(id=id, name=name, metadata=metadata, handler=handler)
return cls(id=id, name=path.name, metadata=metadata, handler=handler)

@property
def path(self) -> Path:
Expand Down
29 changes: 15 additions & 14 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ._components import Assistant, Component, Message, MessageRole, SourceStorage
from ._config import Config
from ._document import Document
from ._document import Document, LocalDocument
from ._queue import Queue
from ._utils import RagnaException, default_user, merge_models

Expand Down Expand Up @@ -46,14 +46,14 @@ def chat(
"""Create a new [ragna.core.Chat][].

Args:
documents: Documents to use. Items that are not [ragna.core.Document][]s are
passed as positional argument to the configured document class.
documents: Documents to use.

!!! note

The default configuration uses [ragna.core.LocalDocument][] as
document class. It accepts a path as positional input to create it.
Thus, in this configuration you can pass paths as documents.
The default configuration uses [ragna.core.LocalDocument][]. If
that is the case, [ragna.core.LocalDocument.from_path][] is invoked
on any non-[ragna.core.Document][] inputs. Thus, in this
configuration you can pass paths directly.
source_storage: Source storage to use. If [str][] can be the
[ragna.core.Component.display_name][] of any configured source
storage.
Expand Down Expand Up @@ -104,14 +104,14 @@ class Chat:

Args:
rag: The RAG workflow this chat is associated with.
documents: Documents to use. Items that are not [ragna.core.Document][]s are
passed as positional argument to the configured document class.
documents: Documents to use.

!!! note

The default configuration uses [ragna.core.LocalDocument][] as document
class. It accepts a path as positional input to create it. Thus, in
this configuration you can pass paths as documents.
The default configuration uses [ragna.core.LocalDocument][]. If that is
the case, [ragna.core.LocalDocument.from_path][] is invoked on any
non-[ragna.core.Document][] inputs. Thus, in this configuration you can
pass paths directly.
source_storage: Source storage to use. If [str][] can be the
[ragna.core.Component.display_name][] of any configured source storage.
assistant: Assistant to use. If [str][] can be the
Expand Down Expand Up @@ -220,9 +220,10 @@ def _parse_documents(self, documents: Iterable[Any]) -> list[Document]:
documents_ = []
for document in documents:
if not isinstance(document, Document):
document = self._rag.config.core.document(
document # type: ignore[misc, call-arg]
)
if issubclass(self._rag.config.core.document, LocalDocument):
document = LocalDocument.from_path(document)
else:
raise RagnaException("Input is not a document", document=document)

if not document.is_readable():
raise RagnaException(
Expand Down
5 changes: 3 additions & 2 deletions tests/core/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def check_core(config):
document_path = document_root / "test.txt"
with open(document_path, "w") as file:
file.write("!\n")
document = ragna.core.LocalDocument.from_path(document_path)

async def core():
rag = Rag(config)
chat = rag.chat(
documents=[document_path],
documents=[document],
source_storage=RagnaDemoSourceStorage,
assistant=RagnaDemoAssistant,
)
Expand All @@ -56,4 +57,4 @@ async def core():

assert isinstance(answer, ragna.core.Message)
assert answer.role is ragna.core.MessageRole.ASSISTANT
assert {source.document.name for source in answer.sources} == {document_path.name}
assert {source.document.name for source in answer.sources} == {document.name}
37 changes: 24 additions & 13 deletions tests/core/test_rag.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,39 @@
import asyncio

import pydantic
import pytest

from ragna import Rag, assistants, source_storages
from ragna.core import LocalDocument


@pytest.fixture
@pytest.fixture()
def demo_document(tmp_path, request):
path = tmp_path / "demo_document.txt"
with open(path, "w") as file:
file.write(f"{request.node.name}\n")
return path
return LocalDocument.from_path(path)


def test_chat_params_extra(demo_document):
async def main():
async with Rag().chat(
documents=[demo_document],
class TestChat:
def chat(self, documents, **params):
return Rag().chat(
documents=documents,
source_storage=source_storages.RagnaDemoSourceStorage,
assistant=assistants.RagnaDemoAssistant,
not_supported_parameter="arbitrary_value",
):
pass
**params,
)

def test_extra_params(self, demo_document):
with pytest.raises(pydantic.ValidationError, match="not_supported_parameter"):
self.chat(
documents=[demo_document], not_supported_parameter="arbitrary_value"
)

def test_document_path(self, demo_document):
chat = self.chat(documents=[demo_document.path])

assert len(chat.documents) == 1
document = chat.documents[0]

with pytest.raises(pydantic.ValidationError, match="not_supported_parameter"):
asyncio.run(main())
assert isinstance(document, LocalDocument)
assert document.path == demo_document.path
assert document.name == demo_document.name
4 changes: 2 additions & 2 deletions tests/source_storages/test_source_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def test_smoke(tmp_path, source_storage_cls):
with open(path, "w") as file:
file.write(f"This is irrelevant information for the {idx}. time!\n")

documents.append(LocalDocument(path))
documents.append(LocalDocument.from_path(path))

secret = "Ragna"
path = document_root / "secret.txt"
with open(path, "w") as file:
file.write(f"The secret is {secret}!\n")

documents.insert(len(documents) // 2, LocalDocument(path))
documents.insert(len(documents) // 2, LocalDocument.from_path(path))

config = Config(local_cache_root=tmp_path)
source_storage = source_storage_cls(config)
Expand Down
Loading