Skip to content

Commit

Permalink
membuffer: fix memory leak in red-black tree (#1483)
Browse files Browse the repository at this point in the history
close #1375, ref pingcap/tidb#56837

Signed-off-by: you06 <you1474600@gmail.com>

Co-authored-by: cfzjywxk <cfzjywxk@gmail.com>
  • Loading branch information
you06 and cfzjywxk authored Nov 14, 2024
1 parent 70049ae commit 86678db
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 13 deletions.
2 changes: 1 addition & 1 deletion internal/unionstore/memdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ func testMemBufferCache(t *testing.T, buffer MemBuffer) {
}

func TestMemDBLeafFragmentation(t *testing.T) {
// RBT cannot pass the leaf fragmentation test.
testMemDBLeafFragmentation(t, newRbtDBWithContext())
testMemDBLeafFragmentation(t, newArtDBWithContext())
}

Expand Down
42 changes: 32 additions & 10 deletions internal/unionstore/rbt/rbt.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,12 @@ func (db *RBT) RevertVAddr(hdr *arena.MemdbVlogHdr) {
// If there are no flags associated with this key, we need to delete this node.
keptFlags := node.getKeyFlags().AndPersistent()
if keptFlags == 0 {
db.deleteNode(node)
node.markDelete()
db.count--
db.size -= int(node.klen)
} else {
node.setKeyFlags(keptFlags)
db.dirty = true
node.resetKeyFlags(keptFlags)
}
} else {
db.size += len(db.vlog.GetValue(hdr.OldValue))
Expand Down Expand Up @@ -279,7 +281,7 @@ func (db *RBT) SelectValueHistory(key []byte, predicate func(value []byte) bool)
// GetFlags returns the latest flags associated with key.
func (db *RBT) GetFlags(key []byte) (kv.KeyFlags, error) {
x := db.traverse(key, false)
if x.isNull() {
if x.isNull() || x.isDeleted() {
return 0, tikverr.ErrNotExist
}
return x.getKeyFlags(), nil
Expand Down Expand Up @@ -347,17 +349,22 @@ func (db *RBT) Set(key []byte, value []byte, ops ...kv.FlagsOp) error {
// the NeedConstraintCheckInPrewrite flag is temporary,
// every write to the node removes the flag unless it's explicitly set.
// This set must be in the latest stage so no special processing is needed.
var flags kv.KeyFlags
flags := x.GetKeyFlags()
if flags == 0 && x.vptr.IsNull() && x.isDeleted() {
x.unmarkDelete()
db.count++
db.size += int(x.klen)
}
if value != nil {
flags = kv.ApplyFlagsOps(x.getKeyFlags(), append([]kv.FlagsOp{kv.DelNeedConstraintCheckInPrewrite}, ops...)...)
flags = kv.ApplyFlagsOps(flags, append([]kv.FlagsOp{kv.DelNeedConstraintCheckInPrewrite}, ops...)...)
} else {
// an UpdateFlag operation, do not delete the NeedConstraintCheckInPrewrite flag.
flags = kv.ApplyFlagsOps(x.getKeyFlags(), ops...)
flags = kv.ApplyFlagsOps(flags, ops...)
}
if flags.AndPersistent() != 0 {
db.dirty = true
}
x.setKeyFlags(flags)
x.resetKeyFlags(flags)

if value == nil {
return nil
Expand Down Expand Up @@ -881,8 +888,11 @@ func (n *memdbNode) getKey() []byte {

const (
// bit 1 => red, bit 0 => black
nodeColorBit uint16 = 0x8000
nodeFlagsMask = ^nodeColorBit
nodeColorBit uint16 = 0x8000
// bit 1 => node is deleted, bit 0 => node is not deleted
// This flag is used to mark a node as deleted, so that we can reuse the node to avoid memory leak.
deleteFlag uint16 = 1 << 14
nodeFlagsMask = ^(nodeColorBit | deleteFlag)
)

func (n *memdbNode) GetKeyFlags() kv.KeyFlags {
Expand All @@ -893,10 +903,22 @@ func (n *memdbNode) getKeyFlags() kv.KeyFlags {
return kv.KeyFlags(n.flags & nodeFlagsMask)
}

func (n *memdbNode) setKeyFlags(f kv.KeyFlags) {
func (n *memdbNode) resetKeyFlags(f kv.KeyFlags) {
n.flags = (^nodeFlagsMask & n.flags) | uint16(f)
}

func (n *memdbNode) markDelete() {
n.flags = (nodeColorBit & n.flags) | deleteFlag
}

func (n *memdbNode) unmarkDelete() {
n.flags &= ^deleteFlag
}

func (n *memdbNode) isDeleted() bool {
return n.flags&deleteFlag != 0
}

// RemoveFromBuffer removes a record from the mem buffer. It should be only used for test.
func (db *RBT) RemoveFromBuffer(key []byte) {
x := db.traverse(key, false)
Expand Down
14 changes: 12 additions & 2 deletions internal/unionstore/rbt/rbt_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (i *RBTIterator) init() {
}
}

if i.isFlagsOnly() && !i.includeFlags {
if (i.isFlagsOnly() && !i.includeFlags) || (!i.curr.isNull() && i.curr.isDeleted()) {
err := i.Next()
_ = err // memdbIterator will never fail
}
Expand All @@ -142,7 +142,7 @@ func (i *RBTIterator) Flags() kv.KeyFlags {
func (i *RBTIterator) UpdateFlags(ops ...kv.FlagsOp) {
origin := i.curr.getKeyFlags()
n := kv.ApplyFlagsOps(origin, ops...)
i.curr.setKeyFlags(n)
i.curr.resetKeyFlags(n)
}

// HasValue returns false if it is flags only.
Expand Down Expand Up @@ -174,6 +174,10 @@ func (i *RBTIterator) Next() error {
i.curr = i.db.successor(i.curr)
}

if i.curr.isDeleted() {
continue
}

// We need to skip persistent flags only nodes.
if i.includeFlags || !i.isFlagsOnly() {
break
Expand All @@ -195,6 +199,9 @@ func (i *RBTIterator) seekToFirst() {
}

i.curr = y
for !i.curr.isNull() && i.curr.isDeleted() {
i.curr = i.db.successor(i.curr)
}
}

func (i *RBTIterator) seekToLast() {
Expand All @@ -207,6 +214,9 @@ func (i *RBTIterator) seekToLast() {
}

i.curr = y
for !i.curr.isNull() && i.curr.isDeleted() {
i.curr = i.db.predecessor(i.curr)
}
}

func (i *RBTIterator) seek(key []byte) {
Expand Down

0 comments on commit 86678db

Please sign in to comment.