Skip to content

Commit

Permalink
feat(routing/http): pass records limit on routing.FindProviders (#299)
Browse files Browse the repository at this point in the history
* feat: indicate if response will be streamable on routing.FindProviders
* refactor: change FindProviders to use "limit int" instead of "stream bool"
  • Loading branch information
hacdias authored May 25, 2023
1 parent 1daddd8 commit a8533c9
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 17 deletions.
11 changes: 7 additions & 4 deletions routing/http/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import (

type mockContentRouter struct{ mock.Mock }

func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) {
args := m.Called(ctx, key)
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.ProviderResponse], error) {
args := m.Called(ctx, key, limit)
return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1)
}
func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) {
Expand Down Expand Up @@ -302,8 +302,11 @@ func TestClient_FindProviders(t *testing.T) {

findProvsIter := iter.FromSlice(c.routerProvs)

router.On("FindProviders", mock.Anything, cid).
Return(findProvsIter, c.routerErr)
if c.expStreamingResponse {
router.On("FindProviders", mock.Anything, cid, 0).Return(findProvsIter, c.routerErr)
} else {
router.On("FindProviders", mock.Anything, cid, 20).Return(findProvsIter, c.routerErr)
}

provsIter, err := client.FindProviders(ctx, cid)

Expand Down
39 changes: 34 additions & 5 deletions routing/http/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ const (
mediaTypeJSON = "application/json"
mediaTypeNDJSON = "application/x-ndjson"
mediaTypeWildcard = "*/*"

DefaultRecordsLimit = 20
DefaultStreamingRecordsLimit = 0
)

var logger = logging.Logger("service/server/delegatedrouting")
Expand All @@ -41,7 +44,9 @@ type FindProvidersAsyncResponse struct {
}

type ContentRouter interface {
FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error)
// FindProviders searches for peers who are able to provide a given key. Limit
// indicates the maximum amount of results to return. 0 means unbounded.
FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.ProviderResponse], error)
ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error)
Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error)
}
Expand Down Expand Up @@ -69,9 +74,27 @@ func WithStreamingResultsDisabled() Option {
}
}

// WithRecordsLimit sets a limit that will be passed to ContentRouter.FindProviders
// for non-streaming requests (application/json). Default is DefaultRecordsLimit.
func WithRecordsLimit(limit int) Option {
return func(s *server) {
s.recordsLimit = limit
}
}

// WithStreamingRecordsLimit sets a limit that will be passed to ContentRouter.FindProviders
// for streaming requests (application/x-ndjson). Default is DefaultStreamingRecordsLimit.
func WithStreamingRecordsLimit(limit int) Option {
return func(s *server) {
s.streamingRecordsLimit = limit
}
}

func Handler(svc ContentRouter, opts ...Option) http.Handler {
server := &server{
svc: svc,
svc: svc,
recordsLimit: DefaultRecordsLimit,
streamingRecordsLimit: DefaultStreamingRecordsLimit,
}

for _, opt := range opts {
Expand All @@ -86,8 +109,10 @@ func Handler(svc ContentRouter, opts ...Option) http.Handler {
}

type server struct {
svc ContentRouter
disableNDJSON bool
svc ContentRouter
disableNDJSON bool
recordsLimit int
streamingRecordsLimit int
}

func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) {
Expand Down Expand Up @@ -170,9 +195,11 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {

var supportsNDJSON bool
var supportsJSON bool
var recordsLimit int
acceptHeaders := httpReq.Header.Values("Accept")
if len(acceptHeaders) == 0 {
handlerFunc = s.findProvidersJSON
recordsLimit = s.recordsLimit
} else {
for _, acceptHeader := range acceptHeaders {
for _, accept := range strings.Split(acceptHeader, ",") {
Expand All @@ -193,15 +220,17 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {

if supportsNDJSON && !s.disableNDJSON {
handlerFunc = s.findProvidersNDJSON
recordsLimit = s.streamingRecordsLimit
} else if supportsJSON {
handlerFunc = s.findProvidersJSON
recordsLimit = s.recordsLimit
} else {
writeErr(w, "FindProviders", http.StatusBadRequest, errors.New("no supported content types"))
return
}
}

provIter, err := s.svc.FindProviders(httpReq.Context(), cid)
provIter, err := s.svc.FindProviders(httpReq.Context(), cid, recordsLimit)
if err != nil {
writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err))
return
Expand Down
20 changes: 12 additions & 8 deletions routing/http/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestHeaders(t *testing.T) {
cb, err := cid.Decode(c)
require.NoError(t, err)

router.On("FindProviders", mock.Anything, cb).
router.On("FindProviders", mock.Anything, cb, DefaultRecordsLimit).
Return(results, nil)

resp, err := http.Get(serverAddr + ProvidePath + c)
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestResponse(t *testing.T) {
cid, err := cid.Decode(cidStr)
require.NoError(t, err)

runTest := func(t *testing.T, contentType string, expected string) {
runTest := func(t *testing.T, contentType string, expectedStream bool, expectedBody string) {
t.Parallel()

results := iter.FromSlice([]iter.Result[types.ProviderResponse]{
Expand All @@ -85,7 +85,11 @@ func TestResponse(t *testing.T) {
server := httptest.NewServer(Handler(router))
t.Cleanup(server.Close)
serverAddr := "http://" + server.Listener.Addr().String()
router.On("FindProviders", mock.Anything, cid).Return(results, nil)
limit := DefaultRecordsLimit
if expectedStream {
limit = DefaultStreamingRecordsLimit
}
router.On("FindProviders", mock.Anything, cid, limit).Return(results, nil)
urlStr := serverAddr + ProvidePath + cidStr

req, err := http.NewRequest(http.MethodGet, urlStr, nil)
Expand All @@ -101,22 +105,22 @@ func TestResponse(t *testing.T) {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)

require.Equal(t, string(body), expected)
require.Equal(t, string(body), expectedBody)
}

t.Run("JSON Response", func(t *testing.T) {
runTest(t, mediaTypeJSON, `{"Providers":[{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]},{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Addrs":[]}]}`)
runTest(t, mediaTypeJSON, false, `{"Providers":[{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]},{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Addrs":[]}]}`)
})

t.Run("NDJSON Response", func(t *testing.T) {
runTest(t, mediaTypeNDJSON, `{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]}`+"\n"+`{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Addrs":[]}`+"\n")
runTest(t, mediaTypeNDJSON, true, `{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]}`+"\n"+`{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Addrs":[]}`+"\n")
})
}

type mockContentRouter struct{ mock.Mock }

func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) {
args := m.Called(ctx, key)
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.ProviderResponse], error) {
args := m.Called(ctx, key, limit)
return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1)
}
func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) {
Expand Down

0 comments on commit a8533c9

Please sign in to comment.