Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tailscale] net: add TCP socket creation/close hooks to SockTrace API #59

Merged
merged 1 commit into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/go1.99999.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ pkg net, type SockTrace struct #58
pkg net, type SockTrace struct, DidRead func(int) #58
pkg net, type SockTrace struct, DidWrite func(int) #58
pkg net, type SockTrace struct, WillOverwrite func(*SockTrace) #58
pkg net, type SockTrace struct, DidCreateTCPConn func(syscall.RawConn) #58
pkg net, type SockTrace struct, WillCloseTCPConn func(syscall.RawConn) #58
22 changes: 16 additions & 6 deletions src/net/fd_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type netFD struct {
// number of bytes transferred.
readHook func(int)
writeHook func(int)
closeHook func()
}

func (fd *netFD) setAddr(laddr, raddr Addr) {
Expand All @@ -39,6 +40,9 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {

func (fd *netFD) Close() error {
runtime.SetFinalizer(fd, nil)
if fd.closeHook != nil {
fd.closeHook()
}
return fd.pfd.Close()
}

Expand All @@ -49,10 +53,16 @@ func (fd *netFD) shutdown(how int) error {
}

func (fd *netFD) closeRead() error {
if fd.closeHook != nil {
fd.closeHook()
}
return fd.shutdown(syscall.SHUT_RD)
}

func (fd *netFD) closeWrite() error {
if fd.closeHook != nil {
fd.closeHook()
}
return fd.shutdown(syscall.SHUT_WR)
}

Expand Down Expand Up @@ -94,7 +104,7 @@ func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, er
func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags)
if fd.readHook != nil && err == nil {
fd.readHook(n + oobn)
fd.readHook(n)
}
runtime.KeepAlive(fd)
return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err)
Expand All @@ -103,7 +113,7 @@ func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int
func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa)
if fd.readHook != nil && err == nil {
fd.readHook(n + oobn)
fd.readHook(n)
}
runtime.KeepAlive(fd)
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
Expand All @@ -112,7 +122,7 @@ func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.Socka
func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa)
if fd.readHook != nil && err == nil {
fd.readHook(n + oobn)
fd.readHook(n)
}
runtime.KeepAlive(fd)
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
Expand Down Expand Up @@ -157,7 +167,7 @@ func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err e
func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
if fd.writeHook != nil && err == nil {
fd.writeHook(n + oobn)
fd.writeHook(n)
}
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
Expand All @@ -166,7 +176,7 @@ func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa)
if fd.writeHook != nil && err == nil {
fd.writeHook(n + oobn)
fd.writeHook(n)
}
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
Expand All @@ -175,7 +185,7 @@ func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4)
func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa)
if fd.writeHook != nil && err == nil {
fd.writeHook(n + oobn)
fd.writeHook(n)
}
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
Expand Down
15 changes: 15 additions & 0 deletions src/net/sock_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only
if trace := ContextSockTrace(ctx); trace != nil {
fd.readHook = trace.DidRead
fd.writeHook = trace.DidWrite
if (trace.DidCreateTCPConn != nil || trace.WillCloseTCPConn != nil) && len(net) >= 3 && net[0:3] == "tcp" {
// Ignore newRawConn errors (they're not possible in the current
// implementation, but even if they were, we don't want to
// affect socket operations for a trace hook invocation).
if c, err := newRawConn(fd); err == nil {
if trace.DidCreateTCPConn != nil {
trace.DidCreateTCPConn(c)
}
if trace.WillCloseTCPConn != nil {
fd.closeHook = func() {
trace.WillCloseTCPConn(c)
}
}
}
}
}

// This function makes a network file descriptor for the
Expand Down
7 changes: 7 additions & 0 deletions src/net/socktrace.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ package net

import (
"context"
"syscall"
)

// SockTrace is a set of hooks to run at various operations on a network socket.
// Any particular hook may be nil. Functions may be called concurrently from
// different goroutines.
type SockTrace struct {
// DidOpenTCPConn is called when a TCP socket was created. The
// underlying raw network connection that was created is provided.
DidCreateTCPConn func(c syscall.RawConn)
// DidRead is called after a successful read from the socket, where n bytes
// were read.
DidRead func(n int)
Expand All @@ -22,6 +26,9 @@ type SockTrace struct {
// subsequent call to WithSockTrace. The provided trace is the new trace
// that will be used.
WillOverwrite func(trace *SockTrace)
// WillCloseTCPConn is called when a TCP socket is about to be closed. The
// underlying raw network connection that is being closed is provided.
WillCloseTCPConn func(c syscall.RawConn)
}

// WithSockTrace returns a new context based on the provided parent
Expand Down