diff --git a/multistream.go b/multistream.go index b3de4c0..9f3a1a4 100644 --- a/multistream.go +++ b/multistream.go @@ -204,10 +204,11 @@ func (msm *MultistreamMuxer) Negotiate(rwc io.ReadWriteCloser) (proto string, ha } }() - // Send our protocol ID - if err := delimWriteBuffered(rwc, []byte(ProtocolID)); err != nil { - return "", nil, err - } + // Send the multistream protocol ID + // Ignore the error here. We want the handshake to finish, even if the + // other side has closed this rwc for writing. They may have sent us a + // message and closed. Future writers will get an error anyways. + _ = delimWriteBuffered(rwc, []byte(ProtocolID)) line, err := ReadNextToken(rwc) if err != nil { diff --git a/multistream_test.go b/multistream_test.go index 5664862..ce70c00 100644 --- a/multistream_test.go +++ b/multistream_test.go @@ -665,7 +665,7 @@ func (rob *readonlyBuffer) Close() error { return nil } -func TestNegotiateFail(t *testing.T) { +func TestNegotiatThenWriteFail(t *testing.T) { buf := new(bytes.Buffer) err := delimWrite(buf, []byte(ProtocolID)) @@ -683,9 +683,15 @@ func TestNegotiateFail(t *testing.T) { rob := &readonlyBuffer{bytes.NewReader(buf.Bytes())} _, _, err = mux.Negotiate(rob) + if err != nil { + t.Fatal("Negotiate should not fail here") + } + + _, err = rob.Write([]byte("app data")) if err == nil { - t.Fatal("Negotiate should fail here") + t.Fatal("Write should fail here") } + } type mockStream struct { @@ -747,19 +753,40 @@ func TestNegotiatePeerSendsAndCloses(t *testing.T) { t.Fatal(err) } - s := &mockStream{ - // We mock the closed stream by only expecting a single write. The - // mockstream will error on any more writes (same as writing to a closed - // stream) - expectWrite: [][]byte{delimtedProtocolID}, - toRead: [][]byte{buf.Bytes()}, - } - - mux := NewMultistreamMuxer() - mux.AddHandler("foo", nil) - _, _, err = mux.Negotiate(s) - if err != nil { - t.Fatal("Negotiate should not fail here", err) + type testCase = struct { + name string + s *mockStream + } + + testCases := []testCase{ + { + name: "Able to echo multistream protocol id, but not app protocol id", + s: &mockStream{ + // We mock the closed stream by only expecting a single write. The + // mockstream will error on any more writes (same as writing to a closed + // stream) + expectWrite: [][]byte{delimtedProtocolID}, + toRead: [][]byte{buf.Bytes()}, + }, + }, + { + name: "Not able to write anything. Stream closes too fast", + s: &mockStream{ + expectWrite: [][]byte{}, + toRead: [][]byte{buf.Bytes()}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mux := NewMultistreamMuxer() + mux.AddHandler("foo", nil) + _, _, err = mux.Negotiate(tc.s) + if err != nil { + t.Fatal("Negotiate should not fail here", err) + } + }) } }