Skip to content

Commit

Permalink
upstream validation: implementation, first try (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Koch committed Oct 8, 2020
1 parent 7b829c3 commit 60764c7
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 6 deletions.
15 changes: 15 additions & 0 deletions config/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ type Backend struct {
Path string `hcl:"path,optional"`
Timeout string `hcl:"timeout,optional"`
TTFBTimeout string `hcl:"ttfb_timeout,optional"`
SwaggerDef string `hcl:"swagger_definition,optional"`
ValidateReq bool `hcl:"validate_request,optional"`
ValidateRes bool `hcl:"validate_response,optional"`
}

// Merge overrides the left backend configuration and returns a new instance.
Expand Down Expand Up @@ -62,5 +65,17 @@ func (b *Backend) Merge(other *Backend) (*Backend, []hcl.Body) {
result.TTFBTimeout = other.TTFBTimeout
}

if other.SwaggerDef != "" {
result.SwaggerDef = other.SwaggerDef
}

if other.ValidateReq {
result.ValidateReq = other.ValidateReq
}

if other.ValidateRes {
result.ValidateRes = other.ValidateRes
}

return &result, bodies
}
12 changes: 12 additions & 0 deletions config/runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ func BuildEntrypointHandlers(conf *config.Gateway, httpConf *HTTPConfig, log *lo
Path: beConf.Path,
Timeout: t,
TTFBTimeout: ttfbt,
SwaggerDef: beConf.SwaggerDef,
ValidateReq: beConf.ValidateReq,
ValidateRes: beConf.ValidateRes,
}, log, conf.Context)
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -193,6 +196,9 @@ func BuildEntrypointHandlers(conf *config.Gateway, httpConf *HTTPConfig, log *lo
Path: beConf.Path,
Timeout: t,
TTFBTimeout: ttfbt,
SwaggerDef: beConf.SwaggerDef,
ValidateReq: beConf.ValidateReq,
ValidateRes: beConf.ValidateRes,
}, log, conf.Context)
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -262,6 +268,9 @@ func BuildEntrypointHandlers(conf *config.Gateway, httpConf *HTTPConfig, log *lo
Path: beConf.Path,
Timeout: t,
TTFBTimeout: ttfbt,
SwaggerDef: beConf.SwaggerDef,
ValidateReq: beConf.ValidateReq,
ValidateRes: beConf.ValidateRes,
}, log, conf.Context)
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -479,6 +488,9 @@ func newInlineBackend(evalCtx *hcl.EvalContext, inlineDef hcl.Body, cors *config
Path: beConf.Path,
Timeout: t,
TTFBTimeout: ttfbt,
SwaggerDef: beConf.SwaggerDef,
ValidateReq: beConf.ValidateReq,
ValidateRes: beConf.ValidateRes,
}, log, evalCtx)
return proxy, beConf, err
}
Expand Down
10 changes: 10 additions & 0 deletions errors/code.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ const (
BasicAuthFailed
)

const (
UpstreamRequestValidationFailed Code = 6000 + iota
UpstreamResponseValidationFailed
UpstreamResponseBufferingFailed
)

var codes = map[Code]string{
// 1xxx
Server: "Server error",
Expand All @@ -51,6 +57,10 @@ var codes = map[Code]string{
AuthorizationRequired: "Authorization required",
AuthorizationFailed: "Authorization failed",
BasicAuthFailed: "Unauthorized",
// 6xxx
UpstreamRequestValidationFailed: "Upstream request validation failed",
UpstreamResponseValidationFailed: "Upstream response validation failed",
UpstreamResponseBufferingFailed: "Upstream response buffering failed",
}

type Code int
Expand Down
95 changes: 89 additions & 6 deletions handler/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"

"github.com/getkin/kin-openapi/openapi3filter"
"github.com/hashicorp/hcl/v2"
"github.com/sirupsen/logrus"
"golang.org/x/net/http/httpguts"
Expand Down Expand Up @@ -52,7 +55,8 @@ type ProxyOptions struct {
ConnectTimeout, Timeout, TTFBTimeout time.Duration
Context []hcl.Body
BackendName string
Hostname, Origin, Path string
Hostname, Origin, Path, SwaggerDef string
ValidateReq, ValidateRes bool
CORS *CORSOptions
}

Expand Down Expand Up @@ -156,6 +160,44 @@ func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
p.upstreamLog.ServeHTTP(rw, req, logging.RoundtripHandlerFunc(p.roundtrip))
}

