From e2c4a6a60bf391b468d0a2097b56bd2ac26c1c25 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sat, 13 Apr 2024 23:00:59 +0200 Subject: [PATCH 1/2] use custom JSON type for database for more generic support --- ragna/deploy/_api/database.py | 2 +- ragna/deploy/_api/orm.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/ragna/deploy/_api/database.py b/ragna/deploy/_api/database.py index b4433a33..ae0bccb0 100644 --- a/ragna/deploy/_api/database.py +++ b/ragna/deploy/_api/database.py @@ -129,7 +129,7 @@ def _orm_to_schema_chat(chat: orm.Chat) -> schemas.Chat: documents=documents, source_storage=chat.source_storage, assistant=chat.assistant, - params=chat.params, # type: ignore[arg-type] + params=chat.params, ), messages=messages, prepared=chat.prepared, diff --git a/ragna/deploy/_api/orm.py b/ragna/deploy/_api/orm.py index 92b97a15..c0f5a8a4 100644 --- a/ragna/deploy/_api/orm.py +++ b/ragna/deploy/_api/orm.py @@ -1,9 +1,35 @@ +import json +from typing import Any + from sqlalchemy import Column, ForeignKey, Table, types +from sqlalchemy.engine import Dialect from sqlalchemy.orm import DeclarativeBase, relationship # type: ignore[attr-defined] from ragna.core import MessageRole +class Json(types.TypeDecorator): + """Universal JSON type which stores values as strings. + + This is needed because sqlalchemy.types.JSON only works for a limited subset of + databases. + """ + + impl = types.String + + cache_ok = True + + def process_bind_param(self, value: Any, dialect: Dialect) -> str: + return json.dumps(value) + + def process_result_value( + self, + value: str, + dialect: Dialect, # type: ignore[override] + ) -> Any: + return json.loads(value) + + class Base(DeclarativeBase): pass @@ -34,7 +60,7 @@ class Document(Base): name = Column(types.String) # Mind the trailing underscore here. Unfortunately, this is necessary, because # metadata without the underscore is reserved by SQLAlchemy - metadata_ = Column(types.JSON) + metadata_ = Column(Json) chats = relationship( "Chat", secondary=document_chat_association_table, @@ -59,7 +85,7 @@ class Chat(Base): ) source_storage = Column(types.String) assistant = Column(types.String) - params = Column(types.JSON) + params = Column(Json) messages = relationship("Message", cascade="all, delete") prepared = Column(types.Boolean) From 0cbd2d0f582388ad278362f92a2fc2cd43caca61 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sat, 13 Apr 2024 23:55:13 +0200 Subject: [PATCH 2/2] fix lint --- ragna/deploy/_api/orm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ragna/deploy/_api/orm.py b/ragna/deploy/_api/orm.py index c0f5a8a4..7e3f8c8e 100644 --- a/ragna/deploy/_api/orm.py +++ b/ragna/deploy/_api/orm.py @@ -24,8 +24,8 @@ def process_bind_param(self, value: Any, dialect: Dialect) -> str: def process_result_value( self, - value: str, - dialect: Dialect, # type: ignore[override] + value: str, # type: ignore[override] + dialect: Dialect, ) -> Any: return json.loads(value)