Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pass records limit on routing.FindProviders #299

Merged
merged 3 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions routing/http/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import (

type mockContentRouter struct{ mock.Mock }

func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) {
args := m.Called(ctx, key)
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.ProviderResponse], error) {
args := m.Called(ctx, key, limit)
return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1)
}
func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) {
Expand Down Expand Up @@ -302,8 +302,11 @@ func TestClient_FindProviders(t *testing.T) {

findProvsIter := iter.FromSlice(c.routerProvs)

router.On("FindProviders", mock.Anything, cid).
Return(findProvsIter, c.routerErr)
if c.expStreamingResponse {
router.On("FindProviders", mock.Anything, cid, 0).Return(findProvsIter, c.routerErr)
} else {
router.On("FindProviders", mock.Anything, cid, 20).Return(findProvsIter, c.routerErr)
}

provsIter, err := client.FindProviders(ctx, cid)

Expand Down
39 changes: 34 additions & 5 deletions routing/http/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ const (
mediaTypeJSON = "application/json"
mediaTypeNDJSON = "application/x-ndjson"
mediaTypeWildcard = "*/*"

DefaultRecordsLimit = 20
DefaultStreamingRecordsLimit = 0
)

var logger = logging.Logger("service/server/delegatedrouting")
Expand All @@ -41,7 +44,9 @@ type FindProvidersAsyncResponse struct {
}

type ContentRouter interface {
FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error)
// FindProviders searches for peers who are able to provide a given key. Limit
// indicates the maximum amount of results to return. 0 means unbounded.
FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.ProviderResponse], error)
ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error)
Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error)
}
Expand Down Expand Up @@ -69,9 +74,27 @@ func WithStreamingResultsDisabled() Option {
}
}

// WithRecordsLimit sets a limit that will be passed to ContentRouter.FindProviders
// for non-streaming requests (application/json). Default is DefaultRecordsLimit.
func WithRecordsLimit(limit int) Option {
return func(s *server) {
s.recordsLimit = limit
}
}

// WithStreamingRecordsLimit sets a limit that will be passed to ContentRouter.FindProviders
// for streaming requests (application/x-ndjson). Default is DefaultStreamingRecordsLimit.
func WithStreamingRecordsLimit(limit int) Option {
return func(s *server) {
s.streamingRecordsLimit = limit
}
}

func Handler(svc ContentRouter, opts ...Option) http.Handler {
server := &server{
svc: svc,
svc: svc,
recordsLimit: DefaultRecordsLimit,
streamingRecordsLimit: DefaultStreamingRecordsLimit,
}

for _, opt := range opts {
Expand All @@ -86,8 +109,10 @@ func Handler(svc ContentRouter, opts ...Option) http.Handler {
}

type server struct {
svc ContentRouter
disableNDJSON bool
svc ContentRouter
disableNDJSON bool
recordsLimit int
streamingRecordsLimit int
}

func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) {
Expand Down Expand Up @@ -170,9 +195,11 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {

var supportsNDJSON bool
var supportsJSON bool
var recordsLimit int
acceptHeaders := httpReq.Header.Values("Accept")
if len(acceptHeaders) == 0 {
handlerFunc = s.findProvidersJSON
recordsLimit = s.recordsLimit
} else {
for _, acceptHeader := range acceptHeaders {
for _, accept := range strings.Split(acceptHeader, ",") {
Expand All @@ -193,15 +220,17 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {

if supportsNDJSON && !s.disableNDJSON {
handlerFunc = s.findProvidersNDJSON
recordsLimit = s.streamingRecordsLimit
} else if supportsJSON {
handlerFunc = s.findProvidersJSON
recordsLimit = s.recordsLimit
} else {
writeErr(w, "FindProviders", http.StatusBadRequest, errors.New("no supported content types"))
return
}
}

provIter, err := s.svc.FindProviders(httpReq.Context(), cid)
provIter, err := s.svc.FindProviders(httpReq.Context(), cid, recordsLimit)
if err != nil {
writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err))
return
Expand Down
20 changes: 12 additions & 8 deletions routing/http/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestHeaders(t *testing.T) {
cb, err := cid.Decode(c)
require.NoError(t, err)

router.On("FindProviders", mock.Anything, cb).
router.On("FindProviders", mock.Anything, cb, DefaultRecordsLimit).
Return(results, nil)

resp, err := http.Get(serverAddr + ProvidePath + c)
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestResponse(t *testing.T) {
cid, err := cid.Decode(cidStr)
require.NoError(t, err)

runTest := func(t *testing.T, contentType string, expected string) {
runTest := func(t *testing.T, contentType string, expectedStream bool, expectedBody string) {
t.Parallel()

results := iter.FromSlice([]iter.Result[types.ProviderResponse]{
Expand All @@ -85,7 +85,11 @@ func TestResponse(t *testing.T) {
server := httptest.NewServer(Handler(router))
t.Cleanup(server.Close)
serverAddr := "http://" + server.Listener.Addr().String()
router.On("FindProviders", mock.Anything, cid).Return(results, nil)
limit := DefaultRecordsLimit
if expectedStream {
limit = DefaultStreamingRecordsLimit
}
router.On("FindProviders", mock.Anything, cid, limit).Return(results, nil)
urlStr := serverAddr + ProvidePath + cidStr

req, err := http.NewRequest(http.MethodGet, urlStr, nil)
Expand All @@ -101,22 +105,22 @@ func TestResponse(t *testing.T) {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)

require.Equal(t, string(body), expected)
require.Equal(t, string(body), expectedBody)
}

t.Run("JSON Response", func(t *testing.T) {
runTest(t, mediaTypeJSON, `{"Providers":[{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]},{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Addrs":[]}]}`)
runTest(t, mediaTypeJSON, false, `{"Providers":[{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]},{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Addrs":[]}]}`)
})

t.Run("NDJSON Response", func(t *testing.T) {
runTest(t, mediaTypeNDJSON, `{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]}`+"\n"+`{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Addrs":[]}`+"\n")
runTest(t, mediaTypeNDJSON, true, `{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]}`+"\n"+`{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Addrs":[]}`+"\n")
})
}

type mockContentRouter struct{ mock.Mock }

func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) {
args := m.Called(ctx, key)
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.ProviderResponse], error) {
args := m.Called(ctx, key, limit)
return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1)
}
func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) {
Expand Down