forked from golang/go
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation golang#18802 (comment) This CL is for a better understanding of the API based on the check-out/check-in model. Change-Id: I7fdef164291cbb064f593faabee53e5221d008da
- Loading branch information
1 parent
9e8ea56
commit 3233aa7
Showing
5 changed files
with
266 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package runtime | ||
|
||
import ( | ||
"internal/runtime/atomic" | ||
"unsafe" | ||
) | ||
|
||
type ShardedValue struct { | ||
// NewShard is a function that produces new shards of type T. | ||
NewShard func() | ||
|
||
id uintptr | ||
} | ||
|
||
func findShrad(p *p, s uintptr) (shard unsafe.Pointer, index int) { | ||
for i := range p.shardp { | ||
if p.shardp[i].pool == s { | ||
return p.shardp[i].shard, i | ||
} | ||
} | ||
return nil, -1 | ||
} | ||
|
||
var genShardId uintptr | ||
|
||
func updateShard(v unsafe.Pointer, f func(unsafe.Pointer) unsafe.Pointer) { | ||
s := (*ShardedValue)(v) | ||
id := atomic.Loaduintptr(&s.id) | ||
if id == 0 { | ||
atomic.Casuintptr(&s.id, 0, atomic.Xadduintptr(&genShardId, 1)) | ||
id = atomic.Loaduintptr(&s.id) | ||
} | ||
p := getg().m.p.ptr() | ||
shard, index := findShrad(p, id) | ||
if index == -1 { | ||
p.shardp = append(p.shardp, struct { | ||
shard unsafe.Pointer | ||
pool uintptr | ||
}{ | ||
pool: id, | ||
shard: f(nil), | ||
}) | ||
return | ||
} | ||
p.shardp[index].shard = f(shard) | ||
KeepAlive(s) | ||
} | ||
|
||
func valueShard(v unsafe.Pointer, yield func(unsafe.Pointer), Getret func() unsafe.Pointer) { | ||
s := (*ShardedValue)(v) | ||
stw := stopTheWorld(stwShardRead) | ||
once := false | ||
for i := range allp { | ||
v, index := findShrad(allp[i], s.id) | ||
if v != nil { | ||
yield(v) | ||
} | ||
if index != -1 { | ||
if once { | ||
allp[i].shardp[index].shard = nil | ||
} else { | ||
allp[i].shardp[index].shard = Getret() | ||
once = true | ||
} | ||
} | ||
} | ||
startTheWorld(stw) | ||
} | ||
|
||
func drainShard(v unsafe.Pointer, yield func(unsafe.Pointer)) { | ||
s := (*ShardedValue)(v) | ||
stw := stopTheWorld(stwShardRead) | ||
for i := range allp { | ||
v, _ := findShrad(allp[i], s.id) | ||
if v != nil { | ||
yield(v) | ||
} | ||
} | ||
startTheWorld(stw) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
package sync | ||
|
||
import ( | ||
"unsafe" | ||
) | ||
|
||
// ShardedValue is a pool of values that all represent a small piece of a single | ||
// conceptual value. These values must implement Shard on themselves. | ||
// | ||
// The purpose of a ShardedValue is to enable the creation of scalable data structures | ||
// that may be updated in shards that are local to the goroutine without any | ||
// additional synchronization. In practice, shards are bound to lower-level | ||
// scheduling resources, such as OS threads and CPUs, for efficiency. | ||
// | ||
// The zero value is ready for use. | ||
type ShardedValue[T Shard[T]] struct { | ||
// NewShard is a function that produces new shards of type T. | ||
NewShard func() T | ||
|
||
id uintptr | ||
} | ||
|
||
// Update acquires a shard and passes it to the provided function. | ||
// The function returns the new shard value to return to the pool. | ||
// Update is safe to call from multiple goroutines. | ||
// Callers are encouraged to keep the provided function as short | ||
// as possible, and are discouraged from blocking within them. | ||
func (s *ShardedValue[T]) Update(f func(value T) T) { | ||
updateShard(unsafe.Pointer(s), func(p unsafe.Pointer) unsafe.Pointer { | ||
if p == nil { | ||
t := s.NewShard() | ||
t = f(t) | ||
return unsafe.Pointer(&t) | ||
} | ||
t := (*T)(p) | ||
*t = f(*t) | ||
return p | ||
}) | ||
} | ||
|
||
// Value snapshots all values in the pool and returns the result of merging them all into | ||
// a single value. This single value is guaranteed to represent a consistent snapshot of | ||
// merging all outstanding shards at some point in time between when the call is made | ||
// and when it returns. This single value is immediately added back into the pool as a | ||
// single shard before being returned. | ||
func (s *ShardedValue[T]) Value() (ret T) { | ||
valueShard(unsafe.Pointer(s), func(p unsafe.Pointer) { | ||
t := (*T)(p) | ||
ret = ret.Merge(*t) | ||
}, func() unsafe.Pointer { return unsafe.Pointer(&ret) }) | ||
return | ||
} | ||
|
||
// Drain snapshots all values in the pool and returns the result of merging them all into | ||
// a single value. This single value is guaranteed to represent a consistent snapshot of | ||
// merging all outstanding shards at some point in time between when the call is made | ||
// and when it returns. Unlike Value, this single value is not added back to the pool. | ||
func (s *ShardedValue[T]) Drain() (ret T) { | ||
drainShard(unsafe.Pointer(s), func(p unsafe.Pointer) { | ||
t := (*T)(p) | ||
ret = ret.Merge(*t) | ||
}) | ||
return | ||
} | ||
|
||
// Shard is an interface implemented by types whose values can be merged | ||
// with values of the same type. | ||
type Shard[T any] interface { | ||
Merge(other T) T | ||
} | ||
|
||
//go:linkname updateShard runtime.updateShard | ||
func updateShard(s unsafe.Pointer, f func(unsafe.Pointer) unsafe.Pointer) | ||
|
||
//go:linkname valueShard runtime.valueShard | ||
func valueShard(v unsafe.Pointer, yield func(unsafe.Pointer), Getret func() unsafe.Pointer) | ||
|
||
//go:linkname drainShard runtime.drainShard | ||
func drainShard(v unsafe.Pointer, yield func(unsafe.Pointer)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
package sync_test | ||
|
||
import ( | ||
. "sync" | ||
"sync/atomic" | ||
"testing" | ||
) | ||
|
||
func TestShardCounter(t *testing.T) { | ||
var s Counter | ||
s.sp.NewShard = func() counterInt { | ||
return counterInt(0) | ||
} | ||
s.Add(1) | ||
if v := s.Value(); v != 1 { | ||
t.Fatalf("got %d , want %d", v, 1) | ||
} | ||
var wg WaitGroup | ||
for range 100 { | ||
wg.Add(1) | ||
go func() { | ||
defer wg.Done() | ||
s.Add(1) | ||
}() | ||
} | ||
wg.Wait() | ||
if v := s.Value(); v != 101 { | ||
t.Fatalf("got %d , want %d", v, 101) | ||
} | ||
} | ||
|
||
type Counter struct { | ||
sp ShardedValue[counterInt] | ||
} | ||
|
||
func (c *Counter) Add(value int) { | ||
c.sp.Update(func(v counterInt) counterInt { | ||
return counterInt(int(v) + value) | ||
}) | ||
} | ||
|
||
func (c *Counter) Value() int { | ||
return int(c.sp.Value()) | ||
} | ||
|
||
type counterInt int | ||
|
||
func (a counterInt) Merge(b counterInt) counterInt { | ||
return a + b | ||
} | ||
|
||
func TestShardDrain(t *testing.T) { | ||
var s Counter | ||
s.sp.NewShard = func() counterInt { | ||
return counterInt(0) | ||
} | ||
s.Add(1) | ||
if v := s.sp.Drain(); v != 1 { | ||
t.Fatalf("got %d , want %d", v, 1) | ||
} | ||
var wg WaitGroup | ||
for range 100 { | ||
wg.Add(1) | ||
go func() { | ||
defer wg.Done() | ||
s.Add(1) | ||
}() | ||
} | ||
wg.Wait() | ||
if v := s.sp.Drain(); v != 101 { | ||
t.Fatalf("got %d , want %d", v, 101) | ||
} | ||
} | ||
|
||
func BenchmarkCounter(b *testing.B) { | ||
b.Run("atomic/int64", func(b *testing.B) { | ||
i := int64(0) | ||
b.RunParallel(func(p *testing.PB) { | ||
for p.Next() { | ||
for range 100000 { | ||
atomic.AddInt64(&i, 1) | ||
} | ||
|
||
} | ||
}) | ||
}) | ||
b.Run("sync/int64", func(b *testing.B) { | ||
b.RunParallel(func(p *testing.PB) { | ||
c := Counter{} | ||
c.sp.NewShard = func() counterInt { | ||
return counterInt(0) | ||
} | ||
for p.Next() { | ||
for range 100000 { | ||
c.Add(1) | ||
} | ||
} | ||
}) | ||
}) | ||
} |