diff --git a/.github/workflows/test-and-build.yml b/.github/workflows/test-and-build.yml index 18fc299..dbaea36 100644 --- a/.github/workflows/test-and-build.yml +++ b/.github/workflows/test-and-build.yml @@ -31,7 +31,7 @@ jobs: strategy: matrix: go-version: - - '1.13' # oldest supported; named in go.mod + - '1.21' # oldest supported; named in go.mod - 'oldstable' - 'stable' env: diff --git a/go.mod b/go.mod index 2d509ee..907b492 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,7 @@ module github.com/hashicorp/go-memdb -go 1.13 +go 1.21 -require ( - github.com/hashicorp/go-immutable-radix v1.3.1 - github.com/hashicorp/golang-lru v0.5.4 // indirect -) +toolchain go1.22.2 + +require github.com/absolutelightning/go-immutable-adaptive-radix v1.0.94 diff --git a/go.sum b/go.sum index cf93b59..a84eda9 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,16 @@ -github.com/hashicorp/go-immutable-radix v1.3.1 h1:DKHmCUm2hRBK510BaiZlwvpD40f8bJFeZnpfm2KLowc= -github.com/hashicorp/go-immutable-radix v1.3.1/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= -github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM= -github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= -github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/absolutelightning/go-immutable-adaptive-radix v1.0.91 h1:Y77uHatYySoeYv24GS/aAFl/yRW8IhAtd8sRkol9KVo= +github.com/absolutelightning/go-immutable-adaptive-radix v1.0.91/go.mod h1:+/nWgzXP46cOw2W+jLxu2lBmEny3bG1RqFXPKTWTLMU= +github.com/absolutelightning/go-immutable-adaptive-radix v1.0.94 h1:Hu6L+uqk3r9RUza1O+HLjpDou6gexIwwUI8iquSiyXc= +github.com/absolutelightning/go-immutable-adaptive-radix v1.0.94/go.mod h1:+/nWgzXP46cOw2W+jLxu2lBmEny3bG1RqFXPKTWTLMU= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/memdb.go b/memdb.go index 13cc6a8..ffd7085 100644 --- a/memdb.go +++ b/memdb.go @@ -6,11 +6,10 @@ package memdb import ( + adaptive "github.com/absolutelightning/go-immutable-adaptive-radix" "sync" "sync/atomic" "unsafe" - - "github.com/hashicorp/go-immutable-radix" ) // MemDB is an in-memory database providing Atomicity, Consistency, and @@ -45,7 +44,7 @@ func NewMemDB(schema *DBSchema) (*MemDB, error) { // Create the MemDB db := &MemDB{ schema: schema, - root: unsafe.Pointer(iradix.New()), + root: unsafe.Pointer(adaptive.NewRadixTree[any]()), primary: true, } if err := db.initialize(); err != nil { @@ -64,8 +63,8 @@ func (db *MemDB) DBSchema() *DBSchema { } // getRoot is used to do an atomic load of the root pointer -func (db *MemDB) getRoot() *iradix.Tree { - root := (*iradix.Tree)(atomic.LoadPointer(&db.root)) +func (db *MemDB) getRoot() *adaptive.RadixTree[any] { + root := (*adaptive.RadixTree[any])(atomic.LoadPointer(&db.root)) return root } @@ -104,7 +103,7 @@ func (db *MemDB) initialize() error { root := db.getRoot() for tName, tableSchema := range db.schema.Tables { for iName := range tableSchema.Indexes { - index := iradix.New() + index := adaptive.NewRadixTree[any]() path := indexPath(tName, iName) root, _, _ = root.Insert(path, index) } diff --git a/txn.go b/txn.go index f83f4fa..dc9dbe0 100644 --- a/txn.go +++ b/txn.go @@ -6,11 +6,10 @@ package memdb import ( "bytes" "fmt" + adaptive "github.com/absolutelightning/go-immutable-adaptive-radix" "strings" "sync/atomic" "unsafe" - - iradix "github.com/hashicorp/go-immutable-radix" ) const ( @@ -33,14 +32,14 @@ type tableIndex struct { type Txn struct { db *MemDB write bool - rootTxn *iradix.Txn + rootTxn *adaptive.Txn[any] after []func() // changes is used to track the changes performed during the transaction. If // it is nil at transaction start then changes are not tracked. changes Changes - modified map[tableIndex]*iradix.Txn + modified map[tableIndex]*adaptive.Txn[any] } // TrackChanges enables change tracking for the transaction. If called at any @@ -58,7 +57,7 @@ func (txn *Txn) TrackChanges() { // readableIndex returns a transaction usable for reading the given index in a // table. If the transaction is a write transaction with modifications, a clone of the // modified index will be returned. -func (txn *Txn) readableIndex(table, index string) *iradix.Txn { +func (txn *Txn) readableIndex(table, index string) *adaptive.Txn[any] { // Look for existing transaction if txn.write && txn.modified != nil { key := tableIndex{table, index} @@ -71,15 +70,15 @@ func (txn *Txn) readableIndex(table, index string) *iradix.Txn { // Create a read transaction path := indexPath(table, index) raw, _ := txn.rootTxn.Get(path) - indexTxn := raw.(*iradix.Tree).Txn() + indexTxn := raw.(*adaptive.RadixTree[any]).Txn() return indexTxn } // writableIndex returns a transaction usable for modifying the // given index in a table. -func (txn *Txn) writableIndex(table, index string) *iradix.Txn { +func (txn *Txn) writableIndex(table, index string) *adaptive.Txn[any] { if txn.modified == nil { - txn.modified = make(map[tableIndex]*iradix.Txn) + txn.modified = make(map[tableIndex]*adaptive.Txn[any]) } // Look for existing transaction @@ -92,7 +91,7 @@ func (txn *Txn) writableIndex(table, index string) *iradix.Txn { // Start a new transaction path := indexPath(table, index) raw, _ := txn.rootTxn.Get(path) - indexTxn := raw.(*iradix.Tree).Txn() + indexTxn := raw.(*adaptive.RadixTree[any]).Txn() // If we are the primary DB, enable mutation tracking. Snapshots should // not notify, otherwise we will trigger watches on the primary DB when @@ -603,8 +602,28 @@ func (txn *Txn) LastWatch(table, index string, args ...interface{}) (<-chan stru // Note that all values read in the transaction form a consistent snapshot // from the time when the transaction was created. func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) { - _, val, err := txn.FirstWatch(table, index, args...) - return val, err + indexSchema, val, err := txn.getIndexValue(table, index, args...) + if err != nil { + return nil, err + } + + // Get the index itself + indexTxn := txn.readableIndex(table, indexSchema.Name) + + // Do an exact lookup + if indexSchema.Unique && val != nil && indexSchema.Name == index { + obj, ok := indexTxn.Get(val) + if !ok { + return nil, nil + } + return obj, nil + } + + // Handle non-unique index by using an iterator and getting the first value + iter := indexTxn.Root().Iterator() + iter.SeekPrefix(val) + _, value, _ := iter.Next() + return value, nil } // Last is used to return the last matching object for @@ -647,7 +666,7 @@ func (txn *Txn) LongestPrefix(table, index string, args ...interface{}) (interfa // Find the longest prefix match with the given index. indexTxn := txn.readableIndex(table, indexSchema.Name) - if _, value, ok := indexTxn.Root().LongestPrefix(val); ok { + if _, value, ok := indexTxn.LongestPrefix(val); ok { return value, nil } return nil, nil @@ -797,7 +816,7 @@ func (txn *Txn) GetReverse(table, index string, args ...interface{}) (ResultIter // See the documentation for ResultIterator to understand the behaviour of the // returned ResultIterator. func (txn *Txn) LowerBound(table, index string, args ...interface{}) (ResultIterator, error) { - indexIter, val, err := txn.getIndexIterator(table, index, args...) + indexIter, val, err := txn.getIndexLowerBoundIterator(table, index, args...) if err != nil { return nil, err } @@ -806,7 +825,7 @@ func (txn *Txn) LowerBound(table, index string, args ...interface{}) (ResultIter indexIter.SeekLowerBound(val) // Create an iterator - iter := &radixIterator{ + iter := &radixLowerBoundIterator{ iter: indexIter, } return iter, nil @@ -923,7 +942,23 @@ func (txn *Txn) Changes() Changes { return cs } -func (txn *Txn) getIndexIterator(table, index string, args ...interface{}) (*iradix.Iterator, []byte, error) { +func (txn *Txn) getIndexLowerBoundIterator(table, index string, args ...interface{}) (*adaptive.LowerBoundIterator[any], []byte, error) { + // Get the index value to scan + indexSchema, val, err := txn.getIndexValue(table, index, args...) + if err != nil { + return nil, nil, err + } + + // Get the index itself + indexTxn := txn.readableIndex(table, indexSchema.Name) + indexRoot := indexTxn.Root() + + // Get an iterator over the index + indexIter := indexRoot.LowerBoundIterator() + return indexIter, val, nil +} + +func (txn *Txn) getIndexIterator(table, index string, args ...interface{}) (*adaptive.Iterator[any], []byte, error) { // Get the index value to scan indexSchema, val, err := txn.getIndexValue(table, index, args...) if err != nil { @@ -939,7 +974,7 @@ func (txn *Txn) getIndexIterator(table, index string, args ...interface{}) (*ira return indexIter, val, nil } -func (txn *Txn) getIndexIteratorReverse(table, index string, args ...interface{}) (*iradix.ReverseIterator, []byte, error) { +func (txn *Txn) getIndexIteratorReverse(table, index string, args ...interface{}) (*adaptive.ReverseIterator[any], []byte, error) { // Get the index value to scan indexSchema, val, err := txn.getIndexValue(table, index, args...) if err != nil { @@ -963,11 +998,11 @@ func (txn *Txn) Defer(fn func()) { txn.after = append(txn.after, fn) } -// radixIterator is used to wrap an underlying iradix iterator. +// radixIterator is used to wrap an underlying adaptive iterator. // This is much more efficient than a sliceIterator as we are not // materializing the entire view. type radixIterator struct { - iter *iradix.Iterator + iter *adaptive.Iterator[any] watchCh <-chan struct{} } @@ -983,8 +1018,25 @@ func (r *radixIterator) Next() interface{} { return value } +type radixLowerBoundIterator struct { + iter *adaptive.LowerBoundIterator[any] + watchCh <-chan struct{} +} + +func (r *radixLowerBoundIterator) WatchCh() <-chan struct{} { + return r.watchCh +} + +func (r *radixLowerBoundIterator) Next() interface{} { + _, value, ok := r.iter.Next() + if !ok { + return nil + } + return value +} + type radixReverseIterator struct { - iter *iradix.ReverseIterator + iter *adaptive.ReverseIterator[any] watchCh <-chan struct{} }