Skip to content

Commit

Permalink
Add stream pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
anacrolix committed Mar 12, 2019
1 parent 07573a0 commit dc67e63
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 267 deletions.
9 changes: 4 additions & 5 deletions dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@ type IpfsDHT struct {
ctx context.Context
proc goprocess.Process

strmap map[peer.ID]*messageSender
smlk sync.Mutex

plk sync.Mutex
streamPoolMu sync.Mutex
streamPool map[peer.ID]map[*poolStream]struct{}

plk sync.Mutex
protocols []protocol.ID // DHT protocols
}

Expand Down Expand Up @@ -143,12 +142,12 @@ func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching, protocols []p
self: h.ID(),
peerstore: h.Peerstore(),
host: h,
strmap: make(map[peer.ID]*messageSender),
ctx: ctx,
providers: providers.NewProviderManager(ctx, h.ID(), dstore),
birth: time.Now(),
routingTable: rt,
protocols: protocols,
streamPool: make(map[peer.ID]map[*poolStream]struct{}),
}
}

Expand Down
248 changes: 23 additions & 225 deletions dht_net.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"fmt"
"io"
"sync"
"time"

ggio "github.com/gogo/protobuf/io"
Expand Down Expand Up @@ -109,40 +108,42 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) bool {

// sendRequest sends out a request, but also makes sure to
// measure the RTT for latency measurements.
func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {

ms, err := dht.messageSenderForPeer(ctx, p)
func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, req *pb.Message) (_ *pb.Message, err error) {
ps, _, err := dht.getPoolStream(ctx, p)
if err != nil {
return nil, err
return
}

start := time.Now()

rpmes, err := ms.SendRequest(ctx, pmes)
reply, err := ps.request(ctx, req)
if err != nil {
return nil, err
return
}

// update the peer (on valid msgs only)
dht.updateFromMessage(ctx, p, rpmes)

dht.updateFromMessage(ctx, p, reply)
dht.peerstore.RecordLatency(p, time.Since(start))
logger.Event(ctx, "dhtReceivedMessage", dht.self, p, rpmes)
return rpmes, nil
return reply, nil
}

func (dht *IpfsDHT) newStream(ctx context.Context, p peer.ID) (inet.Stream, error) {
return dht.host.NewStream(ctx, p, dht.protocols...)
}

// sendMessage sends out a message
func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error {
ms, err := dht.messageSenderForPeer(ctx, p)
func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) (err error) {
ps, _, err := dht.getPoolStream(ctx, p)
if err != nil {
return err
return
}

if err := ms.SendMessage(ctx, pmes); err != nil {
return err
err = ps.send(pmes)
if err == nil {
// Put the stream back in the pool, because we're not waiting for a reply.
dht.putPoolStream(ps, p)
} else {
// Destroy the stream, because we don't intend to use it again.
// Presumably it's in a bad state if we had an error while sending a message.
ps.reset()
}
logger.Event(ctx, "dhtSentMessage", dht.self, p, pmes)
return nil
return err
}

func (dht *IpfsDHT) updateFromMessage(ctx context.Context, p peer.ID, mes *pb.Message) error {
Expand All @@ -153,206 +154,3 @@ func (dht *IpfsDHT) updateFromMessage(ctx context.Context, p peer.ID, mes *pb.Me
}
return nil
}

func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) {
dht.smlk.Lock()
ms, ok := dht.strmap[p]
if ok {
dht.smlk.Unlock()
return ms, nil
}
ms = &messageSender{p: p, dht: dht}
dht.strmap[p] = ms
dht.smlk.Unlock()

if err := ms.prepOrInvalidate(ctx); err != nil {
dht.smlk.Lock()
defer dht.smlk.Unlock()

if msCur, ok := dht.strmap[p]; ok {
// Changed. Use the new one, old one is invalid and
// not in the map so we can just throw it away.
if ms != msCur {
return msCur, nil
}
// Not changed, remove the now invalid stream from the
// map.
delete(dht.strmap, p)
}
// Invalid but not in map. Must have been removed by a disconnect.
return nil, err
}
// All ready to go.
return ms, nil
}

type messageSender struct {
s inet.Stream
r ggio.ReadCloser
w bufferedWriteCloser
lk sync.Mutex
p peer.ID
dht *IpfsDHT

invalid bool
singleMes int
}

// invalidate is called before this messageSender is removed from the strmap.
// It prevents the messageSender from being reused/reinitialized and then
// forgotten (leaving the stream open).
func (ms *messageSender) invalidate() {
ms.invalid = true
if ms.s != nil {
ms.s.Reset()
ms.s = nil
}
}

func (ms *messageSender) prepOrInvalidate(ctx context.Context) error {
ms.lk.Lock()
defer ms.lk.Unlock()
if err := ms.prep(ctx); err != nil {
ms.invalidate()
return err
}
return nil
}

func (ms *messageSender) prep(ctx context.Context) error {
if ms.invalid {
return fmt.Errorf("message sender has been invalidated")
}
if ms.s != nil {
return nil
}

nstr, err := ms.dht.host.NewStream(ctx, ms.p, ms.dht.protocols...)
if err != nil {
return err
}

ms.r = ggio.NewDelimitedReader(nstr, inet.MessageSizeMax)
ms.w = newBufferedDelimitedWriter(nstr)
ms.s = nstr

return nil
}

// streamReuseTries is the number of times we will try to reuse a stream to a
// given peer before giving up and reverting to the old one-message-per-stream
// behaviour.
const streamReuseTries = 3

func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error {
ms.lk.Lock()
defer ms.lk.Unlock()
retry := false
for {
if err := ms.prep(ctx); err != nil {
return err
}

if err := ms.writeMsg(pmes); err != nil {
ms.s.Reset()
ms.s = nil

if retry {
logger.Info("error writing message, bailing: ", err)
return err
} else {
logger.Info("error writing message, trying again: ", err)
retry = true
continue
}
}

logger.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes)

if ms.singleMes > streamReuseTries {
go inet.FullClose(ms.s)
ms.s = nil
} else if retry {
ms.singleMes++
}

return nil
}
}