func (p *Proxy) preparetRequestValidatation(outreq *http.Request) (context.Context, *openapi3filter.Route, *openapi3filter.RequestValidationInput, error) {
if p.options.ValidateReq || p.options.ValidateRes {
dir, err := os.Getwd()
if err != nil {
return nil, nil, nil, err
}
router := openapi3filter.NewRouter().WithSwaggerFromFile(dir + "/" + p.options.SwaggerDef)
validationCtx := context.Background()
route, pathParams, _ := router.FindRoute(outreq.Method, outreq.URL)

requestValidationInput := &openapi3filter.RequestValidationInput{
Request: outreq,
PathParams: pathParams,
Route: route,
}
return validationCtx, route, requestValidationInput, nil
}
return nil, nil, nil, nil
}

func (p *Proxy) prepareResponseValidatation(requestValidationInput *openapi3filter.RequestValidationInput, res *http.Response) (*openapi3filter.ResponseValidationInput, []byte, error) {
if p.options.ValidateRes {
responseValidationInput := &openapi3filter.ResponseValidationInput{
RequestValidationInput: requestValidationInput,
Status: res.StatusCode,
Header: res.Header,
Options: &openapi3filter.Options{IncludeResponseStatus: true /* undefined response codes are invalid */},
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, nil, err
}
responseValidationInput.SetBodyBytes(body)
return responseValidationInput, body, nil
}
return nil, nil, nil
}

func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {
ctx := req.Context()
if p.options.Timeout > 0 {
Expand Down Expand Up @@ -203,6 +245,23 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {
outreq.Header.Set("X-Forwarded-For", clientIP)
}

validationCtx, route, requestValidationInput, err := p.preparetRequestValidatation(outreq)
if err != nil {
// this only happens if os.Getwd() fails
// TODO: use error template from parent endpoint>api>server
p.log.WithField("upstream request validation", err).Error()
couperErr.DefaultJSON.ServeError(couperErr.UpstreamRequestValidationFailed).ServeHTTP(rw, req)
return
}
if (p.options.ValidateReq) {
if err := openapi3filter.ValidateRequest(validationCtx, requestValidationInput); err != nil {
// TODO: use error template from parent endpoint>api>server
p.log.WithField("upstream request validation", err).Error()
couperErr.DefaultJSON.ServeError(couperErr.UpstreamRequestValidationFailed).ServeHTTP(rw, req)
return
}
}

res, err := p.transport.RoundTrip(outreq)
roundtripInfo := req.Context().Value(request.RoundtripInfo).(*logging.RoundtripInfo)
roundtripInfo.BeReq, roundtripInfo.BeResp, roundtripInfo.Err = outreq, res, err
Expand All @@ -212,6 +271,26 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {
return
}

responseValidationInput, body, err := p.prepareResponseValidatation(requestValidationInput, res)
if err != nil {
// TODO: use error template from parent endpoint>api>server
p.log.WithField("upstream response validation", err).Error()
couperErr.DefaultJSON.ServeError(couperErr.UpstreamResponseBufferingFailed).ServeHTTP(rw, req)
return
}
if responseValidationInput != nil {
if route != nil {
if err := openapi3filter.ValidateResponse(validationCtx, responseValidationInput); err != nil {
// TODO: use error template from parent endpoint>api>server
p.log.WithField("upstream response validation", err).Error()
couperErr.DefaultJSON.ServeError(couperErr.UpstreamResponseValidationFailed).ServeHTTP(rw, req)
return
}
} else {
p.log.Info("response validation enabled, but no route found")
}
}

// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols {
p.setRoundtripContext(req, res)
Expand Down Expand Up @@ -242,11 +321,15 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {

rw.WriteHeader(res.StatusCode)

_, err = io.Copy(rw, res.Body)
if err != nil {
defer res.Body.Close()
roundtripInfo.Err = err
return
if body != nil {
rw.Write(body)
} else {
_, err = io.Copy(rw, res.Body)
if err != nil {
defer res.Body.Close()
roundtripInfo.Err = err
return
}
}

res.Body.Close() // close now, instead of defer, to populate res.Trailer
Expand Down

0 comments on commit 60764c7

Please sign in to comment.