Skip to content

Commit

Permalink
rpc: test cors, and respect option on rest endpoints (#1054)
Browse files Browse the repository at this point in the history
* rpc: test cors, and respect option on rest endpoints

* version: report v0.9.2-pre
  • Loading branch information
jchappelow authored Oct 7, 2024
1 parent 83341e9 commit 301c5fa
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 4 deletions.
7 changes: 7 additions & 0 deletions internal/services/jsonrpc/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ func (s *Server) handleMethod(ctx context.Context, method jsonrpc.Method, params

argsPtr, handler := maker(ctx, s)

// Treat omitted params as null, which may or may not be acceptable
// depending on the handler's parameters type. Otherwise json.Unmarshal
// always errors with "unexpected end of JSON input".
if params == nil {
params = []byte(`null`)
}

err := json.Unmarshal(params, argsPtr)
if err != nil {
return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, err.Error(), nil)
Expand Down
12 changes: 9 additions & 3 deletions internal/services/jsonrpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,21 +305,27 @@ func NewServer(addr string, log log.Logger, opts ...Opt) (*Server, error) {
w.Header().Set("content-type", "application/json; charset=utf-8")
http.ServeContent(w, r, "openrpc.json", time.Time{}, bytes.NewReader(s.spec))
})
specHandler = corsHandler(specHandler)
if cfg.enableCORS {
specHandler = corsHandler(specHandler)
}
specHandler = recoverer(specHandler, log)
mux.Handle(pathSpecV1, specHandler)

// aggregate health endpoint handler
var healthHandler http.Handler
healthHandler = http.HandlerFunc(s.healthMethodHandler)
healthHandler = corsHandler(healthHandler)
if cfg.enableCORS {
healthHandler = corsHandler(healthHandler)
}
healthHandler = recoverer(healthHandler, log)
mux.Handle(pathHealthV1, healthHandler)

// service specific health endpoint handler with wild card for service
var userHealthHandler http.Handler
userHealthHandler = http.HandlerFunc(s.handleSvcHealth)
userHealthHandler = corsHandler(userHealthHandler)
if cfg.enableCORS {
userHealthHandler = corsHandler(userHealthHandler)
}
userHealthHandler = recoverer(userHealthHandler, log)
mux.Handle(pathSvcHealthV1, userHealthHandler)

Expand Down
149 changes: 149 additions & 0 deletions internal/services/jsonrpc/server_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package rpcserver

import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"slices"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -69,3 +73,148 @@ func Test_timeout(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, resp.Error.Code, jsonrpc.ErrorTimeout)
}

func Test_options(t *testing.T) {
logger := log.NewStdOut(log.WarnLevel)

const testOrigin = "whoever"

wantCorsHeaders := http.Header{
"Access-Control-Allow-Credentials": {"true"},
"Access-Control-Allow-Headers": {strings.Join([]string{"Accept", "Content-Type", "Content-Length", "Accept-Encoding", "Authorization", "ResponseType", "Range"}, ", ")},
"Access-Control-Allow-Methods": {strings.Join([]string{http.MethodGet, http.MethodPost, http.MethodOptions}, ", ")},
"Access-Control-Allow-Origin": {testOrigin},
}

for _, tt := range []struct {
name string
path string
withcors bool
reqMeth string
expectStatus int
reqBody io.Reader
}{
// JSON-RPC endpoint
{
name: "no cors, options req",
path: pathRPCV1,
withcors: false,
reqMeth: http.MethodOptions,
expectStatus: http.StatusMethodNotAllowed,
},
{
name: "with cors, options req",
path: pathRPCV1,
withcors: true,
reqMeth: http.MethodOptions,
expectStatus: http.StatusOK,
},
{
name: "no cors, get req",
path: pathRPCV1,
withcors: false,
reqMeth: http.MethodGet,
expectStatus: http.StatusMethodNotAllowed,
},
{
name: "with cors, post empty req",
path: pathRPCV1,
withcors: true,
reqMeth: http.MethodPost,
expectStatus: http.StatusBadRequest, // not a jsonrpc req => 400 status code
reqBody: nil,
},
{
name: "with cors, post json req no method",
path: pathRPCV1,
withcors: true,
reqMeth: http.MethodPost,
expectStatus: http.StatusNotFound, // method not found => 404 status code
reqBody: strings.NewReader(`{"jsonrpc":"2.0","id":2,"method":"rpc.nope"}`),
},
{
name: "with cors, post json req valid method",
path: pathRPCV1,
withcors: true,
reqMeth: http.MethodPost,
expectStatus: http.StatusOK, // method not found => 404 status code
reqBody: strings.NewReader(`{"jsonrpc":"2.0","id":2,"method":"rpc.dummy","params":null}`),
},
{
name: "with cors, post json req valid method (no params)",
path: pathRPCV1,
withcors: true,
reqMeth: http.MethodPost,
expectStatus: http.StatusOK, // method not found => 404 status code
reqBody: strings.NewReader(`{"jsonrpc":"2.0","id":2,"method":"rpc.dummy"}`),
},
// REST endpoints
{
name: "no cors, rest options req",
path: pathSpecV1,
withcors: false,
reqMeth: http.MethodOptions,
expectStatus: http.StatusMethodNotAllowed,
},
{
name: "with cors, rest options req",
path: pathSpecV1,
withcors: true,
reqMeth: http.MethodOptions,
expectStatus: http.StatusOK,
},
{
name: "with cors, rest get req",
path: pathSpecV1,
withcors: true,
reqMeth: http.MethodGet,
expectStatus: http.StatusOK,
},
{
name: "with cors, rest health options req",
path: pathHealthV1,
withcors: true,
reqMeth: http.MethodOptions,
expectStatus: http.StatusOK,
},
} {
t.Run(tt.name, func(t *testing.T) {
opts := []Opt{}
if tt.withcors {
opts = append(opts, WithCORS())
}
srv, err := NewServer("127.0.0.1:", logger, opts...)
require.NoError(t, err)

srv.RegisterMethodHandler(
"rpc.dummy",
MakeMethodHandler(func(context.Context, *any) (*json.RawMessage, *jsonrpc.Error) {
respjson := []byte(`"hi"`)
return (*json.RawMessage)(&respjson), nil
}),
)

r := httptest.NewRequest(tt.reqMeth, tt.path, tt.reqBody)
r.Header.Set("origin", testOrigin)
w := httptest.NewRecorder()
srv.srv.Handler.ServeHTTP(w, r)

assert.Equal(t, tt.expectStatus, w.Code)

if tt.withcors && tt.expectStatus == http.StatusOK {
// expect the cors headers fields
rhdr := w.Result().Header
for hk, hvs := range wantCorsHeaders {
vs, have := rhdr[hk]
if !have {
t.Fatalf("missing cors header %v", hk)
}
if !slices.Equal(vs, hvs) {
t.Errorf("different cors headers: got %v, want %v", vs, hvs)
}
}

}
})
}
}
2 changes: 1 addition & 1 deletion internal/version/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
// - 0.6.0+release
// - 0.6.1
// - 0.6.2-alpha0+go1.21.nocgo
const kwilVersion = "0.9.0-pre"
const kwilVersion = "0.9.2-pre" // remove "-pre" for the tagged commit

// KwildVersion may be set at compile time by:
//
Expand Down

0 comments on commit 301c5fa

Please sign in to comment.