From 18377323c9f962a2f52d7a27db00ceb2d32c3122 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Wed, 20 Feb 2019 15:21:10 +1100 Subject: [PATCH 1/3] Replace message senders with a stream per message --- dht.go | 10 +- dht_net.go | 265 ++++++++-------------------------------------------- dht_test.go | 22 ----- notif.go | 15 --- 4 files changed, 44 insertions(+), 268 deletions(-) diff --git a/dht.go b/dht.go index 73ba7588e..dc08dd17f 100644 --- a/dht.go +++ b/dht.go @@ -58,13 +58,10 @@ type IpfsDHT struct { ctx context.Context proc goprocess.Process - strmap map[peer.ID]*messageSender - smlk sync.Mutex - - plk sync.Mutex - + plk sync.Mutex protocols []protocol.ID // DHT protocols - client bool + + client bool } // Assert that IPFS assumptions about interfaces aren't broken. These aren't a @@ -145,7 +142,6 @@ func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching, protocols []p self: h.ID(), peerstore: h.Peerstore(), host: h, - strmap: make(map[peer.ID]*messageSender), ctx: ctx, providers: providers.NewProviderManager(ctx, h.ID(), dstore), birth: time.Now(), diff --git a/dht_net.go b/dht_net.go index 7a801f813..c5adf412d 100644 --- a/dht_net.go +++ b/dht_net.go @@ -5,9 +5,11 @@ import ( "context" "fmt" "io" - "sync" + "log" "time" + "golang.org/x/xerrors" + ggio "github.com/gogo/protobuf/io" ctxio "github.com/jbenet/go-context/io" pb "github.com/libp2p/go-libp2p-kad-dht/pb" @@ -112,39 +114,57 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) bool { // sendRequest sends out a request, but also makes sure to // measure the RTT for latency measurements. -func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { - - ms, err := dht.messageSenderForPeer(p) +func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, err error) { + defer func(started time.Time) { + log.Printf("time taken to send request: %v: err=%v", time.Since(started), err) + }(time.Now()) + s, err := dht.newStream(ctx, p) if err != nil { - return nil, err + return nil, xerrors.Errorf("error creating new stream: %w", err) } - + defer s.Reset() + dr := ggio.NewDelimitedReader(s, inet.MessageSizeMax) + bdw := newBufferedDelimitedWriter(s) start := time.Now() - - rpmes, err := ms.SendRequest(ctx, pmes) + err = bdw.WriteMsg(pmes) if err != nil { - return nil, err + return nil, xerrors.Errorf("error writing message: %w", err) + } + if err := bdw.Flush(); err != nil { + return nil, xerrors.Errorf("error flushing message: %w", err) + } + var reply pb.Message + if err := dr.ReadMsg(&reply); err != nil { + return nil, xerrors.Errorf("error reading reply: %w", err) } - // update the peer (on valid msgs only) - dht.updateFromMessage(ctx, p, rpmes) - + dht.updateFromMessage(ctx, p, &reply) dht.peerstore.RecordLatency(p, time.Since(start)) - logger.Event(ctx, "dhtReceivedMessage", dht.self, p, rpmes) - return rpmes, nil + return &reply, nil +} + +func (dht *IpfsDHT) newStream(ctx context.Context, p peer.ID) (inet.Stream, error) { + return dht.host.NewStream(ctx, p, dht.protocols...) } // sendMessage sends out a message -func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { - ms, err := dht.messageSenderForPeer(p) +func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) (err error) { + defer func(started time.Time) { + log.Printf("time taken to send message: %v: err=%v", time.Since(started), err) + }(time.Now()) + s, err := dht.newStream(ctx, p) if err != nil { - return err + return xerrors.Errorf("error creating new stream: %w", err) } - - if err := ms.SendMessage(ctx, pmes); err != nil { - return err + defer s.Reset() + bdw := newBufferedDelimitedWriter(s) + err = bdw.WriteMsg(pmes) + if err != nil { + return xerrors.Errorf("error writing message: %w", err) + } + if err := bdw.Flush(); err != nil { + return xerrors.Errorf("error flushing message: %w", err) } - logger.Event(ctx, "dhtSentMessage", dht.self, p, pmes) return nil } @@ -156,206 +176,3 @@ func (dht *IpfsDHT) updateFromMessage(ctx context.Context, p peer.ID, mes *pb.Me } return nil } - -func (dht *IpfsDHT) messageSenderForPeer(p peer.ID) (*messageSender, error) { - dht.smlk.Lock() - ms, ok := dht.strmap[p] - if ok { - dht.smlk.Unlock() - return ms, nil - } - ms = &messageSender{p: p, dht: dht} - dht.strmap[p] = ms - dht.smlk.Unlock() - - if err := ms.prepOrInvalidate(); err != nil { - dht.smlk.Lock() - defer dht.smlk.Unlock() - - if msCur, ok := dht.strmap[p]; ok { - // Changed. Use the new one, old one is invalid and - // not in the map so we can just throw it away. - if ms != msCur { - return msCur, nil - } - // Not changed, remove the now invalid stream from the - // map. - delete(dht.strmap, p) - } - // Invalid but not in map. Must have been removed by a disconnect. - return nil, err - } - // All ready to go. - return ms, nil -} - -type messageSender struct { - s inet.Stream - r ggio.ReadCloser - w bufferedWriteCloser - lk sync.Mutex - p peer.ID - dht *IpfsDHT - - invalid bool - singleMes int -} - -// invalidate is called before this messageSender is removed from the strmap. -// It prevents the messageSender from being reused/reinitialized and then -// forgotten (leaving the stream open). -func (ms *messageSender) invalidate() { - ms.invalid = true - if ms.s != nil { - ms.s.Reset() - ms.s = nil - } -} - -func (ms *messageSender) prepOrInvalidate() error { - ms.lk.Lock() - defer ms.lk.Unlock() - if err := ms.prep(); err != nil { - ms.invalidate() - return err - } - return nil -} - -func (ms *messageSender) prep() error { - if ms.invalid { - return fmt.Errorf("message sender has been invalidated") - } - if ms.s != nil { - return nil - } - - nstr, err := ms.dht.host.NewStream(ms.dht.ctx, ms.p, ms.dht.protocols...) - if err != nil { - return err - } - - ms.r = ggio.NewDelimitedReader(nstr, inet.MessageSizeMax) - ms.w = newBufferedDelimitedWriter(nstr) - ms.s = nstr - - return nil -} - -// streamReuseTries is the number of times we will try to reuse a stream to a -// given peer before giving up and reverting to the old one-message-per-stream -// behaviour. -const streamReuseTries = 3 - -func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error { - ms.lk.Lock() - defer ms.lk.Unlock() - retry := false - for { - if err := ms.prep(); err != nil { - return err - } - - if err := ms.writeMsg(pmes); err != nil { - ms.s.Reset() - ms.s = nil - - if retry { - logger.Info("error writing message, bailing: ", err) - return err - } else { - logger.Info("error writing message, trying again: ", err) - retry = true - continue - } - } - - logger.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes) - - if ms.singleMes > streamReuseTries { - go inet.FullClose(ms.s) - ms.s = nil - } else if retry { - ms.singleMes++ - } - - return nil - } -} - -func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) { - ms.lk.Lock() - defer ms.lk.Unlock() - retry := false - for { - if err := ms.prep(); err != nil { - return nil, err - } - - if err := ms.writeMsg(pmes); err != nil { - ms.s.Reset() - ms.s = nil - - if retry { - logger.Info("error writing message, bailing: ", err) - return nil, err - } else { - logger.Info("error writing message, trying again: ", err) - retry = true - continue - } - } - - mes := new(pb.Message) - if err := ms.ctxReadMsg(ctx, mes); err != nil { - ms.s.Reset() - ms.s = nil - - if retry { - logger.Info("error reading message, bailing: ", err) - return nil, err - } else { - logger.Info("error reading message, trying again: ", err) - retry = true - continue - } - } - - logger.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes) - - if ms.singleMes > streamReuseTries { - go inet.FullClose(ms.s) - ms.s = nil - } else if retry { - ms.singleMes++ - } - - return mes, nil - } -} - -func (ms *messageSender) writeMsg(pmes *pb.Message) error { - if err := ms.w.WriteMsg(pmes); err != nil { - return err - } - return ms.w.Flush() -} - -func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error { - errc := make(chan error, 1) - go func(r ggio.ReadCloser) { - errc <- r.ReadMsg(mes) - }(ms.r) - - t := time.NewTimer(dhtReadMessageTimeout) - defer t.Stop() - - select { - case err := <-errc: - return err - case <-ctx.Done(): - return ctx.Err() - case <-t.C: - return ErrReadTimeout - } -} diff --git a/dht_test.go b/dht_test.go index feb646cfe..8e2860b4a 100644 --- a/dht_test.go +++ b/dht_test.go @@ -458,28 +458,6 @@ func TestValueGetInvalid(t *testing.T) { testSetGet("valid", "newer", nil) } -func TestInvalidMessageSenderTracking(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - dht := setupDHT(ctx, t, false) - defer dht.Close() - - foo := peer.ID("asdasd") - _, err := dht.messageSenderForPeer(foo) - if err == nil { - t.Fatal("that shouldnt have succeeded") - } - - dht.smlk.Lock() - mscnt := len(dht.strmap) - dht.smlk.Unlock() - - if mscnt > 0 { - t.Fatal("should have no message senders in map") - } -} - func TestProvides(t *testing.T) { // t.Skip("skipping test to debug another") ctx, cancel := context.WithCancel(context.Background()) diff --git a/notif.go b/notif.go index fbcb073be..915953551 100644 --- a/notif.go +++ b/notif.go @@ -93,21 +93,6 @@ func (nn *netNotifiee) Disconnected(n inet.Network, v inet.Conn) { } dht.routingTable.Remove(p) - - dht.smlk.Lock() - defer dht.smlk.Unlock() - ms, ok := dht.strmap[p] - if !ok { - return - } - delete(dht.strmap, p) - - // Do this asynchronously as ms.lk can block for a while. - go func() { - ms.lk.Lock() - defer ms.lk.Unlock() - ms.invalidate() - }() } func (nn *netNotifiee) OpenedStream(n inet.Network, v inet.Stream) {} From 0a4458a1e47e87e08e6a756937c4be131c014276 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Wed, 20 Feb 2019 15:22:25 +1100 Subject: [PATCH 2/3] Log before and after bootstrap sub-queries --- dht_bootstrap.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/dht_bootstrap.go b/dht_bootstrap.go index 2ed1ac1af..dc74fcb4c 100644 --- a/dht_bootstrap.go +++ b/dht_bootstrap.go @@ -138,15 +138,11 @@ func (dht *IpfsDHT) randomWalk(ctx context.Context) error { // runBootstrap builds up list of peers by requesting random peer IDs func (dht *IpfsDHT) runBootstrap(ctx context.Context, cfg BootstrapConfig) error { - bslog := func(msg string) { - logger.Debugf("DHT %s dhtRunBootstrap %s -- routing table size: %d", dht.self, msg, dht.routingTable.Size()) - } - bslog("start") - defer bslog("end") - defer logger.EventBegin(ctx, "dhtRunBootstrap").Done() - doQuery := func(n int, target string, f func(context.Context) error) error { - logger.Infof("Bootstrapping query (%d/%d) to %s", n, cfg.Queries, target) + logger.Infof("starting bootstrap query (%d/%d) to %s (rt_len=%d)", n, cfg.Queries, target, dht.routingTable.Size()) + defer func() { + logger.Infof("finished bootstrap query (%d/%d) to %s (rt_len=%d)", n, cfg.Queries, target, dht.routingTable.Size()) + }() queryCtx, cancel := context.WithTimeout(ctx, cfg.Timeout) defer cancel() err := f(queryCtx) From b94f90cd5f2f489a05a55f0d59bf660a3c19ddc1 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Thu, 21 Feb 2019 10:34:07 +1100 Subject: [PATCH 3/3] Add stream pooling --- dht.go | 4 ++ dht_net.go | 43 ++++--------- pool_stream.go | 172 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 188 insertions(+), 31 deletions(-) create mode 100644 pool_stream.go diff --git a/dht.go b/dht.go index dc08dd17f..a4b5291da 100644 --- a/dht.go +++ b/dht.go @@ -58,6 +58,9 @@ type IpfsDHT struct { ctx context.Context proc goprocess.Process + streamPoolMu sync.Mutex + streamPool map[peer.ID]map[*poolStream]struct{} + plk sync.Mutex protocols []protocol.ID // DHT protocols @@ -147,6 +150,7 @@ func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching, protocols []p birth: time.Now(), routingTable: rt, protocols: protocols, + streamPool: make(map[peer.ID]map[*poolStream]struct{}), } } diff --git a/dht_net.go b/dht_net.go index c5adf412d..43b102f6a 100644 --- a/dht_net.go +++ b/dht_net.go @@ -8,8 +8,6 @@ import ( "log" "time" - "golang.org/x/xerrors" - ggio "github.com/gogo/protobuf/io" ctxio "github.com/jbenet/go-context/io" pb "github.com/libp2p/go-libp2p-kad-dht/pb" @@ -114,33 +112,24 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) bool { // sendRequest sends out a request, but also makes sure to // measure the RTT for latency measurements. -func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, err error) { +func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, req *pb.Message) (_ *pb.Message, err error) { defer func(started time.Time) { log.Printf("time taken to send request: %v: err=%v", time.Since(started), err) }(time.Now()) - s, err := dht.newStream(ctx, p) + ps, err := dht.getPoolStream(ctx, p) if err != nil { - return nil, xerrors.Errorf("error creating new stream: %w", err) + return } - defer s.Reset() - dr := ggio.NewDelimitedReader(s, inet.MessageSizeMax) - bdw := newBufferedDelimitedWriter(s) + defer dht.putPoolStream(ps, p) start := time.Now() - err = bdw.WriteMsg(pmes) + reply, err := ps.request(ctx, req) if err != nil { - return nil, xerrors.Errorf("error writing message: %w", err) - } - if err := bdw.Flush(); err != nil { - return nil, xerrors.Errorf("error flushing message: %w", err) - } - var reply pb.Message - if err := dr.ReadMsg(&reply); err != nil { - return nil, xerrors.Errorf("error reading reply: %w", err) + return } // update the peer (on valid msgs only) - dht.updateFromMessage(ctx, p, &reply) + dht.updateFromMessage(ctx, p, reply) dht.peerstore.RecordLatency(p, time.Since(start)) - return &reply, nil + return reply, nil } func (dht *IpfsDHT) newStream(ctx context.Context, p peer.ID) (inet.Stream, error) { @@ -152,20 +141,12 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message defer func(started time.Time) { log.Printf("time taken to send message: %v: err=%v", time.Since(started), err) }(time.Now()) - s, err := dht.newStream(ctx, p) + ps, err := dht.getPoolStream(ctx, p) if err != nil { - return xerrors.Errorf("error creating new stream: %w", err) - } - defer s.Reset() - bdw := newBufferedDelimitedWriter(s) - err = bdw.WriteMsg(pmes) - if err != nil { - return xerrors.Errorf("error writing message: %w", err) - } - if err := bdw.Flush(); err != nil { - return xerrors.Errorf("error flushing message: %w", err) + return } - return nil + defer dht.putPoolStream(ps, p) + return ps.send(pmes) } func (dht *IpfsDHT) updateFromMessage(ctx context.Context, p peer.ID, mes *pb.Message) error { diff --git a/pool_stream.go b/pool_stream.go new file mode 100644 index 000000000..294e67c9b --- /dev/null +++ b/pool_stream.go @@ -0,0 +1,172 @@ +package dht + +import ( + "context" + "log" + "sync" + + pbio "github.com/gogo/protobuf/io" + pb "github.com/libp2p/go-libp2p-kad-dht/pb" + inet "github.com/libp2p/go-libp2p-net" + peer "github.com/libp2p/go-libp2p-peer" + "golang.org/x/xerrors" +) + +func (dht *IpfsDHT) getPoolStream(ctx context.Context, p peer.ID) (*poolStream, error) { + dht.streamPoolMu.Lock() + for ps := range dht.streamPool[p] { + dht.deletePoolStreamLocked(ps, p) + if ps.bad() { + log.Printf("got bad pool stream for %v", p) + continue + } + dht.streamPoolMu.Unlock() + log.Printf("reusing pool stream for %v", p) + return ps, nil + } + dht.streamPoolMu.Unlock() + log.Printf("creating new pool stream for %v", p) + return dht.newPoolStream(ctx, p) +} + +func (dht *IpfsDHT) putPoolStream(ps *poolStream, p peer.ID) { + dht.streamPoolMu.Lock() + defer dht.streamPoolMu.Unlock() + if ps.bad() { + log.Printf("putting pool stream for %v but it went bad", p) + return + } + log.Printf("putting pool stream for %v", p) + if dht.streamPool[p] == nil { + dht.streamPool[p] = make(map[*poolStream]struct{}) + } + dht.streamPool[p][ps] = struct{}{} +} + +func (dht *IpfsDHT) deletePoolStream(ps *poolStream, p peer.ID) { + dht.streamPoolMu.Lock() + dht.deletePoolStreamLocked(ps, p) + dht.streamPoolMu.Unlock() +} + +func (dht *IpfsDHT) deletePoolStreamLocked(ps *poolStream, p peer.ID) { + log.Printf("deleting pool stream for %v", p) + delete(dht.streamPool[p], ps) + if len(dht.streamPool) == 0 { + delete(dht.streamPool, p) + } +} + +func (dht *IpfsDHT) newPoolStream(ctx context.Context, p peer.ID) (*poolStream, error) { + s, err := dht.newStream(ctx, p) + if err != nil { + return nil, xerrors.Errorf("opening stream: %w", err) + } + ps := &poolStream{ + stream: s, + w: newBufferedDelimitedWriter(s), + r: pbio.NewDelimitedReader(s, inet.MessageSizeMax), + m: make(chan chan *pb.Message, 1), + } + ps.onReaderErr = func() { + ps.reset() + dht.deletePoolStream(ps, p) + } + go ps.reader() + return ps, nil +} + +type poolStream struct { + stream interface { + Reset() error + } + w bufferedWriteCloser + r pbio.ReadCloser + onReaderErr func() + + mu sync.Mutex + m chan chan *pb.Message + readerErr error +} + +func (me *poolStream) reset() { + me.stream.Reset() +} + +func (me *poolStream) send(m *pb.Message) (err error) { + defer func() { + if err != nil { + log.Printf("error sending message: %v", err) + } + }() + if err := me.w.WriteMsg(m); err != nil { + return xerrors.Errorf("writing message: %w", err) + } + if err := me.w.Flush(); err != nil { + return xerrors.Errorf("flushing: %w", err) + } + return nil +} + +func (me *poolStream) request(ctx context.Context, req *pb.Message) (*pb.Message, error) { + replyChan := make(chan *pb.Message, 1) + me.mu.Lock() + if me.readerErr != nil { + me.mu.Unlock() + return nil, xerrors.Errorf("reader: %w", me.readerErr) + } + select { + case me.m <- replyChan: + default: + me.mu.Unlock() + return nil, xerrors.New("message pipeline full") + } + me.mu.Unlock() + err := me.send(req) + if err != nil { + return nil, err + } + select { + case reply, ok := <-replyChan: + if !ok { + return nil, xerrors.Errorf("reader: %w", err) + } + return reply, nil + case <-ctx.Done(): + return nil, xerrors.Errorf("while waiting for reply: %w", ctx.Err()) + } +} + +func (me *poolStream) reader() { + err := me.readLoop() + me.mu.Lock() + me.readerErr = err + close(me.m) + me.mu.Unlock() + me.onReaderErr() + for mc := range me.m { + close(mc) + } +} + +func (me *poolStream) readLoop() error { + for { + var m pb.Message + err := me.r.ReadMsg(&m) + if err != nil { + return err + } + select { + case mc := <-me.m: + mc <- &m + default: + return xerrors.New("read superfluous message") + } + } +} + +func (me *poolStream) bad() bool { + me.mu.Lock() + defer me.mu.Unlock() + return me.readerErr != nil +}