Skip to content

Commit

Permalink
[tailscale] net: add enforcement hooks
Browse files Browse the repository at this point in the history
Updates #55
Updates tailscale/corp#8944

Signed-off-by: Jenny Zhang <jz@tailscale.com>
(Cherry-picked from 13373ca)
  • Loading branch information
phirework committed Jun 21, 2023
1 parent 52e005d commit 479855c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
2 changes: 2 additions & 0 deletions api/go1.99999.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pkg net, func SetDialEnforcer(func(context.Context, []Addr) error) #55
pkg net, func SetResolveEnforcer(func(context.Context, string, string, string, Addr) error) #55
48 changes: 48 additions & 0 deletions src/net/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,24 @@ func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet s
return "", 0, UnknownNetworkError(network)
}

// SetResolveEnforcer set a program-global resolver enforcer that can cause resolvers to
// fail based on the context and/or other arguments.
//
// f must be non-nil, it can only be called once, and must not be called
// concurrent with any dial/resolve.
func SetResolveEnforcer(f func(ctx context.Context, op, network, addr string, hint Addr) error) {
if f == nil {
panic("nil func")
}
if resolveEnforcer != nil {
panic("already called")
}
resolveEnforcer = f
}

// resolveEnforcer, if non-nil, is the installed hook from SetResolveEnforcer.
var resolveEnforcer func(ctx context.Context, op, network, addr string, hint Addr) error

// resolveAddrList resolves addr using hint and returns a list of
// addresses. The result contains at least one address when error is
// nil.
Expand All @@ -269,6 +287,13 @@ func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string
}
return addrList{addr}, nil
}

if resolveEnforcer != nil {
if err := resolveEnforcer(ctx, op, network, addr, hint); err != nil {
return nil, err
}
}

addrs, err := r.internetAddrList(ctx, afnet, addr)
if err != nil || op != "dial" || hint == nil {
return addrs, err
Expand Down Expand Up @@ -572,9 +597,32 @@ func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addr
}
}

// SetDialEnforcer set a program-global dial enforcer that can cause dials to
// fail based on the context and/or Addr(s).
//
// f must be non-nil, it can only be called once, and must not be called
// concurrent with any dial.
func SetDialEnforcer(f func(context.Context, []Addr) error) {
if f == nil {
panic("nil func")
}
if dialEnforcer != nil {
panic("already called")
}
dialEnforcer = f
}

// dialEnforce, if non-nil, is any installed hook from SetDialEnforcer.
var dialEnforcer func(context.Context, []Addr) error

// dialSerial connects to a list of addresses in sequence, returning
// either the first successful connection, or the first error.
func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
if dialEnforcer != nil {
if err := dialEnforcer(ctx, ras); err != nil {
return nil, err
}
}
var firstErr error // The error from the first address is most relevant.

for i, ra := range ras {
Expand Down

0 comments on commit 479855c

Please sign in to comment.