Skip to content

Commit

Permalink
Persist embeddings components to specified schema, closes #829
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Dec 4, 2024
1 parent 92489ac commit 3f7bb71
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/embeddings/configuration/ann.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ The `torch` backend supports the same options. The only difference is that the v
pgvector:
url: database url connection string, alternatively can be set via
ANN_URL environment variable
schema: database schema to store vectors - defaults to being
determined by the database
table: database table to store vectors - defaults to `vectors`
efconstruction: ef_construction param (int) - defaults to 200
m: M param for init_index (int) - defaults to 16
Expand Down
10 changes: 10 additions & 0 deletions docs/embeddings/configuration/database.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,23 @@ Add custom storage engines via setting this parameter to the fully resolvable cl

Content storage specific settings are set with a corresponding configuration object having the same name as the content storage engine (i.e. duckdb or sqlite). These are optional and set to defaults if omitted.

### client
```yaml
schema: default database schema for the session - defaults to being
determined by the database
```

Additional settings for client-server databases. Also supported when the `content=url`.

### sqlite
```yaml
sqlite:
wal: enable write-ahead logging - allows concurrent read/write operations,
defaults to false
```

Additional settings for SQLite.

## objects
```yaml
objects: boolean|image|pickle
Expand Down
2 changes: 2 additions & 0 deletions docs/embeddings/configuration/graph.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ The `rdbms` backend has the following additional settings.
```yaml
url: database url connection string, alternatively can be set via the
GRAPH_URL environment variable
schema: database schema to store graph - defaults to being
determined by the database
nodes: table to store node data, defaults to `nodes`
edges: table to store edge data, defaults to `edges`
```
Expand Down
10 changes: 9 additions & 1 deletion docs/embeddings/configuration/scoring.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@ The following covers the available options.

## method
```yaml
method: bm25|tfidf|sif|custom
method: bm25|tfidf|sif|pgtext|custom
```
Sets the scoring method. Add custom scoring via setting this parameter to the fully resolvable class string.
### pgtext
```yaml
schema: database schema to store keyword index - defaults to being
determined by the database
```
Additional settings for Postgres full-text keyword indexes.
## terms
```yaml
terms: boolean|dict
Expand Down
19 changes: 18 additions & 1 deletion src/python/txtai/ann/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from sqlalchemy import create_engine, delete, text, Column, Index, Integer, MetaData, StaticPool, Table
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateSchema

PGVECTOR = True
except ImportError:
Expand All @@ -34,9 +35,15 @@ def __init__(self, config):

# Initialize pgvector extension
self.database = Session(self.engine)
self.database.execute(text("CREATE EXTENSION IF NOT EXISTS vector" if self.engine.dialect.name == "postgresql" else "SELECT 1"))
self.sqldialect(text("CREATE EXTENSION IF NOT EXISTS vector"))
self.database.commit()

# Set default schema, if necessary
schema = self.setting("schema")
if schema:
self.sqldialect(CreateSchema(schema, if_not_exists=True))
self.sqldialect(text(f"SET search_path TO {schema},public"))

# Table instance
self.table = None

Expand Down Expand Up @@ -141,3 +148,13 @@ def settings(self):
"""

return {"m": self.setting("m", 16), "ef_construction": self.setting("efconstruction", 200)}

def sqldialect(self, sql):
"""
Executes a SQL statement based on the current SQL dialect.
Args:
sql: SQL to execute
"""

self.database.execute(sql if self.engine.dialect.name == "postgresql" else text("SELECT 1"))
23 changes: 22 additions & 1 deletion src/python/txtai/database/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
try:
from sqlalchemy import StaticPool, Text, cast, create_engine, insert, text as textsql
from sqlalchemy.orm import Session, aliased
from sqlalchemy.schema import CreateSchema

from .schema import Base, Batch, Document, Object, Section, SectionBase, Score

Expand Down Expand Up @@ -115,7 +116,15 @@ def connect(self, path=None):
engine = create_engine(content, poolclass=StaticPool, echo=False, json_serializer=lambda x: x)

# Create database session
return Session(engine)
database = Session(engine)

# Set default schema, if necessary
schema = self.config.get("schema")
if schema:
self.sqldialect(database, engine, CreateSchema(schema, if_not_exists=True))
self.sqldialect(database, engine, textsql(f"SET search_path TO {schema}"))

return database

def getcursor(self):
return Cursor(self.connection)
Expand All @@ -126,6 +135,18 @@ def rows(self):
def addfunctions(self):
return

def sqldialect(self, database, engine, sql):
"""
Executes a SQL statement based on the current SQL dialect.
Args:
database: current database
engine: database engine
sql: SQL to execute
"""

database.execute(sql if engine.dialect.name == "postgresql" else textsql("SELECT 1"))


class Cursor:
"""
Expand Down
22 changes: 19 additions & 3 deletions src/python/txtai/graph/rdbms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from grand import Graph
from grand.backends import SQLBackend, InMemoryCachedBackend

