Skip to content

Commit

Permalink
[tailscale] net/http: add enforcement hook to Transport.RoundTrip, li…
Browse files Browse the repository at this point in the history
…ke our net ones

Updates #55
Updates tailscale/corp#12702

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
(cherry picked from commit 8df9488)
  • Loading branch information
bradfitz committed Aug 21, 2024
1 parent 0634555 commit 19afa83
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
1 change: 1 addition & 0 deletions api/go1.99999.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pkg net, type SockTrace struct, DidRead func(int)
pkg net, type SockTrace struct, DidWrite func(int)
pkg net, type SockTrace struct, WillCloseTCPConn func(syscall.RawConn)
pkg net, type SockTrace struct, WillOverwrite func(*SockTrace)
pkg net/http, func SetRoundTripEnforcer(func(*Request) error)
21 changes: 21 additions & 0 deletions src/net/http/tailscale.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package http

var roundTripEnforcer func(*Request) error

// SetRoundTripEnforcer set a program-global resolver enforcer that can cause
// RoundTrip calls to fail based on the request and its context.
//
// f must be non-nil.
//
// SetRoundTripEnforcer can only be called once, and must not be called
// concurrent with any RoundTrip call; it's expected to be registered during
// init.
func SetRoundTripEnforcer(f func(*Request) error) {
if f == nil {
panic("nil func")
}
if roundTripEnforcer != nil {
panic("already called")
}
roundTripEnforcer = f
}
6 changes: 6 additions & 0 deletions src/net/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,12 @@ func validateHeaders(hdrs Header) string {

// roundTrip implements a RoundTripper over HTTP.
func (t *Transport) roundTrip(req *Request) (_ *Response, err error) {
if roundTripEnforcer != nil {
if err := roundTripEnforcer(req); err != nil {
return nil, err
}
}

t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
ctx := req.Context()
trace := httptrace.ContextClientTrace(ctx)
Expand Down

0 comments on commit 19afa83

Please sign in to comment.