diff --git a/go.mod b/go.mod index 721666a..a3991b2 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/caddyserver/caddy/v2 v2.8.4 github.com/caddyserver/certmagic v0.21.3 github.com/google/go-cmp v0.6.0 - github.com/tailscale/tscert v0.0.0-20240517230440-bbccfbf48933 + github.com/tailscale/tscert v0.0.0-20240607232451-34704dbdb4b3 go.uber.org/zap v1.27.0 tailscale.com v1.67.0-pre.0.20240602211424-42cfbf427c67 ) diff --git a/go.sum b/go.sum index d77dd7f..d583725 100644 --- a/go.sum +++ b/go.sum @@ -521,8 +521,8 @@ github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 h1:zrsUcqrG2uQ github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4/go.mod h1:phI29ccmHQBc+wvroosENp1IF9195449VDnFDhJ4rJU= -github.com/tailscale/tscert v0.0.0-20240517230440-bbccfbf48933 h1:pV0H+XIvFoP7pl1MRtyPXh5hqoxB5I7snOtTHgrn6HU= -github.com/tailscale/tscert v0.0.0-20240517230440-bbccfbf48933/go.mod h1:kNGUQ3VESx3VZwRwA9MSCUegIl6+saPL8Noq82ozCaU= +github.com/tailscale/tscert v0.0.0-20240607232451-34704dbdb4b3 h1:KKyPUIj4xlNYCryHDjvaQs92jYaEyYrmLAC9Ip4S6js= +github.com/tailscale/tscert v0.0.0-20240607232451-34704dbdb4b3/go.mod h1:kNGUQ3VESx3VZwRwA9MSCUegIl6+saPL8Noq82ozCaU= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:tdUdyPqJ0C97SJfjB9tW6EylTtreyee9C44de+UBG0g= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15g8OToDdASBni4CyJOdHY1Hr8M= diff --git a/module.go b/module.go index 778ab70..c304def 100644 --- a/module.go +++ b/module.go @@ -12,17 +12,20 @@ import ( "crypto/tls" "fmt" "net" + "net/http" "net/netip" "os" "path/filepath" "strconv" "strings" + "sync" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/caddyserver/certmagic" "github.com/tailscale/tscert" "go.uber.org/zap" + "tailscale.com/client/tailscale" "tailscale.com/hostinfo" "tailscale.com/tsnet" ) @@ -36,7 +39,8 @@ func init() { // Caddy uses tscert to get certificates for Tailscale hostnames. // Update the tscert dialer to dial the LocalAPI of the correct tsnet node, // rather than just always dialing the local tailscaled. - tscert.TailscaledDialer = localAPIDialer + //tscert.TailscaledDialer = localAPIDialer + tscert.TailscaledTransport = &tsnetMuxTransport{} hostinfo.SetApp("caddy") } @@ -317,40 +321,53 @@ func (t *tsnetServerListener) Close() error { return err } -// localAPIDialer finds the node that matches the requested certificate in ctx -// and dials that node's local API. -// If no matching node is found, the default dialer is used, -// which tries to connect to a local tailscaled on the machine. -func localAPIDialer(ctx context.Context, network, addr string) (net.Conn, error) { - if addr != "local-tailscaled.sock:80" { - return nil, fmt.Errorf("unexpected URL address %q", addr) - } +// localAPITransport is an [http.RoundTripper] that sends requests to a [tailscale.LocalClient]'s LocalAPI. +type localAPITransport struct { + *tailscale.LocalClient +} + +func (t *localAPITransport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.DoLocalRequest(req) +} + +// tsnetMuxTransport is an [http.RoundTripper] that sends requests to the LocalAPI +// for the tsnet server that matches the ClientHelloInfo server name. +// If no tsnet server matches, a default Transport is used. +type tsnetMuxTransport struct { + defaultTransport *http.Transport + defaultTransportOnce sync.Once +} + +func (t *tsnetMuxTransport) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() + var rt http.RoundTripper clientHello, ok := ctx.Value(certmagic.ClientHelloInfoCtxKey).(*tls.ClientHelloInfo) - if !ok || clientHello == nil { - return tscert.DialLocalAPI(ctx, network, addr) - } - - var tn *tailscaleNode - nodes.Range(func(key, value any) bool { - if n, ok := value.(*tailscaleNode); ok && n != nil { - for _, d := range n.CertDomains() { - // Tailscale doesn't do wildcard certs, but caddy uses MatchWildcard - // for the built-in Tailscale cert manager, so we do so here as well. - if certmagic.MatchWildcard(clientHello.ServerName, d) { - tn = n - return false + if ok && clientHello != nil { + nodes.Range(func(key, value any) bool { + if n, ok := value.(*tailscaleNode); ok && n != nil { + for _, d := range n.CertDomains() { + // Tailscale doesn't do wildcard certs, but caddy uses MatchWildcard + // for the built-in Tailscale cert manager, so we do so here as well. + if certmagic.MatchWildcard(clientHello.ServerName, d) { + if lc, err := n.LocalClient(); err == nil { + rt = &localAPITransport{lc} + } + return false + } } } - } - return true - }) - - if tn != nil { - if lc, err := tn.LocalClient(); err == nil { - return lc.Dial(ctx, network, addr) - } + return true + }) } - return tscert.DialLocalAPI(ctx, network, addr) + if rt == nil { + t.defaultTransportOnce.Do(func() { + t.defaultTransport = &http.Transport{ + DialContext: tscert.TailscaledDialer, + } + }) + rt = t.defaultTransport + } + return rt.RoundTrip(req) }