Skip to content

Commit

Permalink
Factor out streamPool
Browse files Browse the repository at this point in the history
  • Loading branch information
anacrolix committed Mar 15, 2019
1 parent 44da402 commit 3c5594b
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 64 deletions.
4 changes: 1 addition & 3 deletions dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ type IpfsDHT struct {
ctx context.Context
proc goprocess.Process

streamPoolMu sync.Mutex
streamPool map[peer.ID]map[*poolStream]struct{}
streamPool streamPool

plk sync.Mutex
protocols []protocol.ID // DHT protocols
Expand Down Expand Up @@ -140,7 +139,6 @@ func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching, protocols []p
self: h.ID(),
peerstore: h.Peerstore(),
host: h,
streamPool: make(map[peer.ID]map[*poolStream]struct{}),
ctx: ctx,
providers: providers.NewProviderManager(ctx, h.ID(), dstore),
birth: time.Now(),
Expand Down
19 changes: 11 additions & 8 deletions dht_net.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) bool {
// sendRequest sends out a request, but also makes sure to
// measure the RTT for latency measurements.
func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, req *pb.Message) (*pb.Message, error) {
ps, _, err := dht.getPoolStream(ctx, p)
ps, err := dht.getStream(ctx, p)
if err != nil {
return nil, err
}
Expand All @@ -120,7 +120,7 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, req *pb.Message)
return nil, err
}
onReply := func(reply *pb.Message) {
dht.putPoolStream(ps, p)
dht.streamPool.put(ps, p)
dht.updateFromMessage(ctx, p, reply)
dht.peerstore.RecordLatency(p, time.Since(start))
}
Expand All @@ -141,20 +141,16 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, req *pb.Message)
}
}

func (dht *IpfsDHT) newStream(ctx context.Context, p peer.ID) (inet.Stream, error) {
return dht.host.NewStream(ctx, p, dht.protocols...)
}

