Skip to content

Commit

Permalink
[tailscale] net: add SockTrace API
Browse files Browse the repository at this point in the history
Loosely inspired by nettrace/httptrace, allows functions to be called
when sockets are read from or written to. The hooks are specified via
the context (with a WithSockTrace function).

Only implemented for network sockets on POSIX systems.

Updates tailscale/corp#9230
Updates #58
  • Loading branch information
mihaip committed Feb 28, 2023
1 parent ec180cb commit b2969cb
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 0 deletions.
6 changes: 6 additions & 0 deletions api/go1.99999.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
pkg net, func SetDialEnforcer(func(context.Context, []Addr) error) #55
pkg net, func SetResolveEnforcer(func(context.Context, string, string, string, Addr) error) #55
pkg net, func WithSockTrace(context.Context, *SockTrace) context.Context #58
pkg net, func ContextSockTrace(context.Context) *SockTrace #58
pkg net, type SockTrace struct #58
pkg net, type SockTrace struct, DidRead func(int) #58
pkg net, type SockTrace struct, DidWrite func(int) #58
pkg net, type SockTrace struct, WillOverwrite func(*SockTrace) #58
47 changes: 47 additions & 0 deletions src/net/fd_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ type netFD struct {
net string
laddr Addr
raddr Addr

// hooks (if provided) are called after successful reads or writes with the
// number of bytes transferred.
readHook func(int)
writeHook func(int)
}

func (fd *netFD) setAddr(laddr, raddr Addr) {
Expand Down Expand Up @@ -53,83 +58,125 @@ func (fd *netFD) closeWrite() error {

func (fd *netFD) Read(p []byte) (n int, err error) {
n, err = fd.pfd.Read(p)
if fd.readHook != nil && err == nil {
fd.readHook(n)
}
runtime.KeepAlive(fd)
return n, wrapSyscallError(readSyscallName, err)
}

func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
n, sa, err = fd.pfd.ReadFrom(p)
if fd.readHook != nil && err == nil {
fd.readHook(n)
}
runtime.KeepAlive(fd)
return n, sa, wrapSyscallError(readFromSyscallName, err)
}
func (fd *netFD) readFromInet4(p []byte, from *syscall.SockaddrInet4) (n int, err error) {
n, err = fd.pfd.ReadFromInet4(p, from)
if fd.readHook != nil && err == nil {
fd.readHook(n)
}
runtime.KeepAlive(fd)
return n, wrapSyscallError(readFromSyscallName, err)
}

func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, err error) {
n, err = fd.pfd.ReadFromInet6(p, from)
if fd.readHook != nil && err == nil {
fd.readHook(n)
}
runtime.KeepAlive(fd)
return n, wrapSyscallError(readFromSyscallName, err)
}

func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags)
if fd.readHook != nil && err == nil {
fd.readHook(n + oobn)
}
runtime.KeepAlive(fd)
return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err)
}

func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa)
if fd.readHook != nil && err == nil {
fd.readHook(n + oobn)
}
runtime.KeepAlive(fd)
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
}

func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa)
if fd.readHook != nil && err == nil {
fd.readHook(n + oobn)
}
runtime.KeepAlive(fd)
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
}

func (fd *netFD) Write(p []byte) (nn int, err error) {
nn, err = fd.pfd.Write(p)
if fd.writeHook != nil && err == nil {
fd.writeHook(nn)
}
runtime.KeepAlive(fd)
return nn, wrapSyscallError(writeSyscallName, err)
}

func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
n, err = fd.pfd.WriteTo(p, sa)
if fd.writeHook != nil && err == nil {
fd.writeHook(n)
}
runtime.KeepAlive(fd)
return n, wrapSyscallError(writeToSyscallName, err)
}

func (fd *netFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
n, err = fd.pfd.WriteToInet4(p, sa)
if fd.writeHook != nil && err == nil {
fd.writeHook(n)
}
runtime.KeepAlive(fd)
return n, wrapSyscallError(writeToSyscallName, err)
}

func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
n, err = fd.pfd.WriteToInet6(p, sa)
if fd.writeHook != nil && err == nil {
fd.writeHook(n)
}
runtime.KeepAlive(fd)
return n, wrapSyscallError(writeToSyscallName, err)
}

func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
if fd.writeHook != nil && err == nil {
fd.writeHook(n + oobn)
}
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
}

func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa)
if fd.writeHook != nil && err == nil {
fd.writeHook(n + oobn)
}
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
}

func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa)
if fd.writeHook != nil && err == nil {
fd.writeHook(n + oobn)
}
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
}
Expand Down
4 changes: 4 additions & 0 deletions src/net/sock_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only
poll.CloseFunc(s)
return nil, err
}
if trace := ContextSockTrace(ctx); trace != nil {
fd.readHook = trace.DidRead
fd.writeHook = trace.DidWrite
}

// This function makes a network file descriptor for the
// following applications:
Expand Down
46 changes: 46 additions & 0 deletions src/net/socktrace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package net

import (
"context"
)

// SockTrace is a set of hooks to run at various operations on a network socket.
// Any particular hook may be nil. Functions may be called concurrently from
// different goroutines.
type SockTrace struct {
// DidRead is called after a successful read from the socket, where n bytes
// were read.
DidRead func(n int)
// DidWrite is called after a successful write to the socket, where n bytes
// were written.
DidWrite func(n int)
// WillOverwrite is called when the registered trace is overwritten by a
// subsequent call to WithSockTrace. The provided trace is the new trace
// that will be used.
WillOverwrite func(trace *SockTrace)
}

// WithSockTrace returns a new context based on the provided parent
// ctx. Socket reads and writes made with the returned context will use
// the provided trace hooks. Any previous hooks registered with ctx are
// ovewritten (their WillOverwrite hook will be called).
func WithSockTrace(ctx context.Context, trace *SockTrace) context.Context {
if previous := ContextSockTrace(ctx); previous != nil && previous.WillOverwrite != nil {
previous.WillOverwrite(trace)
}
return context.WithValue(ctx, sockTraceKey{}, trace)
}

// ContextSockTrace returns the SockTrace associated with the
// provided context. If none, it returns nil.
func ContextSockTrace(ctx context.Context) *SockTrace {
trace, _ := ctx.Value(sockTraceKey{}).(*SockTrace)
return trace
}

// unique type to prevent assignment.
type sockTraceKey struct{}

0 comments on commit b2969cb

Please sign in to comment.