diff --git a/batch.go b/batch.go index 2afcffa7..52bac55a 100644 --- a/batch.go +++ b/batch.go @@ -3,6 +3,7 @@ package rosedb import ( "bytes" "fmt" + "github.com/rosedblabs/rosedb/v2/utils" "sync" "time" @@ -22,14 +23,15 @@ import ( // // You must call Commit method to commit the batch, otherwise the DB will be locked. type Batch struct { - db *DB - pendingWrites []*LogRecord // save the data to be written - options BatchOptions - mu sync.RWMutex - committed bool // whether the batch has been committed - rollbacked bool // whether the batch has been rollbacked - batchId *snowflake.Node - buffers []*bytebufferpool.ByteBuffer + db *DB + pendingWrites []*LogRecord // save the data to be written + pendingWritesMap map[uint64][]int // map record hash key to index, with open hashing + options BatchOptions + mu sync.RWMutex + committed bool // whether the batch has been committed + rollbacked bool // whether the batch has been rollbacked + batchId *snowflake.Node + buffers []*bytebufferpool.ByteBuffer } // NewBatch creates a new Batch instance. @@ -77,6 +79,9 @@ func (b *Batch) init(rdonly, sync bool, db *DB) *Batch { func (b *Batch) reset() { b.db = nil b.pendingWrites = b.pendingWrites[:0] + for key := range b.pendingWritesMap { + delete(b.pendingWritesMap, key) + } b.committed = false b.rollbacked = false // put all buffers back to the pool @@ -116,19 +121,12 @@ func (b *Batch) Put(key []byte, value []byte) error { b.mu.Lock() // write to pendingWrites - var record *LogRecord - // if the key exists in pendingWrites, update the value directly - for i := len(b.pendingWrites) - 1; i >= 0; i-- { - if bytes.Equal(key, b.pendingWrites[i].Key) { - record = b.pendingWrites[i] - break - } - } + var record = b.lookupPendingWrites(key) if record == nil { // if the key does not exist in pendingWrites, write a new record // the record will be put back to the pool when the batch is committed or rollbacked record = b.db.recordPool.Get().(*LogRecord) - b.pendingWrites = append(b.pendingWrites, record) + b.appendPendingWrites(key, record) } record.Key, record.Value = key, value @@ -152,19 +150,12 @@ func (b *Batch) PutWithTTL(key []byte, value []byte, ttl time.Duration) error { b.mu.Lock() // write to pendingWrites - var record *LogRecord - // if the key exists in pendingWrites, update the value directly - for i := len(b.pendingWrites) - 1; i >= 0; i-- { - if bytes.Equal(key, b.pendingWrites[i].Key) { - record = b.pendingWrites[i] - break - } - } + var record = b.lookupPendingWrites(key) if record == nil { // if the key does not exist in pendingWrites, write a new record // the record will be put back to the pool when the batch is committed or rollbacked record = b.db.recordPool.Get().(*LogRecord) - b.pendingWrites = append(b.pendingWrites, record) + b.appendPendingWrites(key, record) } record.Key, record.Value = key, value @@ -186,13 +177,7 @@ func (b *Batch) Get(key []byte) ([]byte, error) { now := time.Now().UnixNano() // get from pendingWrites b.mu.RLock() - var record *LogRecord - for i := len(b.pendingWrites) - 1; i >= 0; i-- { - if bytes.Equal(key, b.pendingWrites[i].Key) { - record = b.pendingWrites[i] - break - } - } + var record = b.lookupPendingWrites(key) b.mu.RUnlock() // if the record is in pendingWrites, return the value directly @@ -240,20 +225,19 @@ func (b *Batch) Delete(key []byte) error { b.mu.Lock() // only need key and type when deleting a value. var exist bool - for i := len(b.pendingWrites) - 1; i >= 0; i-- { - if bytes.Equal(key, b.pendingWrites[i].Key) { - b.pendingWrites[i].Type = LogRecordDeleted - b.pendingWrites[i].Value = nil - b.pendingWrites[i].Expire = 0 - exist = true - break - } + var record = b.lookupPendingWrites(key) + if record != nil { + record.Type = LogRecordDeleted + record.Value = nil + record.Expire = 0 + exist = true } if !exist { - b.pendingWrites = append(b.pendingWrites, &LogRecord{ + record = &LogRecord{ Key: key, Type: LogRecordDeleted, - }) + } + b.appendPendingWrites(key, record) } b.mu.Unlock() @@ -272,13 +256,7 @@ func (b *Batch) Exist(key []byte) (bool, error) { now := time.Now().UnixNano() // check if the key exists in pendingWrites b.mu.RLock() - var record *LogRecord - for i := len(b.pendingWrites) - 1; i >= 0; i-- { - if bytes.Equal(key, b.pendingWrites[i].Key) { - record = b.pendingWrites[i] - break - } - } + var record = b.lookupPendingWrites(key) b.mu.RUnlock() if record != nil { @@ -320,13 +298,7 @@ func (b *Batch) Expire(key []byte, ttl time.Duration) error { b.mu.Lock() defer b.mu.Unlock() - var record *LogRecord - for i := len(b.pendingWrites) - 1; i >= 0; i-- { - if bytes.Equal(key, b.pendingWrites[i].Key) { - record = b.pendingWrites[i] - break - } - } + var record = b.lookupPendingWrites(key) // if the key exists in pendingWrites, update the expiry time directly if record != nil { @@ -335,30 +307,30 @@ func (b *Batch) Expire(key []byte, ttl time.Duration) error { return ErrKeyNotFound } record.Expire = time.Now().Add(ttl).UnixNano() - } else { - // if the key does not exist in pendingWrites, get the value from wal - position := b.db.index.Get(key) - if position == nil { - return ErrKeyNotFound - } - chunk, err := b.db.dataFiles.Read(position) - if err != nil { - return err - } + return nil + } + // if the key does not exist in pendingWrites, get the value from wal + position := b.db.index.Get(key) + if position == nil { + return ErrKeyNotFound + } + chunk, err := b.db.dataFiles.Read(position) + if err != nil { + return err + } - now := time.Now() - record = decodeLogRecord(chunk) - // if the record is deleted or expired, we can assume that the key does not exist, - // and delete the key from the index - if record.Type == LogRecordDeleted || record.IsExpired(now.UnixNano()) { - b.db.index.Delete(key) - return ErrKeyNotFound - } - // now we get the value from wal, update the expiry time - // and rewrite the record to pendingWrites - record.Expire = now.Add(ttl).UnixNano() - b.pendingWrites = append(b.pendingWrites, record) + now := time.Now() + record = decodeLogRecord(chunk) + // if the record is deleted or expired, we can assume that the key does not exist, + // and delete the key from the index + if record.Type == LogRecordDeleted || record.IsExpired(now.UnixNano()) { + b.db.index.Delete(key) + return ErrKeyNotFound } + // now we get the value from wal, update the expiry time + // and rewrite the record to pendingWrites + record.Expire = now.Add(ttl).UnixNano() + b.appendPendingWrites(key, record) return nil } @@ -376,27 +348,17 @@ func (b *Batch) TTL(key []byte) (time.Duration, error) { b.mu.Lock() defer b.mu.Unlock() - // check if the key exists in pendingWrites - if len(b.pendingWrites) > 0 { - var record *LogRecord - for i := len(b.pendingWrites) - 1; i >= 0; i-- { - if bytes.Equal(key, b.pendingWrites[i].Key) { - record = b.pendingWrites[i] - break - } + var record = b.lookupPendingWrites(key) + if record != nil { + if record.Expire == 0 { + return -1, nil } - // if the key exists in pendingWrites, return the ttl directly - if record != nil { - if record.Expire == 0 { - return -1, nil - } - // return key not found if the record is deleted or expired - if record.Type == LogRecordDeleted || record.IsExpired(now.UnixNano()) { - return -1, ErrKeyNotFound - } - // now we get the valid expiry time, we can calculate the ttl - return time.Duration(record.Expire - now.UnixNano()), nil + // return key not found if the record is deleted or expired + if record.Type == LogRecordDeleted || record.IsExpired(now.UnixNano()) { + return -1, ErrKeyNotFound } + // now we get the valid expiry time, we can calculate the ttl + return time.Duration(record.Expire - now.UnixNano()), nil } // if the key does not exist in pendingWrites, get the value from wal @@ -410,7 +372,7 @@ func (b *Batch) TTL(key []byte) (time.Duration, error) { } // return key not found if the record is deleted or expired - record := decodeLogRecord(chunk) + record = decodeLogRecord(chunk) if record.Type == LogRecordDeleted { return -1, ErrKeyNotFound } @@ -443,48 +405,42 @@ func (b *Batch) Persist(key []byte) error { defer b.mu.Unlock() // if the key exists in pendingWrites, update the expiry time directly - var record *LogRecord - for i := len(b.pendingWrites) - 1; i >= 0; i-- { - if bytes.Equal(key, b.pendingWrites[i].Key) { - record = b.pendingWrites[i] - break - } - } - + var record = b.lookupPendingWrites(key) if record != nil { if record.Type == LogRecordDeleted && record.IsExpired(time.Now().UnixNano()) { return ErrKeyNotFound } record.Expire = 0 - } else { - // check if the key exists in index - position := b.db.index.Get(key) - if position == nil { - return ErrKeyNotFound - } - chunk, err := b.db.dataFiles.Read(position) - if err != nil { - return err - } + return nil + } - record := decodeLogRecord(chunk) - now := time.Now().UnixNano() - // check if the record is deleted or expired - if record.Type == LogRecordDeleted || record.IsExpired(now) { - b.db.index.Delete(record.Key) - return ErrKeyNotFound - } - // if the expiration time is 0, it means that the key has no expiration time, - // so we can return directly - if record.Expire == 0 { - return nil - } + // check if the key exists in index + position := b.db.index.Get(key) + if position == nil { + return ErrKeyNotFound + } + chunk, err := b.db.dataFiles.Read(position) + if err != nil { + return err + } - // set the expiration time to 0, and rewrite the record to wal - record.Expire = 0 - b.pendingWrites = append(b.pendingWrites, record) + record = decodeLogRecord(chunk) + now := time.Now().UnixNano() + // check if the record is deleted or expired + if record.Type == LogRecordDeleted || record.IsExpired(now) { + b.db.index.Delete(record.Key) + return ErrKeyNotFound + } + // if the expiration time is 0, it means that the key has no expiration time, + // so we can return directly + if record.Expire == 0 { + return nil } + // set the expiration time to 0, and rewrite the record to wal + record.Expire = 0 + b.appendPendingWrites(key, record) + return nil } @@ -602,8 +558,34 @@ func (b *Batch) Rollback() error { b.db.recordPool.Put(record) } b.pendingWrites = b.pendingWrites[:0] + for key := range b.pendingWritesMap { + delete(b.pendingWritesMap, key) + } } b.rollbacked = true return nil } + +// lookupPendingWrites if the key exists in pendingWrites, update the value directly +func (b *Batch) lookupPendingWrites(key []byte) (record *LogRecord) { + if len(b.pendingWritesMap) == 0 { + return + } + hashKey := utils.MemHash(key) + for _, entry := range b.pendingWritesMap[hashKey] { + if bytes.Compare(b.pendingWrites[entry].Key, key) == 0 { + return b.pendingWrites[entry] + } + } + return +} + +func (b *Batch) appendPendingWrites(key []byte, record *LogRecord) { + b.pendingWrites = append(b.pendingWrites, record) + if b.pendingWritesMap == nil { + b.pendingWritesMap = make(map[uint64][]int) + } + hashKey := utils.MemHash(key) + b.pendingWritesMap[hashKey] = append(b.pendingWritesMap[hashKey], len(b.pendingWrites)-1) +} diff --git a/benchmark/bench_test.go b/benchmark/bench_test.go index 2e55a692..1653f076 100644 --- a/benchmark/bench_test.go +++ b/benchmark/bench_test.go @@ -1,6 +1,7 @@ package benchmark import ( + "errors" "math/rand" "os" "testing" @@ -36,6 +37,14 @@ func BenchmarkPutGet(b *testing.B) { b.Run("get", bencharkGet) } +func BenchmarkBatchPutGet(b *testing.B) { + closer := openDB() + defer closer() + + b.Run("batchPut", benchmarkBatchPut) + b.Run("batchGet", benchmarkBatchGet) +} + func benchmarkPut(b *testing.B) { b.ResetTimer() b.ReportAllocs() @@ -46,6 +55,36 @@ func benchmarkPut(b *testing.B) { } } +func benchmarkBatchPut(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + + batch := db.NewBatch(rosedb.DefaultBatchOptions) + defer batch.Commit() + for i := 0; i < b.N; i++ { + err := batch.Put(utils.GetTestKey(i), utils.RandomValue(1024)) + assert.Nil(b, err) + } +} + +func benchmarkBatchGet(b *testing.B) { + for i := 0; i < 10000; i++ { + err := db.Put(utils.GetTestKey(i), utils.RandomValue(1024)) + assert.Nil(b, err) + } + + b.ResetTimer() + b.ReportAllocs() + batch := db.NewBatch(rosedb.DefaultBatchOptions) + defer batch.Commit() + for i := 0; i < b.N; i++ { + _, err := batch.Get(utils.GetTestKey(rand.Int())) + if err != nil && !errors.Is(err, rosedb.ErrKeyNotFound) { + b.Fatal(err) + } + } +} + func bencharkGet(b *testing.B) { for i := 0; i < 10000; i++ { err := db.Put(utils.GetTestKey(i), utils.RandomValue(1024)) @@ -57,7 +96,7 @@ func bencharkGet(b *testing.B) { for i := 0; i < b.N; i++ { _, err := db.Get(utils.GetTestKey(rand.Int())) - if err != nil && err != rosedb.ErrKeyNotFound { + if err != nil && !errors.Is(err, rosedb.ErrKeyNotFound) { b.Fatal(err) } } diff --git a/utils/hash.go b/utils/hash.go new file mode 100644 index 00000000..f568674a --- /dev/null +++ b/utils/hash.go @@ -0,0 +1,31 @@ +package utils + +import ( + _ "runtime" + "unsafe" +) + +type stringStruct struct { + str unsafe.Pointer + len int +} + +//go:noescape +//go:linkname memhash runtime.memhash +func memhash(p unsafe.Pointer, h, s uintptr) uintptr + +// MemHash is the hash function used by go map, it utilizes available hardware instructions(behaves +// as aeshash if aes instruction is available). +// NOTE: The hash seed changes for every process. So, this cannot be used as a persistent hash. +func MemHash(data []byte) uint64 { + ss := (*stringStruct)(unsafe.Pointer(&data)) + return uint64(memhash(ss.str, 0, uintptr(ss.len))) +} + +// MemHashString is the hash function used by go map, it utilizes available hardware instructions +// (behaves as aeshash if aes instruction is available). +// NOTE: The hash seed changes for every process. So, this cannot be used as a persistent hash. +func MemHashString(str string) uint64 { + ss := (*stringStruct)(unsafe.Pointer(&str)) + return uint64(memhash(ss.str, 0, uintptr(ss.len))) +} diff --git a/utils/hash.s b/utils/hash.s new file mode 100644 index 00000000..e69de29b