from sqlalchemy import text, StaticPool
from sqlalchemy import create_engine, text, StaticPool
from sqlalchemy.schema import CreateSchema

ORM = True
except ImportError:
Expand Down Expand Up @@ -85,11 +86,26 @@ def connect(self):
Graph database instance
"""

# Keyword arguments for SQLAlchemy
kwargs = {"poolclass": StaticPool, "echo": False}
url = self.config.get("url", os.environ.get("GRAPH_URL"))

# Set default schema, if necessary
schema = self.config.get("schema")
if schema:
# Check that schema exists
engine = create_engine(url)
with engine.connect() as connection:
connection.execute(CreateSchema(schema, if_not_exists=True) if "postgresql" in url else text("SELECT 1"))

# Set default schema
kwargs["connect_args"] = {"options": f'-c search_path="{schema}"'} if "postgresql" in url else {}

backend = SQLBackend(
db_url=self.config.get("url", os.environ.get("GRAPH_URL")),
db_url=url,
node_table_name=self.config.get("nodes", "nodes"),
edge_table_name=self.config.get("edges", "edges"),
sqlalchemy_kwargs={"poolclass": StaticPool, "echo": False},
sqlalchemy_kwargs=kwargs,
)

# pylint: disable=W0212
Expand Down
17 changes: 17 additions & 0 deletions src/python/txtai/scoring/pgtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy import Column, Computed, Index, Integer, MetaData, StaticPool, Table, Text
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateSchema

PGTEXT = True
except ImportError:
Expand Down Expand Up @@ -118,6 +119,12 @@ def initialize(self, recreate=False):
self.engine = create_engine(self.config.get("url", os.environ.get("SCORING_URL")), poolclass=StaticPool, echo=False)
self.database = Session(self.engine)

# Set default schema, if necessary
schema = self.config.get("schema")
if schema:
self.sqldialect(CreateSchema(schema, if_not_exists=True))
self.sqldialect(text(f"SET search_path TO {schema}"))

# Table name
table = self.config.get("table", "scoring")

Expand Down Expand Up @@ -149,3 +156,13 @@ def initialize(self, recreate=False):
# Create table and index
self.table.create(self.engine, checkfirst=True)
index.create(self.engine, checkfirst=True)

def sqldialect(self, sql):
"""
Executes a SQL statement based on the current SQL dialect.
Args:
sql: SQL to execute
"""

self.database.execute(sql if self.engine.dialect.name == "postgresql" else text("SELECT 1"))
2 changes: 1 addition & 1 deletion test/python/testann.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def testPGVector(self, query):

# Create ANN
path = os.path.join(tempfile.gettempdir(), "pgvector.sqlite")
ann = ANNFactory.create({"backend": "pgvector", "pgvector": {"url": f"sqlite:///{path}"}, "dimensions": 240})
ann = ANNFactory.create({"backend": "pgvector", "pgvector": {"url": f"sqlite:///{path}", "schema": "txtai"}, "dimensions": 240})

# Test indexing
ann.index(data)
Expand Down
12 changes: 12 additions & 0 deletions test/python/testdatabase/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,15 @@ def setUp(self):
self.backend = f"sqlite:///{path}"

self.embeddings.config["content"] = self.backend

def testSchema(self):
"""
Test database creation with a specified schema
"""

# Default sequence id
embeddings = Embeddings(path="sentence-transformers/nli-mpnet-base-v2", content=self.backend, schema="txtai")
embeddings.index(self.data)

result = embeddings.search("feel good story", 1)[0]
self.assertEqual(result["text"], self.data[4])
2 changes: 1 addition & 1 deletion test/python/testgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def testDatabase(self):

# Generate graph database
path = os.path.join(tempfile.gettempdir(), "graph.sqlite")
graph = GraphFactory.create({"backend": "rdbms", "url": f"sqlite:///{path}"})
graph = GraphFactory.create({"backend": "rdbms", "url": f"sqlite:///{path}", "schema": "txtai"})

# Initialize the graph
graph.initialize()
Expand Down
2 changes: 1 addition & 1 deletion test/python/testscoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def testPGText(self, query):

# Create scoring
path = os.path.join(tempfile.gettempdir(), "pgtext.sqlite")
scoring = ScoringFactory.create({"method": "pgtext", "url": f"sqlite:///{path}"})
scoring = ScoringFactory.create({"method": "pgtext", "url": f"sqlite:///{path}", "schema": "txtai"})
scoring.index((uid, {"text": text}, tags) for uid, text, tags in self.data)

# Run search and validate correct result returned
Expand Down

0 comments on commit 3f7bb71

Please sign in to comment.