Skip to content

Commit

Permalink
net: add multi listener impl for net.Listener
Browse files Browse the repository at this point in the history
This adds an implementation of net.Listener which listens on and
accepts connections from multiple addresses.

Signed-off-by: Daman Arora <aroradaman@gmail.com>
  • Loading branch information
aroradaman authored and Daman Arora committed May 25, 2024
1 parent fe8a2dd commit 6bd4edf
Show file tree
Hide file tree
Showing 3 changed files with 325 additions and 0 deletions.
138 changes: 138 additions & 0 deletions net/multi_listen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package net

import (
"fmt"
"net"
"syscall"
)

// multiListener implements net.Listener and uses multiplexing to listen to and accept
// TCP connections from multiple addresses.
type multiListener struct {
latestAcceptedFDIndex int
fds []int
addrs []net.Addr
stopCh chan struct{}
}

// compile time check to ensure *multiListener implements net.Listener.
var _ net.Listener = &multiListener{}

// NewMultiListener returns *multiListener as net.Listener allowing consumers to
// listen for TCP connections on multiple addresses.
func NewMultiListener(addresses []string) (net.Listener, error) {
ml := &multiListener{
stopCh: make(chan struct{}),
}
for _, address := range addresses {
fd, addr, err := createBindAndListen(address)
if err != nil {
return nil, err
}
ml.fds = append(ml.fds, fd)
ml.addrs = append(ml.addrs, addr)
}
return ml, nil
}

// Accept is part of net.Listener interface.
func (ml *multiListener) Accept() (net.Conn, error) {
return ml.accept()

Check failure on line 56 in net/multi_listen.go

View workflow job for this annotation

GitHub Actions / build (1.21.x, windows-latest)

ml.accept undefined (type *multiListener has no field or method accept, but does have Accept)
}

// Close is part of net.Listener interface.
func (ml *multiListener) Close() error {
close(ml.stopCh)
for _, fd := range ml.fds {
_ = syscall.Close(fd)

Check failure on line 63 in net/multi_listen.go

View workflow job for this annotation

GitHub Actions / build (1.21.x, windows-latest)

cannot use fd (variable of type int) as syscall.Handle value in argument to syscall.Close
}
return nil
}

// Addr is part of net.Listener interface.
func (ml *multiListener) Addr() net.Addr {
return ml.addrs[ml.latestAcceptedFDIndex]
}

// createBindAndListen creates a TCP socket, binds it to the specified address, and starts listening on it.
func createBindAndListen(address string) (int, net.Addr, error) {
host, _, err := net.SplitHostPort(address)
if err != nil {
return -1, nil, err
}

ipFamily := IPFamilyOf(ParseIPSloppy(host))
var network string
var domain int
switch ipFamily {
case IPv4:
network = "tcp4"
domain = syscall.AF_INET
case IPv6:
network = "tcp6"
domain = syscall.AF_INET6
default:
return -1, nil, fmt.Errorf("failed to identify ip family of host '%s'", host)

}

// resolve tcp addr
addr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return -1, nil, err
}

// create socket
fd, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0)
if err != nil {
return -1, nil, err
}

// define socket address for bind
var sockAddr syscall.Sockaddr
if ipFamily == IPv4 {
var ipBytes [4]byte
copy(ipBytes[:], addr.IP.To4())
sockAddr = &syscall.SockaddrInet4{
Addr: ipBytes,
Port: addr.Port,
}
} else {
var ipBytes [16]byte
copy(ipBytes[:], addr.IP.To16())
sockAddr = &syscall.SockaddrInet6{
Addr: ipBytes,
Port: addr.Port,
}
}

// bind socket to specified addr
if err = syscall.Bind(fd, sockAddr); err != nil {
_ = syscall.Close(fd)
return -1, nil, err
}

// start listening on socket
if err = syscall.Listen(fd, syscall.SOMAXCONN); err != nil {
_ = syscall.Close(fd)
return -1, nil, err
}

return fd, addr, nil

Check failure on line 137 in net/multi_listen.go

View workflow job for this annotation

GitHub Actions / build (1.21.x, windows-latest)

cannot use fd (variable of type syscall.Handle) as int value in return statement
}
94 changes: 94 additions & 0 deletions net/multi_listen_darwin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
//go:build darwin
// +build darwin