func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) {
ms.lk.Lock()
defer ms.lk.Unlock()
retry := false
for {
if err := ms.prep(ctx); err != nil {
return nil, err
}

if err := ms.writeMsg(pmes); err != nil {
ms.s.Reset()
ms.s = nil

if retry {
logger.Info("error writing message, bailing: ", err)
return nil, err
} else {
logger.Info("error writing message, trying again: ", err)
retry = true
continue
}
}

mes := new(pb.Message)
if err := ms.ctxReadMsg(ctx, mes); err != nil {
ms.s.Reset()
ms.s = nil

if retry {
logger.Info("error reading message, bailing: ", err)
return nil, err
} else {
logger.Info("error reading message, trying again: ", err)
retry = true
continue
}
}

logger.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes)

if ms.singleMes > streamReuseTries {
go inet.FullClose(ms.s)
ms.s = nil
} else if retry {
ms.singleMes++
}

return mes, nil
}
}

func (ms *messageSender) writeMsg(pmes *pb.Message) error {
if err := ms.w.WriteMsg(pmes); err != nil {
return err
}
return ms.w.Flush()
}

func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error {
errc := make(chan error, 1)
go func(r ggio.ReadCloser) {
errc <- r.ReadMsg(mes)
}(ms.r)

t := time.NewTimer(dhtReadMessageTimeout)
defer t.Stop()

select {
case err := <-errc:
return err
case <-ctx.Done():
return ctx.Err()
case <-t.C:
return ErrReadTimeout
}
}
22 changes: 0 additions & 22 deletions dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,28 +458,6 @@ func TestValueGetInvalid(t *testing.T) {
testSetGet("valid", "newer", nil)
}

func TestInvalidMessageSenderTracking(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

dht := setupDHT(ctx, t, false)
defer dht.Close()

foo := peer.ID("asdasd")
_, err := dht.messageSenderForPeer(ctx, foo)
if err == nil {
t.Fatal("that shouldnt have succeeded")
}

dht.smlk.Lock()
mscnt := len(dht.strmap)
dht.smlk.Unlock()

if mscnt > 0 {
t.Fatal("should have no message senders in map")
}
}

func TestProvides(t *testing.T) {
// t.Skip("skipping test to debug another")
ctx, cancel := context.WithCancel(context.Background())
Expand Down
15 changes: 0 additions & 15 deletions notif.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,6 @@ func (nn *netNotifiee) Disconnected(n inet.Network, v inet.Conn) {
}

dht.routingTable.Remove(p)

dht.smlk.Lock()
defer dht.smlk.Unlock()
ms, ok := dht.strmap[p]
if !ok {
return
}
delete(dht.strmap, p)

// Do this asynchronously as ms.lk can block for a while.
go func() {
ms.lk.Lock()
defer ms.lk.Unlock()
ms.invalidate()
}()
}

func (nn *netNotifiee) OpenedStream(n inet.Network, v inet.Stream) {}
Expand Down
Loading

0 comments on commit dc67e63

Please sign in to comment.