From 7e7d5e6892ab43f472d48923b6c721b8c601e587 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 24 Mar 2023 13:00:05 -0500 Subject: [PATCH] feat: use single `http.Transport` to reuse connections (#12) --- tunneld/api.go | 36 +++++++++++++++--------------------- tunneld/tracing.go | 40 ---------------------------------------- tunneld/tunneld.go | 36 ++++++++++++++++++++++++++++++++++-- 3 files changed, 49 insertions(+), 63 deletions(-) delete mode 100644 tunneld/tracing.go diff --git a/tunneld/api.go b/tunneld/api.go index ccd5987..4efe418 100644 --- a/tunneld/api.go +++ b/tunneld/api.go @@ -180,6 +180,8 @@ allowed_ip=%s/128`, }, exists, nil } +type ipPortKey struct{} + func (api *API) handleTunnelMW(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -210,36 +212,28 @@ func (api *API) handleTunnelMW(next http.Handler) http.Handler { return } - dialCtx, dialCancel := context.WithTimeout(ctx, api.Options.PeerDialTimeout) - defer dialCancel() - - nc, err := api.wgNet.DialContextTCPAddrPort(dialCtx, netip.AddrPortFrom(ip, tunnelsdk.TunnelPort)) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadGateway, tunnelsdk.Response{ - Message: "Failed to dial peer.", - Detail: err.Error(), - }) - return - } - span := trace.SpanFromContext(ctx) span.SetAttributes(attribute.Bool("proxy_request", true)) + // The transport on the reverse proxy uses this ctx value to know which + // IP to dial. See tunneld.go. + ctx = context.WithValue(ctx, ipPortKey{}, netip.AddrPortFrom(ip, tunnelsdk.TunnelPort)) + r = r.WithContext(ctx) + rp := httputil.ReverseProxy{ + // This can only happen when it fails to dial. + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + httpapi.Write(ctx, rw, http.StatusBadGateway, tunnelsdk.Response{ + Message: "Failed to dial peer.", + Detail: err.Error(), + }) + }, Director: func(rp *http.Request) { rp.URL.Scheme = "http" rp.URL.Host = r.Host rp.Host = r.Host }, - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return &tracingConnWrapper{ - Conn: nc, - span: span, - ctx: ctx, - }, nil - }, - }, + Transport: api.transport, } span.End() diff --git a/tunneld/tracing.go b/tunneld/tracing.go deleted file mode 100644 index 108e886..0000000 --- a/tunneld/tracing.go +++ /dev/null @@ -1,40 +0,0 @@ -package tunneld - -import ( - "context" - "net" - - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" -) - -type tracingConnWrapper struct { - net.Conn - - ctx context.Context - span trace.Span -} - -func (n *tracingConnWrapper) Read(b []byte) (int, error) { - _, span := otel.GetTracerProvider().Tracer("").Start(n.ctx, "(net.Conn).Read") - defer span.End() - - nbytes, err := n.Conn.Read(b) - span.SetAttributes(attribute.Int("bytes_read", nbytes)) - return nbytes, err -} - -func (n *tracingConnWrapper) Write(b []byte) (int, error) { - _, span := otel.GetTracerProvider().Tracer("").Start(n.ctx, "(net.Conn).Write") - defer span.End() - - nbytes, err := n.Conn.Write(b) - span.SetAttributes(attribute.Int("bytes_written", nbytes)) - return nbytes, err -} - -func (n *tracingConnWrapper) Close() error { - n.span.AddEvent("connClose") - return n.Conn.Close() -} diff --git a/tunneld/tunneld.go b/tunneld/tunneld.go index 058b886..6585ec0 100644 --- a/tunneld/tunneld.go +++ b/tunneld/tunneld.go @@ -3,7 +3,10 @@ package tunneld import ( "context" "fmt" + "net" + "net/http" "net/netip" + "time" "golang.org/x/xerrors" "golang.zx2c4.com/wireguard/conn" @@ -15,8 +18,9 @@ import ( type API struct { *Options - wgNet *netstack.Net - wgDevice *device.Device + wgNet *netstack.Net + wgDevice *device.Device + transport *http.Transport } func New(options *Options) (*API, error) { @@ -68,6 +72,34 @@ listen_port=%d`, Options: options, wgNet: wgNet, wgDevice: dev, + transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + ip := ctx.Value(ipPortKey{}) + if ip == nil { + return nil, xerrors.New("no ip on context") + } + + ipp, ok := ip.(netip.AddrPort) + if !ok { + return nil, xerrors.Errorf("ip is incorrect type, got %T", ipp) + } + + dialCtx, dialCancel := context.WithTimeout(ctx, options.PeerDialTimeout) + defer dialCancel() + + nc, err := wgNet.DialContextTCPAddrPort(dialCtx, ipp) + if err != nil { + return nil, err + } + + return nc, nil + }, + ForceAttemptHTTP2: false, + MaxIdleConns: 0, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, }, nil }