From d9641b0480a87fae390dfcde73ef17ba39e933e4 Mon Sep 17 00:00:00 2001 From: Abirdcfly Date: Mon, 4 Dec 2023 23:10:42 +0800 Subject: [PATCH] fix: Set pgLockID constant and Simplify createOrGetCollection func Signed-off-by: Abirdcfly --- vectorstores/pgvector/options.go | 5 ++-- vectorstores/pgvector/pgvector.go | 50 +++++++++++++++++-------------- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/vectorstores/pgvector/options.go b/vectorstores/pgvector/options.go index 4b2e300e3..702a82ccd 100644 --- a/vectorstores/pgvector/options.go +++ b/vectorstores/pgvector/options.go @@ -5,6 +5,7 @@ import ( "fmt" "os" + "github.com/jackc/pgx/v5" "github.com/tmc/langchaingo/embeddings" ) @@ -52,14 +53,14 @@ func WithCollectionName(name string) Option { // WithEmbeddingTableName is an option for specifying the embedding table name. func WithEmbeddingTableName(name string) Option { return func(p *Store) { - p.embeddingTableName = name + p.embeddingTableName = pgx.Identifier{name}.Sanitize() } } // WithCollectionTableName is an option for specifying the collection table name. func WithCollectionTableName(name string) Option { return func(p *Store) { - p.collectionTableName = name + p.collectionTableName = pgx.Identifier{name}.Sanitize() } } diff --git a/vectorstores/pgvector/pgvector.go b/vectorstores/pgvector/pgvector.go index 7f0bdcf33..122ba3171 100644 --- a/vectorstores/pgvector/pgvector.go +++ b/vectorstores/pgvector/pgvector.go @@ -14,6 +14,19 @@ import ( "github.com/tmc/langchaingo/vectorstores" ) +const ( + // pgLockIDEmbeddingTable is used for advisor lock to fix issue arising from concurrent + // creation of the embedding table.The same value represents the same lock. + pgLockIDEmbeddingTable = 1573678846307946494 + // pgLockIDCollectionTable is used for advisor lock to fix issue arising from concurrent + // creation of the collection table.The same value represents the same lock. + pgLockIDCollectionTable = 1573678846307946495 + // pgLockIDExtension is used for advisor lock to fix issue arising from concurrent creation + // of the vector extension. The value is deliberately set to the same as python langchain + // https://github.com/langchain-ai/langchain/blob/v0.0.340/libs/langchain/langchain/vectorstores/pgvector.py#L167 + pgLockIDExtension = 1573678846307946496 +) + var ( ErrEmbedderWrongNumberVectors = errors.New("number of vectors from embedder does not match number of documents") ErrInvalidScoreThreshold = errors.New("score threshold must be between 0 and 1") @@ -65,7 +78,7 @@ func New(ctx context.Context, opts ...Option) (Store, error) { return Store{}, err } } - if store.collectionUUID, err = store.createOrGetCollection(ctx); err != nil { + if err = store.createOrGetCollection(ctx); err != nil { return Store{}, err } return store, nil @@ -83,7 +96,7 @@ func (s Store) createVectorExtensionIfNotExists(ctx context.Context) error { // https://github.com/langchain-ai/langchain/issues/12933 // For more information see: // https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS - if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock(1573678846307946495)"); err != nil { + if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", pgLockIDExtension); err != nil { return err } if _, err := tx.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector"); err != nil { @@ -104,7 +117,7 @@ func (s Store) createCollectionTableIfNotExists(ctx context.Context) error { // https://github.com/langchain-ai/langchain/issues/12933 // For more information see: // https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS - if _, err = tx.Exec(ctx, "SELECT pg_advisory_xact_lock(1573678846307946494)"); err != nil { + if _, err = tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", pgLockIDCollectionTable); err != nil { return err } sql := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( @@ -130,7 +143,7 @@ func (s Store) createEmbeddingTableIfNotExists(ctx context.Context) error { // https://github.com/langchain-ai/langchain/issues/12933 // For more information see: // https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS - if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock(1573678846307946493)"); err != nil { + if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", pgLockIDEmbeddingTable); err != nil { return err } sql := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( @@ -237,11 +250,11 @@ FROM ( WHERE %s ORDER BY data.distance -LIMIT %d`, s.embeddingTableName, +LIMIT $2`, s.embeddingTableName, s.embeddingTableName, s.collectionTableName, s.embeddingTableName, s.collectionTableName, s.collectionTableName, collectionName, - whereQuery, numDocuments) - rows, err := tx.Query(ctx, sql, pgvector.NewVector(embedderData)) + whereQuery) + rows, err := tx.Query(ctx, sql, pgvector.NewVector(embedderData), numDocuments) if err != nil { return nil, err } @@ -274,30 +287,21 @@ func (s Store) DropTables(ctx context.Context) error { } func (s Store) RemoveCollection(ctx context.Context) error { - _, err := s.conn.Exec(ctx, fmt.Sprintf(`DELETE FROM %s WHERE name = '%s'`, s.collectionTableName, s.collectionName)) + _, err := s.conn.Exec(ctx, fmt.Sprintf(`DELETE FROM %s WHERE name = $1`, s.collectionTableName), s.collectionName) return err } -func (s Store) createOrGetCollection(ctx context.Context) (string, error) { +func (s *Store) createOrGetCollection(ctx context.Context) error { sql := fmt.Sprintf(`INSERT INTO %s (uuid, name, cmetadata) VALUES($1, $2, $3) ON CONFLICT DO NOTHING`, s.collectionTableName) - _, err := s.conn.Exec(ctx, sql, uuid.New().String(), s.collectionName, s.collectionMetadata) - if err != nil { - return "", err + if _, err := s.conn.Exec(ctx, sql, uuid.New().String(), s.collectionName, s.collectionMetadata); err != nil { + return err } sql = fmt.Sprintf(`SELECT uuid FROM %s WHERE name = $1 ORDER BY name limit 1`, s.collectionTableName) - rows, err := s.conn.Query(ctx, sql, s.collectionName) - if err != nil { - return "", err - } - defer rows.Close() - var collectionUUID string - for rows.Next() { - if err = rows.Scan(&collectionUUID); err != nil { - return "", err - } + if err := s.conn.QueryRow(ctx, sql, s.collectionName).Scan(&s.collectionUUID); err != nil { + return err } - return collectionUUID, nil + return nil } // getOptions applies given options to default Options and returns it