From 7e833fde1e9afd2df01c625fd344279c5bebff89 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Fri, 5 Nov 2021 01:52:54 +0100 Subject: [PATCH] global: use netip where possible now There are more places where we'll need to add it later, when Go 1.18 comes out with support for it in the "net" package. Also, allowedips still uses slices internally, which might be suboptimal. Signed-off-by: Jason A. Donenfeld --- conn/bind_linux.go | 50 ++++------ conn/bind_std.go | 19 ++-- conn/bind_windows.go | 19 ++-- conn/bindtest/bindtest.go | 14 +-- conn/conn.go | 37 +------ device/allowedips.go | 42 ++++---- device/allowedips_rand_test.go | 6 +- device/allowedips_test.go | 6 +- device/device_test.go | 6 +- device/endpoint_test.go | 39 ++++---- device/receive.go | 1 - device/uapi.go | 11 +-- go.mod | 7 +- go.sum | 13 ++- ratelimiter/ratelimiter.go | 58 +++-------- ratelimiter/ratelimiter_test.go | 33 ++++--- tun/netstack/examples/http_client.go | 7 +- tun/netstack/examples/http_server.go | 6 +- tun/netstack/go.mod | 1 + tun/netstack/go.sum | 4 + tun/netstack/tun.go | 143 +++++++++++++++------------ tun/tuntest/tuntest.go | 10 +- 22 files changed, 247 insertions(+), 285 deletions(-) diff --git a/conn/bind_linux.go b/conn/bind_linux.go index 7b970e65a..975a6ab60 100644 --- a/conn/bind_linux.go +++ b/conn/bind_linux.go @@ -14,6 +14,7 @@ import ( "unsafe" "golang.org/x/sys/unix" + "golang.zx2c4.com/go118/netip" ) type ipv4Source struct { @@ -70,32 +71,30 @@ var _ Bind = (*LinuxSocketBind)(nil) func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { var end LinuxSocketEndpoint - addr, err := parseEndpoint(s) + e, err := netip.ParseAddrPort(s) if err != nil { return nil, err } - ipv4 := addr.IP.To4() - if ipv4 != nil { + if e.Addr().Is4() { dst := end.dst4() end.isV6 = false - dst.Port = addr.Port - copy(dst.Addr[:], ipv4) + dst.Port = int(e.Port()) + dst.Addr = e.Addr().As4() end.ClearSrc() return &end, nil } - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) + if e.Addr().Is6() { + zone, err := zoneToUint32(e.Addr().Zone()) if err != nil { return nil, err } dst := end.dst6() end.isV6 = true - dst.Port = addr.Port + dst.Port = int(e.Port()) dst.ZoneId = zone - copy(dst.Addr[:], ipv6[:]) + dst.Addr = e.Addr().As16() end.ClearSrc() return &end, nil } @@ -266,29 +265,19 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { } } -func (end *LinuxSocketEndpoint) SrcIP() net.IP { +func (end *LinuxSocketEndpoint) SrcIP() netip.Addr { if !end.isV6 { - return net.IPv4( - end.src4().Src[0], - end.src4().Src[1], - end.src4().Src[2], - end.src4().Src[3], - ) + return netip.AddrFrom4(end.src4().Src) } else { - return end.src6().src[:] + return netip.AddrFrom16(end.src6().src) } } -func (end *LinuxSocketEndpoint) DstIP() net.IP { +func (end *LinuxSocketEndpoint) DstIP() netip.Addr { if !end.isV6 { - return net.IPv4( - end.dst4().Addr[0], - end.dst4().Addr[1], - end.dst4().Addr[2], - end.dst4().Addr[3], - ) + return netip.AddrFrom4(end.src4().Src) } else { - return end.dst6().Addr[:] + return netip.AddrFrom16(end.dst6().Addr) } } @@ -305,14 +294,13 @@ func (end *LinuxSocketEndpoint) SrcToString() string { } func (end *LinuxSocketEndpoint) DstToString() string { - var udpAddr net.UDPAddr - udpAddr.IP = end.DstIP() + var port int if !end.isV6 { - udpAddr.Port = end.dst4().Port + port = end.dst4().Port } else { - udpAddr.Port = end.dst6().Port + port = end.dst6().Port } - return udpAddr.String() + return netip.AddrPortFrom(end.DstIP(), uint16(port)).String() } func (end *LinuxSocketEndpoint) ClearDst() { diff --git a/conn/bind_std.go b/conn/bind_std.go index cb85cfdaf..a3cbb1584 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -10,6 +10,8 @@ import ( "net" "sync" "syscall" + + "golang.zx2c4.com/go118/netip" ) // StdNetBind is meant to be a temporary solution on platforms for which @@ -32,18 +34,23 @@ var _ Bind = (*StdNetBind)(nil) var _ Endpoint = (*StdNetEndpoint)(nil) func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { - addr, err := parseEndpoint(s) - return (*StdNetEndpoint)(addr), err + e, err := netip.ParseAddrPort(s) + return (*StdNetEndpoint)(&net.UDPAddr{ + IP: e.Addr().AsSlice(), + Port: int(e.Port()), + Zone: e.Addr().Zone(), + }), err } func (*StdNetEndpoint) ClearSrc() {} -func (e *StdNetEndpoint) DstIP() net.IP { - return (*net.UDPAddr)(e).IP +func (e *StdNetEndpoint) DstIP() netip.Addr { + a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP) + return a } -func (e *StdNetEndpoint) SrcIP() net.IP { - return nil // not supported +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return netip.Addr{} // not supported } func (e *StdNetEndpoint) DstToBytes() []byte { diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 42e06adbc..26a3af851 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -15,6 +15,7 @@ import ( "unsafe" "golang.org/x/sys/windows" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn/winrio" ) @@ -128,18 +129,18 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { func (*WinRingEndpoint) ClearSrc() {} -func (e *WinRingEndpoint) DstIP() net.IP { +func (e *WinRingEndpoint) DstIP() netip.Addr { switch e.family { case windows.AF_INET: - return append([]byte{}, e.data[2:6]...) + return netip.AddrFrom4(*(*[4]byte)(e.data[2:6])) case windows.AF_INET6: - return append([]byte{}, e.data[6:22]...) + return netip.AddrFrom16(*(*[16]byte)(e.data[6:22])) } - return nil + return netip.Addr{} } -func (e *WinRingEndpoint) SrcIP() net.IP { - return nil // not supported +func (e *WinRingEndpoint) SrcIP() netip.Addr { + return netip.Addr{} // not supported } func (e *WinRingEndpoint) DstToBytes() []byte { @@ -161,15 +162,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte { func (e *WinRingEndpoint) DstToString() string { switch e.family { case windows.AF_INET: - addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))} - return addr.String() + netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() case windows.AF_INET6: var zone string if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { zone = strconv.FormatUint(uint64(scope), 10) } - addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))} - return addr.String() + return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String() } return "" } diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 7d43fb30b..6a4589613 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -10,8 +10,8 @@ import ( "math/rand" "net" "os" - "strconv" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" ) @@ -61,9 +61,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } -func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) } +func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } -func (c ChannelEndpoint) SrcIP() net.IP { return nil } +func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { c.closeSignal = make(chan bool) @@ -119,13 +119,9 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { } func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { - _, port, err := net.SplitHostPort(s) + addr, err := netip.ParseAddrPort(s) if err != nil { return nil, err } - i, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return nil, err - } - return ChannelEndpoint(i), nil + return ChannelEndpoint(addr.Port()), nil } diff --git a/conn/conn.go b/conn/conn.go index 9cce9adea..35fb6b1b0 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -9,10 +9,11 @@ package conn import ( "errors" "fmt" - "net" "reflect" "runtime" "strings" + + "golang.zx2c4.com/go118/netip" ) // A ReceiveFunc receives a single inbound packet from the network. @@ -68,8 +69,8 @@ type Endpoint interface { SrcToString() string // returns the local source address (ip:port) DstToString() string // returns the destination address (ip:port) DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP + DstIP() netip.Addr + SrcIP() netip.Addr } var ( @@ -119,33 +120,3 @@ func (fn ReceiveFunc) PrettyName() string { } return name } - -func parseEndpoint(s string) (*net.UDPAddr, error) { - // ensure that the host is an IP address - - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { - // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just - // trying to make sure with a small sanity test that this is a real IP address and - // not something that's likely to incur DNS lookups. - host = host[:i] - } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) - } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err - } - ip4 := addr.IP.To4() - if ip4 != nil { - addr.IP = ip4 - } - return addr, err -} diff --git a/device/allowedips.go b/device/allowedips.go index c08399bbf..7a0b27560 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -12,6 +12,8 @@ import ( "net" "sync" "unsafe" + + "golang.zx2c4.com/go118/netip" ) type parentIndirection struct { @@ -26,7 +28,7 @@ type trieEntry struct { cidr uint8 bitAtByte uint8 bitAtShift uint8 - bits net.IP + bits []byte perPeerElem *list.Element } @@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 { return bits.ReverseBytes64(i) } -func commonBits(ip1 net.IP, ip2 net.IP) uint8 { +func commonBits(ip1, ip2 []byte) uint8 { size := len(ip1) if size == net.IPv4len { a := (*uint32)(unsafe.Pointer(&ip1[0])) @@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() { } } -func (node *trieEntry) choose(ip net.IP) byte { +func (node *trieEntry) choose(ip []byte) byte { return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } @@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() { node.parent.parentBit = nil } -func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) { +func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { parent = node if parent.cidr == cidr { @@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, return } -func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { +func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { if *trie.parentBit == nil { node := &trieEntry{ peer: peer, @@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { } } -func (node *trieEntry) lookup(ip net.IP) *Peer { +func (node *trieEntry) lookup(ip []byte) *Peer { var found *Peer size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { @@ -229,13 +231,14 @@ type AllowedIPs struct { mutex sync.RWMutex } -func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) { +func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { table.mutex.RLock() defer table.mutex.RUnlock() for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { node := elem.Value.(*trieEntry) - if !cb(node.bits, node.cidr) { + a, _ := netip.AddrFromSlice(node.bits) + if !cb(netip.PrefixFrom(a, int(node.cidr))) { return } } @@ -283,28 +286,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { } } -func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { +func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - switch len(ip) { - case net.IPv6len: - parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer) - case net.IPv4len: - parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer) - default: + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else { panic(errors.New("inserting unknown address type")) } } -func (table *AllowedIPs) Lookup(address []byte) *Peer { +func (table *AllowedIPs) Lookup(ip []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() - switch len(address) { + switch len(ip) { case net.IPv6len: - return table.IPv6.lookup(address) + return table.IPv6.lookup(ip) case net.IPv4len: - return table.IPv4.lookup(address) + return table.IPv4.lookup(ip) default: panic(errors.New("looking up unknown address type")) } diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 16de1704a..ff56fe6a9 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -10,6 +10,8 @@ import ( "net" "sort" "testing" + + "golang.zx2c4.com/go118/netip" ) const ( @@ -93,14 +95,14 @@ func TestTrieRandom(t *testing.T) { rand.Read(addr4[:]) cidr := uint8(rand.Intn(32) + 1) index := rand.Intn(NumberOfPeers) - allowedIPs.Insert(addr4[:], cidr, peers[index]) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) slow4 = slow4.Insert(addr4[:], cidr, peers[index]) var addr6 [16]byte rand.Read(addr6[:]) cidr = uint8(rand.Intn(128) + 1) index = rand.Intn(NumberOfPeers) - allowedIPs.Insert(addr6[:], cidr, peers[index]) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) slow6 = slow6.Insert(addr6[:], cidr, peers[index]) } diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 2059a8836..a27499782 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -9,6 +9,8 @@ import ( "math/rand" "net" "testing" + + "golang.zx2c4.com/go118/netip" ) type testPairCommonBits struct { @@ -98,7 +100,7 @@ func TestTrieIPv4(t *testing.T) { var allowedIPs AllowedIPs insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { - allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d byte) { @@ -208,7 +210,7 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - allowedIPs.Insert(addr, cidr, peer) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d uint32) { diff --git a/device/device_test.go b/device/device_test.go index 29daeb9c3..84221bed1 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -11,7 +11,6 @@ import ( "fmt" "io" "math/rand" - "net" "runtime" "runtime/pprof" "sync" @@ -19,6 +18,7 @@ import ( "testing" "time" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn/bindtest" "golang.zx2c4.com/wireguard/tun/tuntest" @@ -96,7 +96,7 @@ type testPair [2]testPeer type testPeer struct { tun *tuntest.ChannelTUN dev *Device - ip net.IP + ip netip.Addr } type SendDirection bool @@ -159,7 +159,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { for i := range pair { p := &pair[i] p.tun = tuntest.NewChannelTUN() - p.ip = net.IPv4(1, 0, 0, byte(i+1)) + p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)}) level := LogLevelVerbose if _, ok := tb.(*testing.B); ok && !testing.Verbose() { level = LogLevelError diff --git a/device/endpoint_test.go b/device/endpoint_test.go index 57c361ca6..f1ae47e34 100644 --- a/device/endpoint_test.go +++ b/device/endpoint_test.go @@ -7,47 +7,44 @@ package device import ( "math/rand" - "net" + + "golang.zx2c4.com/go118/netip" ) type DummyEndpoint struct { - src [16]byte - dst [16]byte + src, dst netip.Addr } func CreateDummyEndpoint() (*DummyEndpoint, error) { - var end DummyEndpoint - if _, err := rand.Read(end.src[:]); err != nil { + var src, dst [16]byte + if _, err := rand.Read(src[:]); err != nil { return nil, err } - _, err := rand.Read(end.dst[:]) - return &end, err + _, err := rand.Read(dst[:]) + return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err } func (e *DummyEndpoint) ClearSrc() {} func (e *DummyEndpoint) SrcToString() string { - var addr net.UDPAddr - addr.IP = e.SrcIP() - addr.Port = 1000 - return addr.String() + return netip.AddrPortFrom(e.SrcIP(), 1000).String() } func (e *DummyEndpoint) DstToString() string { - var addr net.UDPAddr - addr.IP = e.DstIP() - addr.Port = 1000 - return addr.String() + return netip.AddrPortFrom(e.DstIP(), 1000).String() } -func (e *DummyEndpoint) SrcToBytes() []byte { - return e.src[:] +func (e *DummyEndpoint) DstToBytes() []byte { + out := e.DstIP().AsSlice() + out = append(out, byte(1000&0xff)) + out = append(out, byte((1000>>8)&0xff)) + return out } -func (e *DummyEndpoint) DstIP() net.IP { - return e.dst[:] +func (e *DummyEndpoint) DstIP() netip.Addr { + return e.dst } -func (e *DummyEndpoint) SrcIP() net.IP { - return e.src[:] +func (e *DummyEndpoint) SrcIP() netip.Addr { + return e.src } diff --git a/device/receive.go b/device/receive.go index 58574810f..cc3449801 100644 --- a/device/receive.go +++ b/device/receive.go @@ -17,7 +17,6 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/conn" ) diff --git a/device/uapi.go b/device/uapi.go index 2306183a6..98e831104 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "time" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/ipc" ) @@ -121,8 +122,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) - device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool { - sendf("allowed_ip=%s/%d", ip.String(), cidr) + device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) return true }) } @@ -374,16 +375,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error case "allowed_ip": device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) - - _, network, err := net.ParseCIDR(value) + prefix, err := netip.ParsePrefix(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) } if peer.dummy { return nil } - ones, _ := network.Mask.Size() - device.allowedips.Insert(network.IP, uint8(ones), peer.Peer) + device.allowedips.Insert(prefix, peer.Peer) case "protocol_version": if value != "1" { diff --git a/go.mod b/go.mod index 856bb6c21..b51096070 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,9 @@ module golang.zx2c4.com/wireguard go 1.17 require ( - golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 - golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 - golang.org/x/sys v0.0.0-20211103235746-7861aae1554b + golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa + golang.org/x/net v0.0.0-20211111083644-e5c967477495 + golang.org/x/sys v0.0.0-20211110154304-99a53858aa08 + golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 ) diff --git a/go.sum b/go.sum index 37fe06796..78f736705 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,19 @@ -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa h1:idItI2DDfCokpg0N51B2VtiLdJ4vAuXC9fnCb2gACo4= +golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 h1:VrJZAjbekhoRn7n5FBujY31gboH+iB3pdLxn3gE9FjU= -golang.org/x/net v0.0.0-20211101193420-4a448f8816b3/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211111083644-e5c967477495 h1:cjxxlQm6d4kYbhpZ2ghvmI8xnq0AG+jXmzrhzfkyu5A= +golang.org/x/net v0.0.0-20211111083644-e5c967477495/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211110154304-99a53858aa08 h1:WecRHqgE09JBkh/584XIE6PMz5KKE/vER4izNUi30AQ= +golang.org/x/sys v0.0.0-20211110154304-99a53858aa08/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d h1:9+v0G0naRhLPOJEeJOL6NuXTtAHHwmkyZlgQJ0XcQ8I= +golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go index 2f7aa2ae0..8e78d5e46 100644 --- a/ratelimiter/ratelimiter.go +++ b/ratelimiter/ratelimiter.go @@ -6,9 +6,10 @@ package ratelimiter import ( - "net" "sync" "time" + + "golang.zx2c4.com/go118/netip" ) const ( @@ -30,8 +31,7 @@ type Ratelimiter struct { timeNow func() time.Time stopReset chan struct{} // send to reset, close to stop - tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry - tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry + table map[netip.Addr]*RatelimiterEntry } func (rate *Ratelimiter) Close() { @@ -57,8 +57,7 @@ func (rate *Ratelimiter) Init() { } rate.stopReset = make(chan struct{}) - rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) - rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) + rate.table = make(map[netip.Addr]*RatelimiterEntry) stopReset := rate.stopReset // store in case Init is called again. @@ -87,71 +86,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) { rate.mu.Lock() defer rate.mu.Unlock() - for key, entry := range rate.tableIPv4 { + for key, entry := range rate.table { entry.mu.Lock() if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv4, key) + delete(rate.table, key) } entry.mu.Unlock() } - for key, entry := range rate.tableIPv6 { - entry.mu.Lock() - if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.mu.Unlock() - } - - return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 + return len(rate.table) == 0 } -func (rate *Ratelimiter) Allow(ip net.IP) bool { +func (rate *Ratelimiter) Allow(ip netip.Addr) bool { var entry *RatelimiterEntry - var keyIPv4 [net.IPv4len]byte - var keyIPv6 [net.IPv6len]byte - // lookup entry - - IPv4 := ip.To4() - IPv6 := ip.To16() - rate.mu.RLock() - - if IPv4 != nil { - copy(keyIPv4[:], IPv4) - entry = rate.tableIPv4[keyIPv4] - } else { - copy(keyIPv6[:], IPv6) - entry = rate.tableIPv6[keyIPv6] - } - + entry = rate.table[ip] rate.mu.RUnlock() // make new entry if not found - if entry == nil { entry = new(RatelimiterEntry) entry.tokens = maxTokens - packetCost entry.lastTime = rate.timeNow() rate.mu.Lock() - if IPv4 != nil { - rate.tableIPv4[keyIPv4] = entry - if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { - rate.stopReset <- struct{}{} - } - } else { - rate.tableIPv6[keyIPv6] = entry - if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 { - rate.stopReset <- struct{}{} - } + rate.table[ip] = entry + if len(rate.table) == 1 { + rate.stopReset <- struct{}{} } rate.mu.Unlock() return true } // add tokens to entry - entry.mu.Lock() now := rate.timeNow() entry.tokens += now.Sub(entry.lastTime).Nanoseconds() @@ -161,7 +128,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { } // subtract cost of packet - if entry.tokens > packetCost { entry.tokens -= packetCost entry.mu.Unlock() diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go index f231fe5b7..3e06ff778 100644 --- a/ratelimiter/ratelimiter_test.go +++ b/ratelimiter/ratelimiter_test.go @@ -6,9 +6,10 @@ package ratelimiter import ( - "net" "testing" "time" + + "golang.zx2c4.com/go118/netip" ) type result struct { @@ -71,21 +72,21 @@ func TestRatelimiter(t *testing.T) { text: "packet following 2 packet burst", }) - ips := []net.IP{ - net.ParseIP("127.0.0.1"), - net.ParseIP("192.168.1.1"), - net.ParseIP("172.167.2.3"), - net.ParseIP("97.231.252.215"), - net.ParseIP("248.97.91.167"), - net.ParseIP("188.208.233.47"), - net.ParseIP("104.2.183.179"), - net.ParseIP("72.129.46.120"), - net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), - net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), - net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), - net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), - net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), - net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), + ips := []netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("172.167.2.3"), + netip.MustParseAddr("97.231.252.215"), + netip.MustParseAddr("248.97.91.167"), + netip.MustParseAddr("188.208.233.47"), + netip.MustParseAddr("104.2.183.179"), + netip.MustParseAddr("72.129.46.120"), + netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), + netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"), + netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), + netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), + netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), + netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), } now := time.Now() diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go index 6ac28596b..b39b45316 100644 --- a/tun/netstack/examples/http_client.go +++ b/tun/netstack/examples/http_client.go @@ -1,4 +1,5 @@ //go:build ignore +// +build ignore /* SPDX-License-Identifier: MIT * @@ -10,9 +11,9 @@ package main import ( "io" "log" - "net" "net/http" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" @@ -20,8 +21,8 @@ import ( func main() { tun, tnet, err := netstack.CreateNetTUN( - []net.IP{net.ParseIP("192.168.4.29")}, - []net.IP{net.ParseIP("8.8.8.8")}, + []netip.Addr{netip.MustParseAddr("192.168.4.29")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 1420) if err != nil { log.Panic(err) diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go index 577c6ea4c..40f780447 100644 --- a/tun/netstack/examples/http_server.go +++ b/tun/netstack/examples/http_server.go @@ -1,4 +1,5 @@ //go:build ignore +// +build ignore /* SPDX-License-Identifier: MIT * @@ -13,6 +14,7 @@ import ( "net" "net/http" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" @@ -20,8 +22,8 @@ import ( func main() { tun, tnet, err := netstack.CreateNetTUN( - []net.IP{net.ParseIP("192.168.4.29")}, - []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")}, + []netip.Addr{netip.MustParseAddr("192.168.4.29")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}, 1420, ) if err != nil { diff --git a/tun/netstack/go.mod b/tun/netstack/go.mod index 8db9f4bcd..46b57bab1 100644 --- a/tun/netstack/go.mod +++ b/tun/netstack/go.mod @@ -6,6 +6,7 @@ require ( golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect + golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6 ) diff --git a/tun/netstack/go.sum b/tun/netstack/go.sum index 78c025c54..01bfbc70c 100644 --- a/tun/netstack/go.sum +++ b/tun/netstack/go.sum @@ -805,6 +805,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5 h1:mV4w4F7AtWXoDNkko9odoTdWpNwyDh8jx+S1fOZKDLg= +golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= +golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E= +golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw= golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 24d0835dd..f1c03f48c 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -18,6 +18,7 @@ import ( "strings" "time" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" @@ -38,7 +39,7 @@ type netTun struct { events chan tun.Event incomingPacket chan buffer.VectorisedView mtu int - dnsServers []net.IP + dnsServers []netip.Addr hasV4, hasV6 bool } type endpoint netTun @@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType { func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } -func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) { +func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, @@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) } for _, ip := range localAddresses { - if ip4 := ip.To4(); ip4 != nil { - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.Address(ip4).WithPrefix(), - } - tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr) - } + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + if ip.Is4() { dev.hasV4 = true - } else { - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.Address(ip).WithPrefix(), - } - tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) - } + } else if ip.Is6() { dev.hasV6 = true } } @@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) { return tun.mtu, nil } -func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { - if ip4 := ip.To4(); ip4 != nil { - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(ip4), - Port: uint16(port), - }, ipv4.ProtocolNumber +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber } else { - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(ip), - Port: uint16(port), - }, ipv6.ProtocolNumber + protoNumber = ipv6.ProtocolNumber } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.Address(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, net.stack, fa, pn) } func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.DialContextTCP(ctx, net.stack, fa, pn) + return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) +} + +func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialTCP(net.stack, fa, pn) } func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.DialTCPAddrPort(netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.DialTCP(net.stack, fa, pn) + return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) +} + +func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { + fa, pn := convertToFullAddr(addr) + return gonet.ListenTCP(net.stack, fa, pn) } func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.ListenTCPAddrPort(netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.ListenTCP(net.stack, fa, pn) + return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) } -func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { +func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { var lfa, rfa *tcpip.FullAddress var pn tcpip.NetworkProtocolNumber - if laddr != nil { + if laddr.IsValid() || laddr.Port() > 0 { var addr tcpip.FullAddress - addr, pn = convertToFullAddr(laddr.IP, laddr.Port) + addr, pn = convertToFullAddr(laddr) lfa = &addr } - if raddr != nil { + if raddr.IsValid() || raddr.Port() > 0 { var addr tcpip.FullAddress - addr, pn = convertToFullAddr(raddr.IP, raddr.Port) + addr, pn = convertToFullAddr(raddr) rfa = &addr } return gonet.DialUDP(net.stack, lfa, rfa, pn) } +func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { + var la, ra netip.AddrPort + if laddr != nil { + la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port)) + } + if raddr != nil { + ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port)) + } + return net.DialUDPAddrPort(la, ra) +} + var ( errNoSuchHost = errors.New("no such host") errLameReferral = errors.New("lame referral") @@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by return p, h, nil } -func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { +func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { q.Class = dnsmessage.ClassINET id, udpReq, tcpReq, err := newRequest(q) if err != nil { @@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest var c net.Conn var err error if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53}) + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) } else { - c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53}) + c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) } if err != nil { @@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, zlen = zidx } } - if ip := net.ParseIP(host[:zlen]); ip != nil { - return []string{host[:zlen]}, nil + if ip, err := netip.ParseAddr(host[:zlen]); err == nil { + return []string{ip.String()}, nil } if !isDomainName(host) { @@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, server string error } - var addrsV4, addrsV6 []net.IP + var addrsV4, addrsV6 []netip.Addr lanes := 0 if tnet.hasV4 { lanes++ @@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } break loop } - addrsV4 = append(addrsV4, net.IP(a.A[:])) + addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) case dnsmessage.TypeAAAA: aaaa, err := result.p.AAAAResource() @@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } break loop } - addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:])) + addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) default: if err := result.p.SkipAnswer(); err != nil { @@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } } // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled - var addrs []net.IP + var addrs []netip.Addr if tnet.hasV6 { addrs = append(addrsV6, addrsV4...) } else { @@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. if err != nil { return nil, &net.OpError{Op: "dial", Err: err} } - var addrs []net.IP + var addrs []netip.AddrPort for _, addr := range allAddr { - if strings.IndexByte(addr, ':') != -1 && acceptV6 { - addrs = append(addrs, net.ParseIP(addr)) - } else if strings.IndexByte(addr, '.') != -1 && acceptV4 { - addrs = append(addrs, net.ParseIP(addr)) + ip, err := netip.ParseAddr(addr) + if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { + addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) } } if len(addrs) == 0 && len(allAddr) != 0 { @@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. var c net.Conn if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port}) + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) } else { - c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port}) + c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) } if err == nil { return c, nil diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index d89db7187..bdf0467b4 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -8,13 +8,13 @@ package tuntest import ( "encoding/binary" "io" - "net" "os" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/tun" ) -func Ping(dst, src net.IP) []byte { +func Ping(dst, src netip.Addr) []byte { localPort := uint16(1337) seq := uint16(0) @@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 { return ^uint16(v) } -func genICMPv4(payload []byte, dst, src net.IP) []byte { +func genICMPv4(payload []byte, dst, src netip.Addr) []byte { const ( icmpv4ProtocolNumber = 1 icmpv4Echo = 8 @@ -70,8 +70,8 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte { binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) ip[8] = ttl ip[9] = icmpv4ProtocolNumber - copy(ip[12:], src.To4()) - copy(ip[16:], dst.To4()) + copy(ip[12:], src.AsSlice()) + copy(ip[16:], dst.AsSlice()) chksum = ^checksum(ip[:], 0) binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)