From 118d1b50a4d9e0c8df4d18736a6603cff3168766 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Fri, 19 Apr 2024 14:10:44 +0200 Subject: [PATCH 1/3] add: DeleteDocument to delete a single document from a collection --- collection.go | 43 +++++++++++++++++++++++----- collection_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++++++ persistence.go | 16 +++++++++++ 3 files changed, 122 insertions(+), 7 deletions(-) diff --git a/collection.go b/collection.go index b0765cd..44f0d5e 100644 --- a/collection.go +++ b/collection.go @@ -236,19 +236,37 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { // Persist the document if c.persistDirectory != "" { - safeID := hash2hex(doc.ID) - docPath := filepath.Join(c.persistDirectory, safeID) - docPath += ".gob" - if c.compress { - docPath += ".gz" - } + docPath := c.getDocPath(doc.ID) err := persist(docPath, doc, c.compress, "") if err != nil { - return fmt.Errorf("couldn't persist document: %w", err) + return fmt.Errorf("couldn't persist document to %q: %w", docPath, err) + } + } + + return nil +} + +// RemoveDocument removes a document from the collection. +func (c *Collection) RemoveDocument(_ context.Context, documentID string) error { + if documentID == "" { + return errors.New("documentID is empty") + } + + c.documentsLock.Lock() + defer c.documentsLock.Unlock() + delete(c.documents, documentID) + + // Remove the document from disk + if c.persistDirectory != "" { + docPath := c.getDocPath(documentID) + err := remove(docPath) + if err != nil { + return fmt.Errorf("couldn't remove document at %q: %w", docPath, err) } } return nil + } // Count returns the number of documents in the collection. @@ -350,3 +368,14 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 // Return the top nResults return res, nil } + +// getDocPath generates the path to the document file. +func (c *Collection) getDocPath(docID string) string { + safeID := hash2hex(docID) + docPath := filepath.Join(c.persistDirectory, safeID) + docPath += ".gob" + if c.compress { + docPath += ".gz" + } + return docPath +} diff --git a/collection_test.go b/collection_test.go index fc37689..5d75223 100644 --- a/collection_test.go +++ b/collection_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "math/rand" + "os" "slices" "strconv" "testing" @@ -420,6 +421,75 @@ func TestCollection_Count(t *testing.T) { } } +func TestCollection_RemoveDocument(t *testing.T) { + // Create persistent collection + tmpdir, err := os.MkdirTemp(os.TempDir(), "chromem-test-*") + if err != nil { + t.Fatal("expected no error, got", err) + } + db, err := NewPersistentDB(tmpdir, false) + if err != nil { + t.Fatal("expected no error, got", err) + } + name := "test" + metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return vectors, nil + } + c, err := db.CreateCollection(name, metadata, embeddingFunc) + if err != nil { + t.Fatal("expected no error, got", err) + } + if c == nil { + t.Fatal("expected collection, got nil") + } + + // Add documents + ids := []string{"1", "2"} + metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} + contents := []string{"hello world", "hallo welt"} + err = c.Add(context.Background(), ids, nil, metadatas, contents) + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Check count + if c.Count() != 2 { + t.Fatal("expected 2, got", c.Count()) + } + + // Check number of files in the persist directory + d, err := os.ReadDir(c.persistDirectory) + if err != nil { + t.Fatal("expected nil, got", err) + } + if len(d) != 3 { // 2 documents + 1 metadata file + t.Fatal("expected 2 files in persist_dir, got", len(d)) + } + + // Remove document + err = c.RemoveDocument(context.Background(), "1") + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Check count + if c.Count() != 1 { + t.Fatal("expected 1, got", c.Count()) + } + + // Check number of files in the persist directory + d, err = os.ReadDir(c.persistDirectory) + if err != nil { + t.Fatal("expected nil, got", err) + } + if len(d) != 2 { // 1 document + 1 metadata file + t.Fatal("expected 1 file in persist_dir, got", len(d)) + } + +} + // Global var for assignment in the benchmark to avoid compiler optimizations. var globalRes []Result diff --git a/persistence.go b/persistence.go index 4a385d4..5748afd 100644 --- a/persistence.go +++ b/persistence.go @@ -225,3 +225,19 @@ func read(filePath string, obj any, encryptionKey string) error { return nil } + +// remove removes a file at the given path. If the file doesn't exist, it's a no-op. +func remove(filePath string) error { + if filePath == "" { + return fmt.Errorf("file path is empty") + } + + err := os.Remove(filePath) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("couldn't remove file %q: %w", filePath, err) + } + } + + return nil +} From 78061fb72402cf02ce9e771fa3ec1ba90eae23e8 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Mon, 22 Apr 2024 20:06:38 +0200 Subject: [PATCH 2/3] change: add filters and variadic ids params to renamed delete func --- collection.go | 58 +++++++++++++++++++++++++++++++++++++--------- collection_test.go | 56 ++++++++++++++++++++++++++++++-------------- 2 files changed, 85 insertions(+), 29 deletions(-) diff --git a/collection.go b/collection.go index 44f0d5e..ec4a776 100644 --- a/collection.go +++ b/collection.go @@ -246,22 +246,58 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { return nil } -// RemoveDocument removes a document from the collection. -func (c *Collection) RemoveDocument(_ context.Context, documentID string) error { - if documentID == "" { - return errors.New("documentID is empty") +// Delete removes document(s) from the collection. +// +// - where: Conditional filtering on metadata. Optional. +// - whereDocument: Conditional filtering on documents. Optional. +// - ids: The ids of the documents to delete. If empty, all documents are deleted. +func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error { + + // must have at least one of where, whereDocument or ids + if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 { + return fmt.Errorf("must have at least one of where, whereDocument or ids") + } + + if len(c.documents) == 0 { + return nil + } + + for k := range whereDocument { + if !slices.Contains(supportedFilters, k) { + return errors.New("unsupported whereDocument operator") + } + } + + var docIDs []string + + if where != nil || whereDocument != nil { + // metadata + content filters + filteredDocs := filterDocs(c.documents, where, whereDocument) + for _, doc := range filteredDocs { + docIDs = append(docIDs, doc.ID) + } + } else { + docIDs = ids + } + + // No-op if no docs are left + if len(docIDs) == 0 { + return nil } c.documentsLock.Lock() defer c.documentsLock.Unlock() - delete(c.documents, documentID) - // Remove the document from disk - if c.persistDirectory != "" { - docPath := c.getDocPath(documentID) - err := remove(docPath) - if err != nil { - return fmt.Errorf("couldn't remove document at %q: %w", docPath, err) + for _, docID := range docIDs { + delete(c.documents, docID) + + // Remove the document from disk + if c.persistDirectory != "" { + docPath := c.getDocPath(docID) + err := remove(docPath) + if err != nil { + return fmt.Errorf("couldn't remove document at %q: %w", docPath, err) + } } } diff --git a/collection_test.go b/collection_test.go index 5d75223..e37e511 100644 --- a/collection_test.go +++ b/collection_test.go @@ -421,7 +421,7 @@ func TestCollection_Count(t *testing.T) { } } -func TestCollection_RemoveDocument(t *testing.T) { +func TestCollection_Delete(t *testing.T) { // Create persistent collection tmpdir, err := os.MkdirTemp(os.TempDir(), "chromem-test-*") if err != nil { @@ -446,47 +446,67 @@ func TestCollection_RemoveDocument(t *testing.T) { } // Add documents - ids := []string{"1", "2"} - metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} - contents := []string{"hello world", "hallo welt"} + ids := []string{"1", "2", "3", "4"} + metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}, {"foo": "bar"}, {"e": "f"}} + contents := []string{"hello world", "hallo welt", "bonjour le monde", "hola mundo"} err = c.Add(context.Background(), ids, nil, metadatas, contents) if err != nil { t.Fatal("expected nil, got", err) } // Check count - if c.Count() != 2 { - t.Fatal("expected 2, got", c.Count()) + if c.Count() != 4 { + t.Fatal("expected 4 documents, got", c.Count()) } // Check number of files in the persist directory d, err := os.ReadDir(c.persistDirectory) + if err != nil { t.Fatal("expected nil, got", err) } - if len(d) != 3 { // 2 documents + 1 metadata file - t.Fatal("expected 2 files in persist_dir, got", len(d)) + if len(d) != 5 { // 4 documents + 1 metadata file + t.Fatal("expected 4 document files + 1 metadata file in persist_dir, got", len(d)) } - // Remove document - err = c.RemoveDocument(context.Background(), "1") + checkCount := func(expected int) { + // Check count + if c.Count() != expected { + t.Fatalf("expected %d documents, got %d", expected, c.Count()) + } + + // Check number of files in the persist directory + d, err = os.ReadDir(c.persistDirectory) + if err != nil { + t.Fatal("expected nil, got", err) + } + if len(d) != expected+1 { // 3 document + 1 metadata file + t.Fatalf("expected %d document files + 1 metadata file in persist_dir, got %d", expected, len(d)) + } + } + + // Test 1 - Remove document by ID: should delete one document + err = c.Delete(context.Background(), nil, nil, "4") if err != nil { t.Fatal("expected nil, got", err) } + checkCount(3) - // Check count - if c.Count() != 1 { - t.Fatal("expected 1, got", c.Count()) + // Test 2 - Remove document by metadata + err = c.Delete(context.Background(), map[string]string{"foo": "bar"}, nil) + if err != nil { + t.Fatal("expected nil, got", err) } - // Check number of files in the persist directory - d, err = os.ReadDir(c.persistDirectory) + checkCount(1) + + // Test 3 - Remove document by content + err = c.Delete(context.Background(), nil, map[string]string{"$contains": "hallo welt"}) if err != nil { t.Fatal("expected nil, got", err) } - if len(d) != 2 { // 1 document + 1 metadata file - t.Fatal("expected 1 file in persist_dir, got", len(d)) - } + + checkCount(0) } From 1da54cf3d4b57514febdd5608529165d4e7f8bf0 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Mon, 22 Apr 2024 21:08:31 +0200 Subject: [PATCH 3/3] change: lock docs earlier during deletion process --- collection.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/collection.go b/collection.go index ec4a776..e01540a 100644 --- a/collection.go +++ b/collection.go @@ -270,6 +270,9 @@ func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]s var docIDs []string + c.documentsLock.Lock() + defer c.documentsLock.Unlock() + if where != nil || whereDocument != nil { // metadata + content filters filteredDocs := filterDocs(c.documents, where, whereDocument) @@ -285,9 +288,6 @@ func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]s return nil } - c.documentsLock.Lock() - defer c.documentsLock.Unlock() - for _, docID := range docIDs { delete(c.documents, docID)