diff --git a/routing/http/client/client.go b/routing/http/client/client.go index 741390794..782ddf54e 100644 --- a/routing/http/client/client.go +++ b/routing/http/client/client.go @@ -29,8 +29,15 @@ import ( ) var ( - _ contentrouter.Client = &client{} - logger = logging.Logger("service/delegatedrouting") + _ contentrouter.Client = &client{} + logger = logging.Logger("service/delegatedrouting") + defaultHTTPClient = &http.Client{ + Transport: &ResponseBodyLimitedTransport{ + RoundTripper: http.DefaultTransport, + LimitBytes: 1 << 20, + UserAgent: defaultUserAgent, + }, + } ) const ( @@ -65,21 +72,21 @@ type httpClient interface { Do(req *http.Request) (*http.Response, error) } -type option func(*client) +type Option func(*client) -func WithIdentity(identity crypto.PrivKey) option { +func WithIdentity(identity crypto.PrivKey) Option { return func(c *client) { c.identity = identity } } -func WithHTTPClient(h httpClient) option { +func WithHTTPClient(h httpClient) Option { return func(c *client) { c.httpClient = h } } -func WithUserAgent(ua string) option { +func WithUserAgent(ua string) Option { return func(c *client) { if ua == "" { return @@ -96,7 +103,7 @@ func WithUserAgent(ua string) option { } } -func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) option { +func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) Option { return func(c *client) { c.peerID = peerID for _, a := range addrs { @@ -105,7 +112,7 @@ func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) option { } } -func WithStreamResultsRequired() option { +func WithStreamResultsRequired() Option { return func(c *client) { c.accepts = mediaTypeNDJSON } @@ -113,14 +120,7 @@ func WithStreamResultsRequired() option { // New creates a content routing API client. // The Provider and identity parameters are option. If they are nil, the `Provide` method will not function. -func New(baseURL string, opts ...option) (*client, error) { - defaultHTTPClient := &http.Client{ - Transport: &ResponseBodyLimitedTransport{ - RoundTripper: http.DefaultTransport, - LimitBytes: 1 << 20, - UserAgent: defaultUserAgent, - }, - } +func New(baseURL string, opts ...Option) (*client, error) { client := &client{ baseURL: baseURL, httpClient: defaultHTTPClient, @@ -171,6 +171,7 @@ func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs iter.Res if err != nil { return nil, err } + req.Header.Set("Accept", c.accepts) m.host = req.Host @@ -189,13 +190,14 @@ func (c *client) FindProviders(ctx context.Context, key cid.Cid) (provs iter.Res if resp.StatusCode == http.StatusNotFound { resp.Body.Close() m.record(ctx) - return nil, nil + return iter.FromSlice[iter.Result[types.ProviderResponse]](nil), nil } if resp.StatusCode != http.StatusOK { + err := httpError(resp.StatusCode, resp.Body) resp.Body.Close() m.record(ctx) - return nil, httpError(resp.StatusCode, resp.Body) + return nil, err } respContentType := resp.Header.Get("Content-Type") diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index b82c79f6b..bd79094e2 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -41,45 +41,76 @@ func (m *mockContentRouter) Provide(ctx context.Context, req *server.WriteProvid } type testDeps struct { - router *mockContentRouter - server *httptest.Server - peerID peer.ID - addrs []multiaddr.Multiaddr - client *client + // recordingHandler records requests received on the server side + recordingHandler *recordingHandler + // recordingHTTPClient records responses received on the client side + recordingHTTPClient *recordingHTTPClient + router *mockContentRouter + server *httptest.Server + peerID peer.ID + addrs []multiaddr.Multiaddr + client *client } -func makeTestDeps(t *testing.T) testDeps { +type recordingHandler struct { + http.Handler + f []func(*http.Request) +} + +func (h *recordingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + for _, f := range h.f { + f(r) + } + h.Handler.ServeHTTP(w, r) +} + +type recordingHTTPClient struct { + httpClient + f []func(*http.Response) +} + +func (c *recordingHTTPClient) Do(req *http.Request) (*http.Response, error) { + resp, err := c.httpClient.Do(req) + for _, f := range c.f { + f(resp) + } + return resp, err +} + +func makeTestDeps(t *testing.T, clientsOpts []Option, serverOpts []server.Option) testDeps { const testUserAgent = "testUserAgent" peerID, addrs, identity := makeProviderAndIdentity() router := &mockContentRouter{} - server := httptest.NewServer(server.Handler(router)) + recordingHandler := &recordingHandler{ + Handler: server.Handler(router, serverOpts...), + f: []func(*http.Request){ + func(r *http.Request) { + assert.Equal(t, testUserAgent, r.Header.Get("User-Agent")) + }, + }, + } + server := httptest.NewServer(recordingHandler) t.Cleanup(server.Close) serverAddr := "http://" + server.Listener.Addr().String() - c, err := New(serverAddr, WithProviderInfo(peerID, addrs), WithIdentity(identity), WithUserAgent(testUserAgent)) + recordingHTTPClient := &recordingHTTPClient{httpClient: defaultHTTPClient} + defaultClientOpts := []Option{ + WithProviderInfo(peerID, addrs), + WithIdentity(identity), + WithUserAgent(testUserAgent), + WithHTTPClient(recordingHTTPClient), + } + c, err := New(serverAddr, append(defaultClientOpts, clientsOpts...)...) if err != nil { panic(err) } - assertUserAgentOverride(t, c, testUserAgent) return testDeps{ - router: router, - server: server, - peerID: peerID, - addrs: addrs, - client: c, - } -} - -func assertUserAgentOverride(t *testing.T, c *client, expected string) { - httpClient, ok := c.httpClient.(*http.Client) - if !ok { - t.Error("invalid c.httpClient") - } - transport, ok := httpClient.Transport.(*ResponseBodyLimitedTransport) - if !ok { - t.Error("invalid httpClient.Transport") - } - if transport.UserAgent != expected { - t.Error("invalid httpClient.Transport.UserAgent") + recordingHandler: recordingHandler, + recordingHTTPClient: recordingHTTPClient, + router: router, + server: server, + peerID: peerID, + addrs: addrs, + client: c, } } @@ -149,6 +180,10 @@ type osErrContains struct { } func (e *osErrContains) errContains(t *testing.T, err error) { + if e.expContains == "" && e.expContainsWin == "" { + assert.NoError(t, err) + return + } if runtime.GOOS == "windows" && len(e.expContainsWin) != 0 { assert.ErrorContains(t, err, e.expContainsWin) } else { @@ -163,37 +198,90 @@ func TestClient_FindProviders(t *testing.T) { } cases := []struct { - name string - httpStatusCode int - stopServer bool - routerProvs []iter.Result[types.ProviderResponse] - routerErr error - - expProvs []iter.Result[types.ProviderResponse] - expErrContains []osErrContains + name string + httpStatusCode int + stopServer bool + routerProvs []iter.Result[types.ProviderResponse] + routerErr error + clientRequiresStreaming bool + serverStreamingDisabled bool + + expErrContains osErrContains + expProvs []iter.Result[types.ProviderResponse] + expStreamingResponse bool + expJSONResponse bool }{ { - name: "happy case", - routerProvs: bitswapProvs, - expProvs: bitswapProvs, + name: "happy case", + routerProvs: bitswapProvs, + expProvs: bitswapProvs, + expStreamingResponse: true, + }, + { + name: "server doesn't support streaming", + routerProvs: bitswapProvs, + expProvs: bitswapProvs, + serverStreamingDisabled: true, + expJSONResponse: true, + }, + { + name: "client requires streaming but server doesn't support it", + serverStreamingDisabled: true, + clientRequiresStreaming: true, + expErrContains: osErrContains{expContains: "HTTP error with StatusCode=400: no supported content types"}, }, { name: "returns an error if there's a non-200 response", httpStatusCode: 500, - expErrContains: []osErrContains{{expContains: "HTTP error with StatusCode=500: "}}, + expErrContains: osErrContains{expContains: "HTTP error with StatusCode=500"}, }, { name: "returns an error if the HTTP client returns a non-HTTP error", stopServer: true, - expErrContains: []osErrContains{{ + expErrContains: osErrContains{ expContains: "connect: connection refused", expContainsWin: "connectex: No connection could be made because the target machine actively refused it.", - }}, + }, + }, + { + name: "returns no providers if the HTTP server returns a 404 respones", + httpStatusCode: 404, + expProvs: nil, }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - deps := makeTestDeps(t) + var clientOpts []Option + var serverOpts []server.Option + var onRespReceived []func(*http.Response) + var onReqReceived []func(*http.Request) + + if c.serverStreamingDisabled { + serverOpts = append(serverOpts, server.WithStreamingResultsDisabled()) + } + if c.clientRequiresStreaming { + clientOpts = append(clientOpts, WithStreamResultsRequired()) + onReqReceived = append(onReqReceived, func(r *http.Request) { + assert.Equal(t, mediaTypeNDJSON, r.Header.Get("Accept")) + }) + } + + if c.expStreamingResponse { + onRespReceived = append(onRespReceived, func(r *http.Response) { + assert.Equal(t, mediaTypeNDJSON, r.Header.Get("Content-Type")) + }) + } + if c.expJSONResponse { + onRespReceived = append(onRespReceived, func(r *http.Response) { + assert.Equal(t, mediaTypeJSON, r.Header.Get("Content-Type")) + }) + } + + deps := makeTestDeps(t, clientOpts, serverOpts) + + deps.recordingHTTPClient.f = append(deps.recordingHTTPClient.f, onRespReceived...) + deps.recordingHandler.f = append(deps.recordingHandler.f, onReqReceived...) + client := deps.client router := deps.router @@ -218,12 +306,7 @@ func TestClient_FindProviders(t *testing.T) { provsIter, err := client.FindProviders(ctx, cid) - for _, exp := range c.expErrContains { - exp.errContains(t, err) - } - if len(c.expErrContains) == 0 { - require.NoError(t, err) - } + c.expErrContains.errContains(t, err) provs := iter.ReadAll[iter.Result[types.ProviderResponse]](provsIter) assert.Equal(t, c.expProvs, provs) @@ -263,8 +346,7 @@ func TestClient_Provide(t *testing.T) { name: "should return a 403 if the payload signature verification fails", cids: []cid.Cid{}, mangleSignature: true, - - expErrContains: "HTTP error with StatusCode=403", + expErrContains: "HTTP error with StatusCode=403", }, { name: "should return error if identity is not provided", @@ -290,7 +372,7 @@ func TestClient_Provide(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - deps := makeTestDeps(t) + deps := makeTestDeps(t, nil, nil) client := deps.client router := deps.router diff --git a/routing/http/server/server.go b/routing/http/server/server.go index f0635f590..814c6733e 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -9,6 +9,7 @@ import ( "io" "mime" "net/http" + "strings" "time" "github.com/gorilla/mux" @@ -59,16 +60,16 @@ type WriteProvideRequest struct { Bytes []byte } -type serverOption func(s *server) +type Option func(s *server) // WithStreamingResultsDisabled disables ndjson responses, so that the server only supports JSON responses. -func WithStreamingResultsDisabled() serverOption { +func WithStreamingResultsDisabled() Option { return func(s *server) { s.disableNDJSON = true } } -func Handler(svc ContentRouter, opts ...serverOption) http.Handler { +func Handler(svc ContentRouter, opts ...Option) http.Handler { server := &server{ svc: svc, } @@ -169,22 +170,24 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { var supportsNDJSON bool var supportsJSON bool - accepts := httpReq.Header.Values("Accept") - if len(accepts) == 0 { + acceptHeaders := httpReq.Header.Values("Accept") + if len(acceptHeaders) == 0 { handlerFunc = s.findProvidersJSON } else { - for _, accept := range accepts { - mediaType, _, err := mime.ParseMediaType(accept) - if err != nil { - writeErr(w, "FindProviders", http.StatusBadRequest, fmt.Errorf("unable to parse Accept header: %w", err)) - return - } - - switch mediaType { - case mediaTypeJSON, mediaTypeWildcard: - supportsJSON = true - case mediaTypeNDJSON: - supportsNDJSON = true + for _, acceptHeader := range acceptHeaders { + for _, accept := range strings.Split(acceptHeader, ",") { + mediaType, _, err := mime.ParseMediaType(accept) + if err != nil { + writeErr(w, "FindProviders", http.StatusBadRequest, fmt.Errorf("unable to parse Accept header: %w", err)) + return + } + + switch mediaType { + case mediaTypeJSON, mediaTypeWildcard: + supportsJSON = true + case mediaTypeNDJSON: + supportsNDJSON = true + } } } diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index 842cf95d1..17ac3cace 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -42,6 +42,7 @@ func TestHeaders(t *testing.T) { resp, err = http.Get(serverAddr + ProvidePath + "BAD_CID") require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, 400, resp.StatusCode) header = resp.Header.Get("Content-Type") require.Equal(t, "text/plain; charset=utf-8", header) diff --git a/routing/http/types/iter/iter.go b/routing/http/types/iter/iter.go index 37ca634b7..67c6dde00 100644 --- a/routing/http/types/iter/iter.go +++ b/routing/http/types/iter/iter.go @@ -37,6 +37,7 @@ func ReadAll[T any](iter Iter[T]) []T { if iter == nil { return nil } + defer iter.Close() var vs []T for iter.Next() { vs = append(vs, iter.Val())