Skip to content

Commit

Permalink
reduce allocs (#666)
Browse files Browse the repository at this point in the history
* use len instead of nil check

* update changelog link

* fixup jwtlib nil ref

* fixup possible timer leak for backends

* fixed proxy handler websocket option and context handling
  • Loading branch information
malud authored Jan 26, 2023
1 parent 330bf3f commit 47573d4
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 38 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ Unreleased changes are available as `avenga/couper:edge` container.

* **Fixed**
* Requests to wildcard (`**`) [endpoints](https://docs.couper.io/configuration/block/endpoint) using backends with a wildcard [`path` attribue](https://docs.couper.io/configuration/block/backend#attributes), where the wildcard matches the empty string (regression; since v1.11.0) ([#655](https://github.com/avenga/couper/pull/655))
* Creating request context based jwt, oauth2 and saml (hcl) functions without related definitions ([#666](https://github.com/avenga/couper/pull/666))
* Reduced allocation amount while proxying requests ([#666](https://github.com/avenga/couper/pull/666))
* Removing websockets related headers while the proxy `websockets` option is `false` (or no block definition) ([#666](https://github.com/avenga/couper/pull/666))

---

Expand Down
18 changes: 15 additions & 3 deletions config/runtime/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,27 @@ func NewEndpointOptions(confCtx *hcl.EvalContext, endpointConf *config.Endpoint,
if berr != nil {
return nil, berr
}
proxyHandler := handler.NewProxy(backend, proxyConf.HCLBody(), log)

var hasWSblock bool
proxyBody := proxyConf.HCLBody()
for _, b := range proxyBody.Blocks {
if b.Type == "websockets" {
hasWSblock = true
break
}
}

allowWebsockets := proxyConf.Websockets != nil || hasWSblock
proxyHandler := handler.NewProxy(backend, proxyBody, allowWebsockets, log)

p := &producer.Proxy{
Content: proxyConf.HCLBody(),
Content: proxyBody,
Name: proxyConf.Name,
RoundTrip: proxyHandler,
}

allProxies[proxyConf.Name] = p
blockBodies = append(blockBodies, proxyConf.Backend, proxyConf.HCLBody())
blockBodies = append(blockBodies, proxyConf.Backend, proxyBody)
}

allRequests := make(map[string]*producer.Request)
Expand Down
12 changes: 8 additions & 4 deletions eval/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,20 +407,24 @@ func (c *Context) getCodeVerifier() (*pkce.CodeVerifier, error) {

// updateFunctions recreates the listed functions with the current evaluation context.
func (c *Context) updateFunctions() {
jwtfn := lib.NewJwtSignFunction(c.eval, c.jwtSigningConfigs, Value)
c.eval.Functions[lib.FnJWTSign] = jwtfn
if len(c.jwtSigningConfigs) > 0 {
jwtfn := lib.NewJwtSignFunction(c.eval, c.jwtSigningConfigs, Value)
c.eval.Functions[lib.FnJWTSign] = jwtfn
} else {
c.eval.Functions[lib.FnJWTSign] = lib.NoOpJwtSignFunction
}
}

// updateRequestRelatedFunctions re-creates the listed functions for the client request context.
func (c *Context) updateRequestRelatedFunctions(origin *url.URL) {
if c.oauth2 != nil {
if len(c.oauth2) > 0 {
oauth2fn := lib.NewOAuthAuthorizationURLFunction(c.eval, c.oauth2, c.getCodeVerifier, origin, Value)
c.eval.Functions[lib.FnOAuthAuthorizationURL] = oauth2fn
}
c.eval.Functions[lib.FnOAuthVerifier] = lib.NewOAuthCodeVerifierFunction(c.getCodeVerifier)
c.eval.Functions[lib.InternalFnOAuthHashedVerifier] = lib.NewOAuthCodeChallengeFunction(c.getCodeVerifier)

if c.saml != nil {
if len(c.saml) > 0 {
samlfn := lib.NewSamlSsoURLFunction(c.saml, origin)
c.eval.Functions[lib.FnSamlSsoURL] = samlfn
}
Expand Down
19 changes: 18 additions & 1 deletion eval/lib/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,23 @@ func NewJWTSigningConfigFromJWT(j *config.JWT) (*JWTSigningConfig, error) {
return c, nil
}

var NoOpJwtSignFunction = function.New(&function.Spec{
Params: []function.Parameter{
{
Name: "jwt_signing_profile_label",
Type: cty.String,
},
{
Name: "claims",
Type: cty.DynamicPseudoType,
},
},
Type: function.StaticReturnType(cty.String),
Impl: func(_ []cty.Value, _ cty.Type) (ret cty.Value, err error) {
return cty.StringVal(""), fmt.Errorf("missing jwt_signing_profile or jwt (with signing_ttl) definitions")
},
})

func NewJwtSignFunction(ctx *hcl.EvalContext, jwtSigningConfigs map[string]*JWTSigningConfig,
evalFn func(*hcl.EvalContext, hcl.Expression) (cty.Value, error)) function.Function {
return function.New(&function.Spec{
Expand All @@ -145,7 +162,7 @@ func NewJwtSignFunction(ctx *hcl.EvalContext, jwtSigningConfigs map[string]*JWTS
Type: function.StaticReturnType(cty.String),
Impl: func(args []cty.Value, _ cty.Type) (ret cty.Value, err error) {
if len(jwtSigningConfigs) == 0 {
return cty.StringVal(""), fmt.Errorf("missing jwt_signing_profile or jwt (with signing_ttl) definitions")
return NoOpJwtSignFunction.Call(nil)
}

label := args[0].AsString()
Expand Down
2 changes: 1 addition & 1 deletion handler/producer/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func Test_ProduceExpectedStatus(t *testing.T) {
proxies := producer.Proxies{&producer.Proxy{
Content: content,
Name: "proxy",
RoundTrip: handler.NewProxy(backend, content, logEntry),
RoundTrip: handler.NewProxy(backend, content, false, logEntry),
}}

testNames := []string{"request", "proxy"}
Expand Down
55 changes: 28 additions & 27 deletions handler/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"
"time"

"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/hclsyntax"
"github.com/sirupsen/logrus"

Expand All @@ -30,13 +31,15 @@ var headerBlacklist = []string{"Authorization", "Cookie"}
// Proxy wraps a httputil.ReverseProxy to apply additional configuration context
// and have control over the roundtrip configuration.
type Proxy struct {
allowWS bool
backend http.RoundTripper
context *hclsyntax.Body
logger *logrus.Entry
}

func NewProxy(backend http.RoundTripper, ctx *hclsyntax.Body, logger *logrus.Entry) *Proxy {
func NewProxy(backend http.RoundTripper, ctx *hclsyntax.Body, allowWS bool, logger *logrus.Entry) *Proxy {
proxy := &Proxy{
allowWS: allowWS,
backend: backend,
context: ctx,
logger: logger,
Expand All @@ -54,12 +57,14 @@ func (p *Proxy) RoundTrip(req *http.Request) (*http.Response, error) {
hclCtx := eval.ContextFromRequest(req).HCLContextSync()

// 2. Apply proxy-body
if err := eval.ApplyRequestContext(hclCtx, p.context, req); err != nil {
err := eval.ApplyRequestContext(hclCtx, p.context, req)
if err != nil {
return nil, err
}

// 3. Apply websockets-body
if err := p.applyWebsocketsRequest(req); err != nil {
outCtx, err := p.applyWebsocketsRequest(hclCtx, req)
if err != nil {
return nil, err
}

Expand All @@ -69,7 +74,7 @@ func (p *Proxy) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, err
}

outCtx := context.WithValue(req.Context(), request.EndpointExpectedStatus, seetie.ValueToIntSlice(expStatusVal))
outCtx = context.WithValue(outCtx, request.EndpointExpectedStatus, seetie.ValueToIntSlice(expStatusVal))

*req = *req.WithContext(outCtx)

Expand Down Expand Up @@ -141,54 +146,50 @@ func upgradeType(h http.Header) string {
return ""
}

func (p *Proxy) applyWebsocketsRequest(req *http.Request) error {
ctx := req.Context()

ctx = context.WithValue(ctx, request.WebsocketsAllowed, true)
*req = *req.WithContext(ctx)

hclCtx := eval.ContextFromRequest(req).HCLContextSync()
func (p *Proxy) applyWebsocketsRequest(hclCtx *hcl.EvalContext, req *http.Request) (context.Context, error) {
outCtx := req.Context()
if p.allowWS {
outCtx = context.WithValue(outCtx, request.WebsocketsAllowed, p.allowWS)
} else {
return outCtx, nil
}

// This method needs the 'request.WebsocketsAllowed' flag in the 'req.context'.
if !eval.IsUpgradeRequest(req) {
return nil
if !eval.IsUpgradeRequest(req.WithContext(outCtx)) {
return outCtx, nil
}

wsBody := p.getWebsocketsBody()
if wsBody == nil { // applies if just the websockets attribute is given
return outCtx, nil
}

if err := eval.ApplyRequestContext(hclCtx, wsBody, req); err != nil {
return err
return nil, err
}

attr, ok := wsBody.Attributes["timeout"]
if !ok {
return nil
return outCtx, nil
}

val, err := eval.Value(hclCtx, attr.Expr)
if err != nil {
return err
return nil, err
}

str := seetie.ValueToString(val)

timeout, err := time.ParseDuration(str)
if str != "" && err != nil {
return err
return nil, err
}

ctx = context.WithValue(ctx, request.WebsocketsTimeout, timeout)
*req = *req.WithContext(ctx)

return nil
outCtx = context.WithValue(outCtx, request.WebsocketsTimeout, timeout)
return outCtx, nil
}

func (p *Proxy) registerWebsocketsResponse(req *http.Request) error {
ctx := req.Context()

ctx = context.WithValue(ctx, request.WebsocketsAllowed, true)
*req = *req.WithContext(ctx)

// This method needs the 'request.WebsocketsAllowed' flag in the 'req.context'.
if !eval.IsUpgradeRequest(req) {
return nil
}
Expand Down
57 changes: 56 additions & 1 deletion handler/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handler_test

import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
Expand All @@ -10,6 +11,7 @@ import (

"github.com/avenga/couper/config"
"github.com/avenga/couper/config/body"
"github.com/avenga/couper/config/request"
"github.com/avenga/couper/eval"
"github.com/avenga/couper/handler"
"github.com/avenga/couper/handler/transport"
Expand All @@ -24,14 +26,15 @@ func TestProxy_BlacklistHeaderRemoval(t *testing.T) {
Origin: "https://1.2.3.4/",
}, nil, logEntry),
&hclsyntax.Body{},
false,
logEntry,
)

outreq := httptest.NewRequest("GET", "https://1.2.3.4/", nil)
outreq.Header.Set("Authorization", "Basic 123")
outreq.Header.Set("Cookie", "123")
outreq = outreq.WithContext(eval.NewContext(nil, &config.Defaults{}, "").WithClientRequest(outreq))
ctx, cancel := context.WithDeadline(outreq.Context(), time.Now().Add(time.Millisecond*50))
ctx, cancel := context.WithDeadline(context.WithValue(context.Background(), request.RoundTripProxy, true), time.Now().Add(time.Millisecond*50))
outreq = outreq.WithContext(ctx)
defer cancel()

Expand All @@ -45,3 +48,55 @@ func TestProxy_BlacklistHeaderRemoval(t *testing.T) {
t.Error("Expected removed Cookie header")
}
}

func TestProxy_WebsocketsAllowed(t *testing.T) {
log, _ := test.NewLogger()
logEntry := log.WithContext(context.Background())

origin := test.NewBackend()

pNotAllowed := handler.NewProxy(
transport.NewBackend(body.NewHCLSyntaxBodyWithStringAttr("origin", origin.Addr()), &transport.Config{
Origin: origin.Addr(),
}, nil, logEntry),
&hclsyntax.Body{},
false,
logEntry,
)

pAllowed := handler.NewProxy(
transport.NewBackend(body.NewHCLSyntaxBodyWithStringAttr("origin", origin.Addr()), &transport.Config{
Origin: origin.Addr(),
}, nil, logEntry),
&hclsyntax.Body{},
true,
logEntry,
)

headers := http.Header{
"Connection": []string{"upgrade"},
"Upgrade": []string{"websocket"},
}

outreqN := httptest.NewRequest("GET", "http://couper.local/ws", nil)
outreqA := httptest.NewRequest("GET", "http://couper.local/ws", nil)

outCtx := context.WithValue(context.Background(), request.RoundTripProxy, true)

for _, r := range []*http.Request{outreqN, outreqA} {
for h := range headers {
r.Header.Set(h, headers.Get(h))
}
}

resN, _ := pNotAllowed.RoundTrip(outreqN.WithContext(outCtx))
resA, _ := pAllowed.RoundTrip(outreqA.WithContext(outCtx))

if resN.StatusCode != http.StatusBadRequest {
t.Errorf("expected a bad request on ws endpoint without related headers, got: %d", resN.StatusCode)
}

if resA.StatusCode != http.StatusSwitchingProtocols {
t.Errorf("expcted passed Connection and Upgrade header which results in 101, got: %d", resA.StatusCode)
}
}
4 changes: 3 additions & 1 deletion handler/transport/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ func (b *Backend) withTimeout(req *http.Request, conf *Config) <-chan error {
defer cancelFn()
deadline := make(<-chan time.Time)
if timeout > 0 {
deadline = time.After(timeout)
deadlineTimer := time.NewTimer(timeout)
deadline = deadlineTimer.C
defer deadlineTimer.Stop()
}
select {
case <-deadline:
Expand Down

0 comments on commit 47573d4

Please sign in to comment.