/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package net

import (
"fmt"
"net"
"os"
"syscall"
)

// Accept is part of net.Listener interface.
func (ml *multiListener) accept() (net.Conn, error) {
for {
readFds := &syscall.FdSet{}
maxfd := 0

for _, fd := range ml.fds {
if fd > maxfd {
maxfd = fd
}
addFDToFDSet(fd, readFds)
}

// wait for any of the sockets to be ready for accepting new connection
timeout := syscall.Timeval{Sec: 1, Usec: 0}
err := syscall.Select(maxfd+1, readFds, nil, nil, &timeout)
if err != nil {
return nil, err
}

for i, fd := range ml.fds {
if isFDInFDSet(fd, readFds) {
conn, err := acceptConnection(fd)
if err != nil {
return nil, err
}
ml.latestAcceptedFDIndex = i
return conn, nil
}
}

select {
case <-ml.stopCh:
return nil, fmt.Errorf("multiListener closed")
default:
continue
}
}
}

// addFDToFDSet adds fd to the given fd set
func addFDToFDSet(fd int, p *syscall.FdSet) {
mask := 1 << (uint(fd) % syscall.FD_SETSIZE)
p.Bits[fd/syscall.FD_SETSIZE] |= int32(mask)
}

// isFDInFDSet returns true if fd is in fd set, false otherwise
func isFDInFDSet(fd int, p *syscall.FdSet) bool {
mask := 1 << (uint(fd) % syscall.FD_SETSIZE)
return p.Bits[fd/syscall.FD_SETSIZE]&int32(mask) != 0
}

// acceptConnection accepts connection and returns remote connection object
func acceptConnection(fd int) (net.Conn, error) {
connFD, _, err := syscall.Accept(fd)
if err != nil {
return nil, err
}

conn, err := net.FileConn(os.NewFile(uintptr(connFD), fmt.Sprintf("fd %d", connFD)))
if err != nil {
_ = syscall.Close(connFD)
return nil, err
}
return conn, nil
}
93 changes: 93 additions & 0 deletions net/multi_listen_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//go:build linux
// +build linux

/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package net

import (
"fmt"
"net"
"os"
"syscall"
)

func (ml *multiListener) accept() (net.Conn, error) {
for {
readFds := &syscall.FdSet{}
maxfd := 0

for _, fd := range ml.fds {
if fd > maxfd {
maxfd = fd
}
addFDToFDSet(fd, readFds)
}

// wait for any of the sockets to be ready for accepting new connection
timeout := syscall.Timeval{Sec: 1, Usec: 0}
n, err := syscall.Select(maxfd+1, readFds, nil, nil, &timeout)
if err != nil {
return nil, err
}
if n == 0 {
select {
case <-ml.stopCh:
return nil, fmt.Errorf("multiListener closed")
default:
continue
}
}
for i, fd := range ml.fds {
if isFDInFDSet(fd, readFds) {
conn, err := acceptConnection(fd)
if err != nil {
return nil, err
}
ml.latestAcceptedFDIndex = i
return conn, nil
}
}
}
}

// addFDToFDSet adds fd to the given fd set
func addFDToFDSet(fd int, p *syscall.FdSet) {
mask := 1 << (uint(fd) % syscall.FD_SETSIZE)
p.Bits[fd/syscall.FD_SETSIZE] |= int64(mask)
}

// isFDInFDSet returns true if fd is in fd set, false otherwise
func isFDInFDSet(fd int, p *syscall.FdSet) bool {
mask := 1 << (uint(fd) % syscall.FD_SETSIZE)
return p.Bits[fd/syscall.FD_SETSIZE]&int64(mask) != 0
}

// acceptConnection accepts connection and returns remote connection object
func acceptConnection(fd int) (net.Conn, error) {
connFD, _, err := syscall.Accept(fd)
if err != nil {
return nil, err
}

conn, err := net.FileConn(os.NewFile(uintptr(connFD), fmt.Sprintf("fd %d", connFD)))
if err != nil {
_ = syscall.Close(connFD)
return nil, err
}
return conn, nil
}

0 comments on commit 6bd4edf

Please sign in to comment.