Skip to content

Commit

Permalink
fix: Set pgLockID constant and Simplify createOrGetCollection func
Browse files Browse the repository at this point in the history
Signed-off-by: Abirdcfly <fp544037857@gmail.com>
  • Loading branch information
Abirdcfly committed Dec 5, 2023
1 parent 1346747 commit d9641b0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
5 changes: 3 additions & 2 deletions vectorstores/pgvector/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"os"

"github.com/jackc/pgx/v5"
"github.com/tmc/langchaingo/embeddings"
)

Expand Down Expand Up @@ -52,14 +53,14 @@ func WithCollectionName(name string) Option {
// WithEmbeddingTableName is an option for specifying the embedding table name.
func WithEmbeddingTableName(name string) Option {
return func(p *Store) {
p.embeddingTableName = name
p.embeddingTableName = pgx.Identifier{name}.Sanitize()
}
}

// WithCollectionTableName is an option for specifying the collection table name.
func WithCollectionTableName(name string) Option {
return func(p *Store) {
p.collectionTableName = name
p.collectionTableName = pgx.Identifier{name}.Sanitize()
}
}

Expand Down
50 changes: 27 additions & 23 deletions vectorstores/pgvector/pgvector.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ import (
"github.com/tmc/langchaingo/vectorstores"
)

const (
// pgLockIDEmbeddingTable is used for advisor lock to fix issue arising from concurrent
// creation of the embedding table.The same value represents the same lock.
pgLockIDEmbeddingTable = 1573678846307946494
// pgLockIDCollectionTable is used for advisor lock to fix issue arising from concurrent
// creation of the collection table.The same value represents the same lock.
pgLockIDCollectionTable = 1573678846307946495
// pgLockIDExtension is used for advisor lock to fix issue arising from concurrent creation
// of the vector extension. The value is deliberately set to the same as python langchain
// https://github.com/langchain-ai/langchain/blob/v0.0.340/libs/langchain/langchain/vectorstores/pgvector.py#L167
pgLockIDExtension = 1573678846307946496
)

var (
ErrEmbedderWrongNumberVectors = errors.New("number of vectors from embedder does not match number of documents")
ErrInvalidScoreThreshold = errors.New("score threshold must be between 0 and 1")
Expand Down Expand Up @@ -65,7 +78,7 @@ func New(ctx context.Context, opts ...Option) (Store, error) {
return Store{}, err
}
}
if store.collectionUUID, err = store.createOrGetCollection(ctx); err != nil {
if err = store.createOrGetCollection(ctx); err != nil {
return Store{}, err
}
return store, nil
Expand All @@ -83,7 +96,7 @@ func (s Store) createVectorExtensionIfNotExists(ctx context.Context) error {
// https://github.com/langchain-ai/langchain/issues/12933
// For more information see:
// https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS
if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock(1573678846307946495)"); err != nil {
if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", pgLockIDExtension); err != nil {
return err
}
if _, err := tx.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector"); err != nil {
Expand All @@ -104,7 +117,7 @@ func (s Store) createCollectionTableIfNotExists(ctx context.Context) error {
// https://github.com/langchain-ai/langchain/issues/12933
// For more information see:
// https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS
if _, err = tx.Exec(ctx, "SELECT pg_advisory_xact_lock(1573678846307946494)"); err != nil {
if _, err = tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", pgLockIDCollectionTable); err != nil {
return err
}
sql := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
Expand All @@ -130,7 +143,7 @@ func (s Store) createEmbeddingTableIfNotExists(ctx context.Context) error {
// https://github.com/langchain-ai/langchain/issues/12933
// For more information see:
// https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS
if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock(1573678846307946493)"); err != nil {
if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", pgLockIDEmbeddingTable); err != nil {
return err
}
sql := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
Expand Down Expand Up @@ -237,11 +250,11 @@ FROM (
WHERE %s
ORDER BY
data.distance
LIMIT %d`, s.embeddingTableName,
LIMIT $2`, s.embeddingTableName,
s.embeddingTableName,
s.collectionTableName, s.embeddingTableName, s.collectionTableName, s.collectionTableName, collectionName,
whereQuery, numDocuments)
rows, err := tx.Query(ctx, sql, pgvector.NewVector(embedderData))
whereQuery)
rows, err := tx.Query(ctx, sql, pgvector.NewVector(embedderData), numDocuments)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -274,30 +287,21 @@ func (s Store) DropTables(ctx context.Context) error {
}

func (s Store) RemoveCollection(ctx context.Context) error {
_, err := s.conn.Exec(ctx, fmt.Sprintf(`DELETE FROM %s WHERE name = '%s'`, s.collectionTableName, s.collectionName))
_, err := s.conn.Exec(ctx, fmt.Sprintf(`DELETE FROM %s WHERE name = $1`, s.collectionTableName), s.collectionName)
return err
}

func (s Store) createOrGetCollection(ctx context.Context) (string, error) {
func (s *Store) createOrGetCollection(ctx context.Context) error {
sql := fmt.Sprintf(`INSERT INTO %s (uuid, name, cmetadata)
VALUES($1, $2, $3) ON CONFLICT DO NOTHING`, s.collectionTableName)
_, err := s.conn.Exec(ctx, sql, uuid.New().String(), s.collectionName, s.collectionMetadata)
if err != nil {
return "", err
if _, err := s.conn.Exec(ctx, sql, uuid.New().String(), s.collectionName, s.collectionMetadata); err != nil {
return err
}
sql = fmt.Sprintf(`SELECT uuid FROM %s WHERE name = $1 ORDER BY name limit 1`, s.collectionTableName)
rows, err := s.conn.Query(ctx, sql, s.collectionName)
if err != nil {
return "", err
}
defer rows.Close()
var collectionUUID string
for rows.Next() {
if err = rows.Scan(&collectionUUID); err != nil {
return "", err
}
if err := s.conn.QueryRow(ctx, sql, s.collectionName).Scan(&s.collectionUUID); err != nil {
return err
}
return collectionUUID, nil
return nil
}

// getOptions applies given options to default Options and returns it
Expand Down

0 comments on commit d9641b0

Please sign in to comment.