From 4b6f3fe863d6444b63e5b708249586bfe6f72a4f Mon Sep 17 00:00:00 2001 From: Will Norris Date: Fri, 7 Jun 2024 16:25:58 -0700 Subject: [PATCH] override tscert.TailscaledTransport with muxing transport This provides an http.RoundTripper implementation that dynamically routes requests to the correct tsnet server's LocalAPI based on the ClientHelloInfo in the context. Previously, we were just overriding the tscert Dialer. That worked fine the first time it dialed a LocalAPI, and would correctly dial the right tsnet server. However, tscert caches the Transport with that Dialer, so requests that should be routed to different tsnet servers would be routed incorrectly. Updates #19 Updates #53 Updates #66 Signed-off-by: Will Norris --- go.mod | 2 +- go.sum | 4 +-- module.go | 79 +++++++++++++++++++++++++++++++++---------------------- 3 files changed, 51 insertions(+), 34 deletions(-) diff --git a/go.mod b/go.mod index 721666a..a3991b2 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/caddyserver/caddy/v2 v2.8.4 github.com/caddyserver/certmagic v0.21.3 github.com/google/go-cmp v0.6.0 - github.com/tailscale/tscert v0.0.0-20240517230440-bbccfbf48933 + github.com/tailscale/tscert v0.0.0-20240607232451-34704dbdb4b3 go.uber.org/zap v1.27.0 tailscale.com v1.67.0-pre.0.20240602211424-42cfbf427c67 ) diff --git a/go.sum b/go.sum index d77dd7f..d583725 100644 --- a/go.sum +++ b/go.sum @@ -521,8 +521,8 @@ github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 h1:zrsUcqrG2uQ github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4/go.mod h1:phI29ccmHQBc+wvroosENp1IF9195449VDnFDhJ4rJU= -github.com/tailscale/tscert v0.0.0-20240517230440-bbccfbf48933 h1:pV0H+XIvFoP7pl1MRtyPXh5hqoxB5I7snOtTHgrn6HU= -github.com/tailscale/tscert v0.0.0-20240517230440-bbccfbf48933/go.mod h1:kNGUQ3VESx3VZwRwA9MSCUegIl6+saPL8Noq82ozCaU= +github.com/tailscale/tscert v0.0.0-20240607232451-34704dbdb4b3 h1:KKyPUIj4xlNYCryHDjvaQs92jYaEyYrmLAC9Ip4S6js= +github.com/tailscale/tscert v0.0.0-20240607232451-34704dbdb4b3/go.mod h1:kNGUQ3VESx3VZwRwA9MSCUegIl6+saPL8Noq82ozCaU= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:tdUdyPqJ0C97SJfjB9tW6EylTtreyee9C44de+UBG0g= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15g8OToDdASBni4CyJOdHY1Hr8M= diff --git a/module.go b/module.go index 778ab70..c304def 100644 --- a/module.go +++ b/module.go @@ -12,17 +12,20 @@ import ( "crypto/tls" "fmt" "net" + "net/http" "net/netip" "os" "path/filepath" "strconv" "strings" + "sync" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/caddyserver/certmagic" "github.com/tailscale/tscert" "go.uber.org/zap" + "tailscale.com/client/tailscale" "tailscale.com/hostinfo" "tailscale.com/tsnet" ) @@ -36,7 +39,8 @@ func init() { // Caddy uses tscert to get certificates for Tailscale hostnames. // Update the tscert dialer to dial the LocalAPI of the correct tsnet node, // rather than just always dialing the local tailscaled. - tscert.TailscaledDialer = localAPIDialer + //tscert.TailscaledDialer = localAPIDialer + tscert.TailscaledTransport = &tsnetMuxTransport{} hostinfo.SetApp("caddy") } @@ -317,40 +321,53 @@ func (t *tsnetServerListener) Close() error { return err } -// localAPIDialer finds the node that matches the requested certificate in ctx -// and dials that node's local API. -// If no matching node is found, the default dialer is used, -// which tries to connect to a local tailscaled on the machine. -func localAPIDialer(ctx context.Context, network, addr string) (net.Conn, error) { - if addr != "local-tailscaled.sock:80" { - return nil, fmt.Errorf("unexpected URL address %q", addr) - } +// localAPITransport is an [http.RoundTripper] that sends requests to a [tailscale.LocalClient]'s LocalAPI. +type localAPITransport struct { + *tailscale.LocalClient +} + +func (t *localAPITransport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.DoLocalRequest(req) +} + +// tsnetMuxTransport is an [http.RoundTripper] that sends requests to the LocalAPI +// for the tsnet server that matches the ClientHelloInfo server name. +// If no tsnet server matches, a default Transport is used. +type tsnetMuxTransport struct { + defaultTransport *http.Transport + defaultTransportOnce sync.Once +} + +func (t *tsnetMuxTransport) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() + var rt http.RoundTripper clientHello, ok := ctx.Value(certmagic.ClientHelloInfoCtxKey).(*tls.ClientHelloInfo) - if !ok || clientHello == nil { - return tscert.DialLocalAPI(ctx, network, addr) - } - - var tn *tailscaleNode - nodes.Range(func(key, value any) bool { - if n, ok := value.(*tailscaleNode); ok && n != nil { - for _, d := range n.CertDomains() { - // Tailscale doesn't do wildcard certs, but caddy uses MatchWildcard - // for the built-in Tailscale cert manager, so we do so here as well. - if certmagic.MatchWildcard(clientHello.ServerName, d) { - tn = n - return false + if ok && clientHello != nil { + nodes.Range(func(key, value any) bool { + if n, ok := value.(*tailscaleNode); ok && n != nil { + for _, d := range n.CertDomains() { + // Tailscale doesn't do wildcard certs, but caddy uses MatchWildcard + // for the built-in Tailscale cert manager, so we do so here as well. + if certmagic.MatchWildcard(clientHello.ServerName, d) { + if lc, err := n.LocalClient(); err == nil { + rt = &localAPITransport{lc} + } + return false + } } } - } - return true - }) - - if tn != nil { - if lc, err := tn.LocalClient(); err == nil { - return lc.Dial(ctx, network, addr) - } + return true + }) } - return tscert.DialLocalAPI(ctx, network, addr) + if rt == nil { + t.defaultTransportOnce.Do(func() { + t.defaultTransport = &http.Transport{ + DialContext: tscert.TailscaledDialer, + } + }) + rt = t.defaultTransport + } + return rt.RoundTrip(req) }