diff --git a/client.go b/client.go index 25e81e2..8261631 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "io" "os" "runtime/debug" + "testing" ) // ErrNotSupported is the error returned when the muxer doesn't support @@ -35,12 +36,14 @@ var ErrNoProtocols = errors.New("no protocols specified") // on this ReadWriteCloser. It returns an error if, for example, // the muxer does not know how to handle this protocol. func SelectProtoOrFail[T StringLike](proto T, rwc io.ReadWriteCloser) (err error) { - defer func() { - if rerr := recover(); rerr != nil { - fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack()) - err = fmt.Errorf("panic selecting protocol: %s", rerr) - } - }() + if !testing.Testing() { + defer func() { + if rerr := recover(); rerr != nil { + fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack()) + err = fmt.Errorf("panic selecting protocol: %s", rerr) + } + }() + } errCh := make(chan error, 1) go func() { @@ -70,12 +73,14 @@ func SelectProtoOrFail[T StringLike](proto T, rwc io.ReadWriteCloser) (err error // SelectOneOf will perform handshakes with the protocols on the given slice // until it finds one which is supported by the muxer. func SelectOneOf[T StringLike](protos []T, rwc io.ReadWriteCloser) (proto T, err error) { - defer func() { - if rerr := recover(); rerr != nil { - fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack()) - err = fmt.Errorf("panic selecting one of protocols: %s", rerr) - } - }() + if !testing.Testing() { + defer func() { + if rerr := recover(); rerr != nil { + fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack()) + err = fmt.Errorf("panic selecting one of protocols: %s", rerr) + } + }() + } if len(protos) == 0 { return "", ErrNoProtocols diff --git a/multistream.go b/multistream.go index 17e1ef7..061f07d 100644 --- a/multistream.go +++ b/multistream.go @@ -11,6 +11,7 @@ import ( "os" "runtime/debug" "sync" + "testing" "github.com/multiformats/go-varint" ) @@ -192,12 +193,14 @@ func (msm *MultistreamMuxer[T]) findHandler(proto T) *Handler[T] { // Negotiate performs protocol selection and returns the protocol name and // the matching handler function for it (or an error). func (msm *MultistreamMuxer[T]) Negotiate(rwc io.ReadWriteCloser) (proto T, handler HandlerFunc[T], err error) { - defer func() { - if rerr := recover(); rerr != nil { - fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack()) - err = fmt.Errorf("panic in multistream negotiation: %s", rerr) - } - }() + if !testing.Testing() { + defer func() { + if rerr := recover(); rerr != nil { + fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack()) + err = fmt.Errorf("panic in multistream negotiation: %s", rerr) + } + }() + } // Send the multistream protocol ID // Ignore the error here. We want the handshake to finish, even if the diff --git a/multistream_test.go b/multistream_test.go index 199dc60..4fe5beb 100644 --- a/multistream_test.go +++ b/multistream_test.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net" + "slices" "sort" "strings" "testing" @@ -802,7 +803,7 @@ func TestNegotiatePeerSendsAndCloses(t *testing.T) { } type rwc struct { - *strings.Reader + strings.Reader } func (*rwc) Write(b []byte) (int, error) { @@ -813,21 +814,68 @@ func (*rwc) Close() error { return nil } -func FuzzMultistream(f *testing.F) { - f.Add("/multistream/1.0.0") - f.Add(ProtocolID) +func FuzzHandler(f *testing.F) { + f.Add("/noise", "/tls", "", "\x13/multistream/1.0.0\n\a/noise\n") - f.Fuzz(func(t *testing.T, b string) { - readStream := strings.NewReader(b) - input := &rwc{readStream} + f.Fuzz(func(t *testing.T, p1, p2, p3, b string) { + input := &rwc{*strings.NewReader(b)} mux := NewMultistreamMuxer[string]() - mux.AddHandler("/a", nil) - mux.AddHandler("/b", nil) + h := func(protocol string, rwc io.ReadWriteCloser) error { + defer rwc.Close() + _, err := io.Copy(io.Discard, rwc) + return err + } + if p1 != "" { + mux.AddHandler(p1, h) + } + if p2 != "" { + mux.AddHandler(p2, h) + } + if p3 != "" { + mux.AddHandler(p3, h) + } _ = mux.Handle(input) }) } +func FuzzSelectOneOf(f *testing.F) { + f.Add("/noise", "/tls", "", "\x13/multistream/1.0.0\n\a/noise\n") + + f.Fuzz(func(t *testing.T, p1, p2, p3, response string) { + protos := make([]string, 0, 3) + if p1 != "" { + protos = append(protos, p1) + } + if p2 != "" { + protos = append(protos, p2) + } + if p3 != "" { + protos = append(protos, p3) + } + + r := &rwc{*strings.NewReader(response)} + + p, err := SelectOneOf(protos, r) + if err != nil { + return + } + if !slices.Contains(protos, p) { + t.Fatal("matched proto which wasn't proposed") + } + }) +} + +func FuzzSelectOrFail(f *testing.F) { + f.Add("/noise", "\x13/multistream/1.0.0\n\a/noise\n") + + f.Fuzz(func(t *testing.T, proto, response string) { + r := &rwc{*strings.NewReader(response)} + + _ = SelectProtoOrFail(proto, r) + }) +} + func TestComparableErrors(t *testing.T) { var err1 error = ErrNotSupported[string]{[]string{"/a"}} if !errors.Is(err1, ErrNotSupported[string]{}) {