diff --git a/routing/http/server/server.go b/routing/http/server/server.go index 8ce7d063b..cd8f2d005 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -254,6 +254,13 @@ func (s *server) findProvidersNDJSON(w http.ResponseWriter, provIter iter.Result logger.Warn("FindProviders ndjson write error", "Error", err) return } + + _, err = w.Write([]byte("\n")) + if err != nil { + logger.Warn("FindProviders ndjson write error", "Error", err) + return + } + if f, ok := w.(http.Flusher); ok { f.Flush() } diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index fec5eaf9a..380846801 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "io" "net/http" "net/http/httptest" "testing" @@ -10,6 +11,7 @@ import ( "github.com/ipfs/boxo/routing/http/types" "github.com/ipfs/boxo/routing/http/types/iter" "github.com/ipfs/go-cid" + "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -48,6 +50,72 @@ func TestHeaders(t *testing.T) { require.Equal(t, "text/plain; charset=utf-8", header) } +func TestResponse(t *testing.T) { + router := &mockContentRouter{} + server := httptest.NewServer(Handler(router)) + t.Cleanup(server.Close) + serverAddr := "http://" + server.Listener.Addr().String() + + pidStr := "12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn" + cidStr := "bafkreifjjcie6lypi6ny7amxnfftagclbuxndqonfipmb64f2km2devei4" + + pid, err := peer.Decode(pidStr) + require.NoError(t, err) + + cid, err := cid.Decode(cidStr) + require.NoError(t, err) + + results := iter.FromSlice([]iter.Result[types.ProviderResponse]{ + {Val: &types.ReadBitswapProviderRecord{ + Protocol: "transport-bitswap", + Schema: types.SchemaBitswap, + ID: &pid, + Addrs: []types.Multiaddr{}, + }}}, + ) + + router.On("FindProviders", mock.Anything, cid, mock.Anything).Return(results, nil) + urlStr := serverAddr + ProvidePath + cidStr + + t.Run("JSON Response", func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + require.NoError(t, err) + req.Header.Set("Accept", mediaTypeJSON) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + header := resp.Header.Get("Content-Type") + require.Equal(t, mediaTypeJSON, header) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Equal(t, string(body), `{"Providers":[{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]}]}`) + }) + + t.Run("NDJSON Response", func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + require.NoError(t, err) + req.Header.Set("Accept", mediaTypeNDJSON) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + header := resp.Header.Get("Content-Type") + require.Equal(t, mediaTypeNDJSON, header) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Equal(t, string(body), `{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]}`+"\n") + }) +} + type mockContentRouter struct{ mock.Mock } func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) {