Skip to content

Commit

Permalink
sync: add ShardedValue
Browse files Browse the repository at this point in the history
Implementation golang#18802 (comment)

This CL is for a better understanding of the API based on the check-out/check-in model.

Change-Id: I7fdef164291cbb064f593faabee53e5221d008da
  • Loading branch information
qiulaidongfeng committed Aug 29, 2024
1 parent 9e8ea56 commit 3233aa7
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/runtime/proc.go
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,7 @@ const (
stwForTestReadMemStatsSlow // "ReadMemStatsSlow (test)"
stwForTestPageCachePagesLeaked // "PageCachePagesLeaked (test)"
stwForTestResetDebugLog // "ResetDebugLog (test)"
stwShardRead
)

func (r stwReason) String() string {
Expand Down Expand Up @@ -1402,6 +1403,7 @@ var stwReasonStrings = [...]string{
stwForTestReadMemStatsSlow: "ReadMemStatsSlow (test)",
stwForTestPageCachePagesLeaked: "PageCachePagesLeaked (test)",
stwForTestResetDebugLog: "ResetDebugLog (test)",
stwShardRead: "ShardRead",
}

// worldStop provides context from the stop-the-world required by the
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/runtime2.go
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,11 @@ type p struct {
// gcStopTime is the nanotime timestamp that this P last entered _Pgcstop.
gcStopTime int64

shardp []struct {
shard unsafe.Pointer
pool uintptr
}

// Padding is no longer needed. False sharing is now not a worry because p is large enough
// that its size class is an integer multiple of the cache line size (for any of our architectures).
}
Expand Down
80 changes: 80 additions & 0 deletions src/runtime/shard.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package runtime

import (
"internal/runtime/atomic"
"unsafe"
)

type ShardedValue struct {
// NewShard is a function that produces new shards of type T.
NewShard func()

id uintptr
}

func findShrad(p *p, s uintptr) (shard unsafe.Pointer, index int) {
for i := range p.shardp {
if p.shardp[i].pool == s {
return p.shardp[i].shard, i
}
}
return nil, -1
}

var genShardId uintptr

func updateShard(v unsafe.Pointer, f func(unsafe.Pointer) unsafe.Pointer) {
s := (*ShardedValue)(v)
id := atomic.Loaduintptr(&s.id)
if id == 0 {
atomic.Casuintptr(&s.id, 0, atomic.Xadduintptr(&genShardId, 1))
id = atomic.Loaduintptr(&s.id)
}
p := getg().m.p.ptr()
shard, index := findShrad(p, id)
if index == -1 {
p.shardp = append(p.shardp, struct {
shard unsafe.Pointer
pool uintptr
}{
pool: id,
shard: f(nil),
})
return
}
p.shardp[index].shard = f(shard)
KeepAlive(s)
}

func valueShard(v unsafe.Pointer, yield func(unsafe.Pointer), Getret func() unsafe.Pointer) {
s := (*ShardedValue)(v)
stw := stopTheWorld(stwShardRead)
once := false
for i := range allp {
v, index := findShrad(allp[i], s.id)
if v != nil {
yield(v)
}
if index != -1 {
if once {
allp[i].shardp[index].shard = nil
} else {
allp[i].shardp[index].shard = Getret()
once = true
}
}
}
startTheWorld(stw)
}

func drainShard(v unsafe.Pointer, yield func(unsafe.Pointer)) {
s := (*ShardedValue)(v)
stw := stopTheWorld(stwShardRead)
for i := range allp {
v, _ := findShrad(allp[i], s.id)
if v != nil {
yield(v)
}
}
startTheWorld(stw)
}
79 changes: 79 additions & 0 deletions src/sync/shard.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package sync

import (
"unsafe"
)

// ShardedValue is a pool of values that all represent a small piece of a single
// conceptual value. These values must implement Shard on themselves.
//
// The purpose of a ShardedValue is to enable the creation of scalable data structures
// that may be updated in shards that are local to the goroutine without any
// additional synchronization. In practice, shards are bound to lower-level
// scheduling resources, such as OS threads and CPUs, for efficiency.
//
// The zero value is ready for use.
type ShardedValue[T Shard[T]] struct {
// NewShard is a function that produces new shards of type T.
NewShard func() T

id uintptr
}

// Update acquires a shard and passes it to the provided function.
// The function returns the new shard value to return to the pool.
// Update is safe to call from multiple goroutines.
// Callers are encouraged to keep the provided function as short
// as possible, and are discouraged from blocking within them.
func (s *ShardedValue[T]) Update(f func(value T) T) {
updateShard(unsafe.Pointer(s), func(p unsafe.Pointer) unsafe.Pointer {
if p == nil {
t := s.NewShard()
t = f(t)
return unsafe.Pointer(&t)
}
t := (*T)(p)
*t = f(*t)
return p
})
}

// Value snapshots all values in the pool and returns the result of merging them all into
// a single value. This single value is guaranteed to represent a consistent snapshot of
// merging all outstanding shards at some point in time between when the call is made
// and when it returns. This single value is immediately added back into the pool as a
// single shard before being returned.
func (s *ShardedValue[T]) Value() (ret T) {
valueShard(unsafe.Pointer(s), func(p unsafe.Pointer) {
t := (*T)(p)
ret = ret.Merge(*t)
}, func() unsafe.Pointer { return unsafe.Pointer(&ret) })
return
}

// Drain snapshots all values in the pool and returns the result of merging them all into
// a single value. This single value is guaranteed to represent a consistent snapshot of
// merging all outstanding shards at some point in time between when the call is made
// and when it returns. Unlike Value, this single value is not added back to the pool.
func (s *ShardedValue[T]) Drain() (ret T) {
drainShard(unsafe.Pointer(s), func(p unsafe.Pointer) {
t := (*T)(p)
ret = ret.Merge(*t)
})
return
}

// Shard is an interface implemented by types whose values can be merged
// with values of the same type.
type Shard[T any] interface {
Merge(other T) T
}

//go:linkname updateShard runtime.updateShard
func updateShard(s unsafe.Pointer, f func(unsafe.Pointer) unsafe.Pointer)

//go:linkname valueShard runtime.valueShard
func valueShard(v unsafe.Pointer, yield func(unsafe.Pointer), Getret func() unsafe.Pointer)

//go:linkname drainShard runtime.drainShard
func drainShard(v unsafe.Pointer, yield func(unsafe.Pointer))
100 changes: 100 additions & 0 deletions src/sync/shard_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package sync_test

import (
. "sync"
"sync/atomic"
"testing"
)

func TestShardCounter(t *testing.T) {
var s Counter
s.sp.NewShard = func() counterInt {
return counterInt(0)
}
s.Add(1)
if v := s.Value(); v != 1 {
t.Fatalf("got %d , want %d", v, 1)
}
var wg WaitGroup
for range 100 {
wg.Add(1)
go func() {
defer wg.Done()
s.Add(1)
}()
}
wg.Wait()
if v := s.Value(); v != 101 {
t.Fatalf("got %d , want %d", v, 101)
}
}

type Counter struct {
sp ShardedValue[counterInt]
}

func (c *Counter) Add(value int) {
c.sp.Update(func(v counterInt) counterInt {
return counterInt(int(v) + value)
})
}

func (c *Counter) Value() int {
return int(c.sp.Value())
}

type counterInt int

func (a counterInt) Merge(b counterInt) counterInt {
return a + b
}

func TestShardDrain(t *testing.T) {
var s Counter
s.sp.NewShard = func() counterInt {
return counterInt(0)
}
s.Add(1)
if v := s.sp.Drain(); v != 1 {
t.Fatalf("got %d , want %d", v, 1)
}
var wg WaitGroup
for range 100 {
wg.Add(1)
go func() {
defer wg.Done()
s.Add(1)
}()
}
wg.Wait()
if v := s.sp.Drain(); v != 101 {
t.Fatalf("got %d , want %d", v, 101)
}
}

func BenchmarkCounter(b *testing.B) {
b.Run("atomic/int64", func(b *testing.B) {
i := int64(0)
b.RunParallel(func(p *testing.PB) {
for p.Next() {
for range 100000 {
atomic.AddInt64(&i, 1)
}

}
})
})
b.Run("sync/int64", func(b *testing.B) {
b.RunParallel(func(p *testing.PB) {
c := Counter{}
c.sp.NewShard = func() counterInt {
return counterInt(0)
}
for p.Next() {
for range 100000 {
c.Add(1)
}
}
})
})
}

0 comments on commit 3233aa7

Please sign in to comment.