Skip to content

Commit

Permalink
httptransport: fix content negotiation
Browse files Browse the repository at this point in the history
This fixes the handling of the "accept" header to process lists and "q"
parameters properly, and moves it to a common helper function.

Closes: #1441
Signed-off-by: Hank Donnay <hdonnay@redhat.com>
  • Loading branch information
hdonnay committed Dec 8, 2021
1 parent 11cb491 commit d3cbb5a
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 28 deletions.
96 changes: 96 additions & 0 deletions httptransport/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package httptransport

import (
"errors"
"fmt"
"mime"
"net/http"
"sort"
"strconv"
"strings"
)

// PickContentType sets the response's "Content-Type" header.
//
// If "Accept" headers are not present in the request, the first element of the
// "allow" slice is used.
//
// If "Accept" headers are present, the first (ordered by "q" value) media type
// in the "allow" slice is chosen. If there are no common media types, "415
// Unsupported Media Type" is written and ErrMediaType is reported.
func pickContentType(w http.ResponseWriter, r *http.Request, allow []string) error {
// There's no canonical algorithm for this, it's all server-dependent
// behavior. Our algorithm is:
//
// - Parse the Accept header(s) as MIME media types joined by commas.
// - Stable sort according to the "q" parameter, defaulting to 1.0 if
// omitted (as specified)
// - Pick the first match.
//
// BUG(hank) Content type negotiation does an O(n*m) comparison driven on
// user input, which may be a DoS issue.
as, ok := r.Header["Accept"]
if !ok {
w.Header().Set("content-type", allow[0])
return nil
}
var acceptable []accept
for _, part := range as {
for _, s := range strings.Split(part, ",") {
a := accept{}
mt, p, err := mime.ParseMediaType(strings.TrimSpace(s))
if err != nil {
return err
}
a.Q = 1.0
if qs, ok := p["q"]; ok {
a.Q, _ = strconv.ParseFloat(qs, 64)
}
typ := strings.Split(mt, "/")
a.Type = typ[0]
a.Subtype = typ[1]
acceptable = append(acceptable, a)
}
}
if len(acceptable) == 0 {
w.Header().Set("content-type", allow[0])
return nil
}
sort.SliceStable(acceptable, func(i, j int) bool { return acceptable[i].Q > acceptable[j].Q })
for _, l := range acceptable {
for _, a := range allow {
if l.Match(a) {
w.Header().Set("content-type", a)
return nil
}
}
}
w.WriteHeader(http.StatusUnsupportedMediaType)
return ErrMediaType
}

// ErrMediaType is returned if no common media types can be found for a given
// request.
var ErrMediaType = errors.New("no common media type")

type accept struct {
Type, Subtype string
Q float64
}

// Match reports whether the type in the "accept" struct matches the provided
// media type, honoring wildcards.
//
// Match panics if the provided media type is not well-formed.
func (a *accept) Match(mt string) bool {
if a.Type == "*" && a.Subtype == "*" {
return true
}
i := strings.IndexByte(mt, '/')
if i == -1 {
// Programmer error -- inputs to this function should be static strings.
panic(fmt.Sprintf("bad media type: %q", mt))
}
t, s := mt[:i], mt[i+1:]
return a.Type == t && (a.Subtype == s || a.Subtype == "*")
}
42 changes: 17 additions & 25 deletions httptransport/discoveryhandler.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
package httptransport

import (
"errors"
"io"
"net/http"
"strings"

je "github.com/quay/claircore/pkg/jsonerr"
)

//go:generate go run openapigen.go

var okCT = map[string]string{
"application/vnd.oai.openapi+json": "application/vnd.oai.openapi+json",
"application/json": "application/json",
"application/*": "application/vnd.oai.openapi+json",
"*/*": "application/vnd.oai.openapi+json",
}

// DiscoveryHandler serves the embedded OpenAPI spec.
func DiscoveryHandler() http.Handler {
allow := []string{`application/json`, `application/vnd.oai.openapi+json`}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
resp := &je.Response{
Expand All @@ -27,26 +23,22 @@ func DiscoveryHandler() http.Handler {
je.Error(w, resp, http.StatusMethodNotAllowed)
return
}
w.Header().Set("content-type", okCT["*/*"])
// Add a gate so that requesters expecting the yaml version get some
// sort of error.
if as, ok := r.Header["Accept"]; ok {
bail := true
for _, a := range as {
if ct, ok := okCT[a]; ok {
w.Header().Set("content-type", ct)
bail = false
break
}
switch err := pickContentType(w, r, allow); {
case errors.Is(err, nil):
case errors.Is(err, ErrMediaType):
resp := &je.Response{
Code: "unknown accept type",
Message: "endpoint only allows " + strings.Join(allow, " or "),
}
if bail {
resp := &je.Response{
Code: "unknown accept type",
Message: "endpoint only allows application/json or application/vnd.oai.openapi+json",
}
je.Error(w, resp, http.StatusBadRequest)
return
je.Error(w, resp, http.StatusUnsupportedMediaType)
return
default:
resp := &je.Response{
Code: "unknown other error",
Message: err.Error(),
}
je.Error(w, resp, http.StatusBadRequest)
return
}
w.Header().Set("etag", _openapiJSONEtag)
var err error
Expand Down
9 changes: 6 additions & 3 deletions httptransport/discoveryhandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@ func TestDiscoveryEndpoint(t *testing.T) {

r := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/openapi/v1", nil)
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept", "application/yaml, application/json; q=0.4, application/vnd.oai.openapi+json; q=1.0")
h.ServeHTTP(r, req)

resp := r.Result()
if resp.StatusCode != http.StatusOK {
t.Fatalf("got status code: %v want status code: %v", resp.StatusCode, http.StatusOK)
}
if got, want := resp.Header.Get("content-type"), "application/vnd.oai.openapi+json"; got != want {
t.Errorf("got: %q, want: %q", got, want)
}

buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
Expand Down Expand Up @@ -55,8 +58,8 @@ func TestDiscoveryFailure(t *testing.T) {
h.ServeHTTP(r, req)

resp := r.Result()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("got status code: %v want status code: %v", resp.StatusCode, http.StatusBadRequest)
if got, want := resp.StatusCode, http.StatusUnsupportedMediaType; got != want {
t.Fatalf("got status code: %v want status code: %v", got, want)
}
}

Expand Down

0 comments on commit d3cbb5a

Please sign in to comment.