// sendMessage sends out a message
func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) (err error) {
ps, _, err := dht.getPoolStream(ctx, p)
ps, err := dht.getStream(ctx, p)
if err != nil {
return
}
err = ps.send(pmes)
if err == nil {
// Put the stream back in the pool, because we're not waiting for a reply.
dht.putPoolStream(ps, p)
dht.streamPool.put(ps, p)
} else {
// Destroy the stream, because we don't intend to use it again.
// Presumably it's in a bad state if we had an error while sending a message.
Expand All @@ -163,6 +159,13 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message
return err
}

func (dht *IpfsDHT) getStream(ctx context.Context, p peer.ID) (*stream, error) {
if ps, ok := dht.streamPool.get(ctx, p); ok {
return ps, nil
}
return dht.newStream(ctx, p)
}

func (dht *IpfsDHT) updateFromMessage(ctx context.Context, p peer.ID, mes *pb.Message) error {
// Make sure that this node is actually a DHT server, not just a client.
protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...)
Expand Down
65 changes: 12 additions & 53 deletions stream_pooling.go → stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,67 +11,26 @@ import (
"golang.org/x/xerrors"
)

func (dht *IpfsDHT) getPoolStream(ctx context.Context, p peer.ID) (_ *poolStream, reused bool, _ error) {
dht.streamPoolMu.Lock()
for ps := range dht.streamPool[p] {
dht.deletePoolStreamLocked(ps, p)
if ps.err() != nil {
// Stream went bad and hasn't deleted itself yet.
continue
}
dht.streamPoolMu.Unlock()
return ps, true, nil
}
dht.streamPoolMu.Unlock()
ps, err := dht.newPoolStream(ctx, p)
return ps, false, err
}

func (dht *IpfsDHT) putPoolStream(ps *poolStream, p peer.ID) {
dht.streamPoolMu.Lock()
defer dht.streamPoolMu.Unlock()
if ps.err() != nil {
return
}
if dht.streamPool[p] == nil {
dht.streamPool[p] = make(map[*poolStream]struct{})
}
dht.streamPool[p][ps] = struct{}{}
}

func (dht *IpfsDHT) deletePoolStream(ps *poolStream, p peer.ID) {
dht.streamPoolMu.Lock()
dht.deletePoolStreamLocked(ps, p)
dht.streamPoolMu.Unlock()
}

func (dht *IpfsDHT) deletePoolStreamLocked(ps *poolStream, p peer.ID) {
delete(dht.streamPool[p], ps)
if len(dht.streamPool) == 0 {
delete(dht.streamPool, p)
}
}

func (dht *IpfsDHT) newPoolStream(ctx context.Context, p peer.ID) (*poolStream, error) {
s, err := dht.newStream(ctx, p)
func (dht *IpfsDHT) newStream(ctx context.Context, p peer.ID) (*stream, error) {
s, err := dht.host.NewStream(ctx, p, dht.protocols...)
if err != nil {
return nil, xerrors.Errorf("opening stream: %w", err)
}
ps := &poolStream{
ps := &stream{
stream: s,
w: newBufferedDelimitedWriter(s),
r: pbio.NewDelimitedReader(s, inet.MessageSizeMax),
m: make(chan chan *pb.Message, 1),
}
go func() {
ps.reader()
dht.deletePoolStream(ps, p)
dht.streamPool.delete(ps, p)
ps.reset()
}()
return ps, nil
}

type poolStream struct {
type stream struct {
stream interface {
Reset() error
}
Expand All @@ -85,11 +44,11 @@ type poolStream struct {
readerErr error
}

func (me *poolStream) reset() {
func (me *stream) reset() {
me.stream.Reset()
}

func (me *poolStream) send(m *pb.Message) (err error) {
func (me *stream) send(m *pb.Message) (err error) {
if err := me.w.WriteMsg(m); err != nil {
return xerrors.Errorf("writing message: %w", err)
}
Expand All @@ -99,7 +58,7 @@ func (me *poolStream) send(m *pb.Message) (err error) {
return nil
}

func (me *poolStream) request(ctx context.Context, req *pb.Message) (<-chan *pb.Message, error) {
func (me *stream) request(ctx context.Context, req *pb.Message) (<-chan *pb.Message, error) {
replyChan := make(chan *pb.Message, 1)
me.mu.Lock()
if err := me.errLocked(); err != nil {
Expand All @@ -118,7 +77,7 @@ func (me *poolStream) request(ctx context.Context, req *pb.Message) (<-chan *pb.
}

// Handles the error returned from the read loop.
func (me *poolStream) reader() {
func (me *stream) reader() {
err := me.readLoop()
me.mu.Lock()
me.readerErr = err
Expand All @@ -130,7 +89,7 @@ func (me *poolStream) reader() {
}

// Reads from the stream until something is wrong.
func (me *poolStream) readLoop() error {
func (me *stream) readLoop() error {
for {
var m pb.Message
err := me.r.ReadMsg(&m)
Expand All @@ -146,14 +105,14 @@ func (me *poolStream) readLoop() error {
}
}

func (me *poolStream) err() error {
func (me *stream) err() error {
me.mu.Lock()
defer me.mu.Unlock()
return me.errLocked()
}

// A stream has gone bad when the reader has given up.
func (me *poolStream) errLocked() error {
func (me *stream) errLocked() error {
if me.readerErr != nil {
return xerrors.Errorf("reader: %w", me.readerErr)
}
Expand Down
55 changes: 55 additions & 0 deletions stream_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package dht

import (
"context"
"sync"

peer "github.com/libp2p/go-libp2p-peer"
)

type streamPool struct {
mu sync.Mutex
m map[peer.ID]map[*stream]struct{}
}

func (sp *streamPool) get(ctx context.Context, p peer.ID) (*stream, bool) {
sp.mu.Lock()
defer sp.mu.Unlock()
for ps := range sp.m[p] {
sp.deleteLocked(ps, p)
if ps.err() != nil {
// Stream went bad and hasn't deleted itself yet.
continue
}
return ps, true
}
return nil, false
}

func (sp *streamPool) put(ps *stream, p peer.ID) {
sp.mu.Lock()
defer sp.mu.Unlock()
if ps.err() != nil {
return
}
if sp.m == nil {
sp.m = make(map[peer.ID]map[*stream]struct{})
}
if sp.m[p] == nil {
sp.m[p] = make(map[*stream]struct{})
}
sp.m[p][ps] = struct{}{}
}

func (sp *streamPool) delete(ps *stream, p peer.ID) {
sp.mu.Lock()
sp.deleteLocked(ps, p)
sp.mu.Unlock()
}

func (sp *streamPool) deleteLocked(ps *stream, p peer.ID) {
delete(sp.m[p], ps)
if len(sp.m) == 0 {
delete(sp.m, p)
}
}

0 comments on commit 3c5594b

Please sign in to comment.