Skip to content

Commit

Permalink
feat: use single http.Transport to reuse connections (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
coadler authored Mar 24, 2023
1 parent 00ed52a commit 7e7d5e6
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 63 deletions.
36 changes: 15 additions & 21 deletions tunneld/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ allowed_ip=%s/128`,
}, exists, nil
}

type ipPortKey struct{}

func (api *API) handleTunnelMW(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
Expand Down Expand Up @@ -210,36 +212,28 @@ func (api *API) handleTunnelMW(next http.Handler) http.Handler {
return
}

dialCtx, dialCancel := context.WithTimeout(ctx, api.Options.PeerDialTimeout)
defer dialCancel()

nc, err := api.wgNet.DialContextTCPAddrPort(dialCtx, netip.AddrPortFrom(ip, tunnelsdk.TunnelPort))
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadGateway, tunnelsdk.Response{
Message: "Failed to dial peer.",
Detail: err.Error(),
})
return
}

span := trace.SpanFromContext(ctx)
span.SetAttributes(attribute.Bool("proxy_request", true))

// The transport on the reverse proxy uses this ctx value to know which
// IP to dial. See tunneld.go.
ctx = context.WithValue(ctx, ipPortKey{}, netip.AddrPortFrom(ip, tunnelsdk.TunnelPort))
r = r.WithContext(ctx)

rp := httputil.ReverseProxy{
// This can only happen when it fails to dial.
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
httpapi.Write(ctx, rw, http.StatusBadGateway, tunnelsdk.Response{
Message: "Failed to dial peer.",
Detail: err.Error(),
})
},
Director: func(rp *http.Request) {
rp.URL.Scheme = "http"
rp.URL.Host = r.Host
rp.Host = r.Host
},
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return &tracingConnWrapper{
Conn: nc,
span: span,
ctx: ctx,
}, nil
},
},
Transport: api.transport,
}

span.End()
Expand Down
40 changes: 0 additions & 40 deletions tunneld/tracing.go

This file was deleted.

36 changes: 34 additions & 2 deletions tunneld/tunneld.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package tunneld
import (
"context"
"fmt"
"net"
"net/http"
"net/netip"
"time"

"golang.org/x/xerrors"
"golang.zx2c4.com/wireguard/conn"
Expand All @@ -15,8 +18,9 @@ import (
type API struct {
*Options

wgNet *netstack.Net
wgDevice *device.Device
wgNet *netstack.Net
wgDevice *device.Device
transport *http.Transport
}

func New(options *Options) (*API, error) {
Expand Down Expand Up @@ -68,6 +72,34 @@ listen_port=%d`,
Options: options,
wgNet: wgNet,
wgDevice: dev,
transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
ip := ctx.Value(ipPortKey{})
if ip == nil {
return nil, xerrors.New("no ip on context")
}

ipp, ok := ip.(netip.AddrPort)
if !ok {
return nil, xerrors.Errorf("ip is incorrect type, got %T", ipp)
}

dialCtx, dialCancel := context.WithTimeout(ctx, options.PeerDialTimeout)
defer dialCancel()

nc, err := wgNet.DialContextTCPAddrPort(dialCtx, ipp)
if err != nil {
return nil, err
}

return nc, nil
},
ForceAttemptHTTP2: false,
MaxIdleConns: 0,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}, nil
}

Expand Down

0 comments on commit 7e7d5e6

Please sign in to comment.