Skip to content

Commit

Permalink
upstream validation: implementation, first try (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Koch authored and Marcel Ludwig committed Dec 8, 2020
1 parent ce072a9 commit e8f64ef
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 5 deletions.
15 changes: 15 additions & 0 deletions config/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ type Backend struct {
RequestBodyLimit string `hcl:"request_body_limit,optional"`
TTFBTimeout string `hcl:"ttfb_timeout,optional"`
Timeout string `hcl:"timeout,optional"`
SwaggerDef string `hcl:"swagger_definition,optional"`
ValidateReq bool `hcl:"validate_request,optional"`
ValidateRes bool `hcl:"validate_response,optional"`
}

func (b Backend) Body() hcl.Body {
Expand Down Expand Up @@ -79,6 +82,18 @@ func (b *Backend) Merge(other *Backend) (*Backend, []hcl.Body) {
result.Timeout = other.Timeout
}

if other.SwaggerDef != "" {
result.SwaggerDef = other.SwaggerDef
}

if other.ValidateReq {
result.ValidateReq = other.ValidateReq
}

if other.ValidateRes {
result.ValidateRes = other.ValidateRes
}

return &result, bodies
}

Expand Down
10 changes: 10 additions & 0 deletions errors/code.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ const (
BasicAuthFailed
)

const (
UpstreamRequestValidationFailed Code = 6000 + iota
UpstreamResponseValidationFailed
UpstreamResponseBufferingFailed
)

var codes = map[Code]string{
// 1xxx
Server: "Server error",
Expand All @@ -53,6 +59,10 @@ var codes = map[Code]string{
AuthorizationRequired: "Authorization required",
AuthorizationFailed: "Authorization failed",
BasicAuthFailed: "Unauthorized",
// 6xxx
UpstreamRequestValidationFailed: "Upstream request validation failed",
UpstreamResponseValidationFailed: "Upstream response validation failed",
UpstreamResponseBufferingFailed: "Upstream response buffering failed",
}

type Code int
Expand Down
92 changes: 87 additions & 5 deletions handler/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@ import (
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"net/http"
"net/url"
"os"
"regexp"
"strconv"
"strings"
"sync"
"time"

"github.com/getkin/kin-openapi/openapi3filter"
"github.com/hashicorp/hcl/v2"
"github.com/sirupsen/logrus"
"golang.org/x/net/http/httpguts"
Expand Down Expand Up @@ -168,6 +171,44 @@ func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
p.upstreamLog.ServeHTTP(rw, req, logging.RoundtripHandlerFunc(p.roundtrip), startTime)
}

func (p *Proxy) prepareRequestValidation(outreq *http.Request) (context.Context, *openapi3filter.Route, *openapi3filter.RequestValidationInput, error) {
if p.options.ValidateReq || p.options.ValidateRes {
dir, err := os.Getwd()
if err != nil {
return nil, nil, nil, err
}
router := openapi3filter.NewRouter().WithSwaggerFromFile(dir + "/" + p.options.SwaggerDef)
validationCtx := context.Background()
route, pathParams, _ := router.FindRoute(outreq.Method, outreq.URL)

requestValidationInput := &openapi3filter.RequestValidationInput{
Request: outreq,
PathParams: pathParams,
Route: route,
}
return validationCtx, route, requestValidationInput, nil
}
return nil, nil, nil, nil
}

func (p *Proxy) prepareResponseValidation(requestValidationInput *openapi3filter.RequestValidationInput, res *http.Response) (*openapi3filter.ResponseValidationInput, []byte, error) {
if p.options.ValidateRes {
responseValidationInput := &openapi3filter.ResponseValidationInput{
RequestValidationInput: requestValidationInput,
Status: res.StatusCode,
Header: res.Header,
Options: &openapi3filter.Options{IncludeResponseStatus: true /* undefined response codes are invalid */},
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, nil, err
}
responseValidationInput.SetBodyBytes(body)
return responseValidationInput, body, nil
}
return nil, nil, nil
}

func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {
ctx := req.Context()
if p.options.Timeout > 0 {
Expand Down Expand Up @@ -220,6 +261,23 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {
outreq.Header.Set("X-Forwarded-For", clientIP)
}

validationCtx, route, requestValidationInput, err := p.prepareRequestValidation(outreq)
if err != nil {
// this only happens if os.Getwd() fails
// TODO: use error template from parent endpoint>api>server
p.log.WithField("upstream request validation", err).Error()
couperErr.DefaultJSON.ServeError(couperErr.UpstreamRequestValidationFailed).ServeHTTP(rw, req)
return
}
if p.options.ValidateReq {
if err := openapi3filter.ValidateRequest(validationCtx, requestValidationInput); err != nil {
// TODO: use error template from parent endpoint>api>server
p.log.WithField("upstream request validation", err).Error()
couperErr.DefaultJSON.ServeError(couperErr.UpstreamRequestValidationFailed).ServeHTTP(rw, req)
return
}
}

res, err := p.getTransport(outreq.URL.Scheme, outreq.URL.Host, outreq.Host).RoundTrip(outreq)
roundtripInfo := req.Context().Value(request.RoundtripInfo).(*logging.RoundtripInfo)
roundtripInfo.BeReq, roundtripInfo.BeResp, roundtripInfo.Err = outreq, res, err
Expand Down Expand Up @@ -249,6 +307,26 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {
res.Body = eval.NewReadCloser(src, res.Body)
}

responseValidationInput, body, err := p.prepareResponseValidation(requestValidationInput, res)
if err != nil {
// TODO: use error template from parent endpoint>api>server
p.log.WithField("upstream response validation", err).Error()
couperErr.DefaultJSON.ServeError(couperErr.UpstreamResponseBufferingFailed).ServeHTTP(rw, req)
return
}
if responseValidationInput != nil {
if route != nil {
if err := openapi3filter.ValidateResponse(validationCtx, responseValidationInput); err != nil {
// TODO: use error template from parent endpoint>api>server
p.log.WithField("upstream response validation", err).Error()
couperErr.DefaultJSON.ServeError(couperErr.UpstreamResponseValidationFailed).ServeHTTP(rw, req)
return
}
} else {
p.log.Info("response validation enabled, but no route found")
}
}

removeConnectionHeaders(res.Header)

for _, h := range hopHeaders {
Expand All @@ -272,11 +350,15 @@ func (p *Proxy) roundtrip(rw http.ResponseWriter, req *http.Request) {

rw.WriteHeader(res.StatusCode)

_, err = io.Copy(rw, res.Body)
if err != nil {
defer res.Body.Close()
roundtripInfo.Err = err
return
if body != nil {
rw.Write(body)
} else {
_, err = io.Copy(rw, res.Body)
if err != nil {
defer res.Body.Close()
roundtripInfo.Err = err
return
}
}

res.Body.Close() // close now, instead of defer, to populate res.Trailer
Expand Down
5 changes: 5 additions & 0 deletions handler/proxy_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ type ProxyOptions struct {
ConnectTimeout, Timeout, TTFBTimeout time.Duration
Context []hcl.Body
BackendName string
SwaggerDef string
ValidateReq, ValidateRes bool
CORS *CORSOptions
RequestBodyLimit int64
}
Expand Down Expand Up @@ -50,6 +52,9 @@ func NewProxyOptions(conf *config.Backend, corsOpts *CORSOptions, remainCtx []hc
RequestBodyLimit: bodyLimit,
TTFBTimeout: ttfbD,
Timeout: totalD,
SwaggerDef: conf.SwaggerDef,
ValidateReq: conf.ValidateReq,
ValidateRes: conf.ValidateRes,
}, nil
}

Expand Down

0 comments on commit e8f64ef

Please sign in to comment.