diff --git a/net/multi_listen.go b/net/multi_listen.go new file mode 100644 index 00000000..7cb7795b --- /dev/null +++ b/net/multi_listen.go @@ -0,0 +1,195 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "context" + "fmt" + "net" + "sync" +) + +// connErrPair pairs conn and error which is returned by accept on sub-listeners. +type connErrPair struct { + conn net.Conn + err error +} + +// multiListener implements net.Listener +type multiListener struct { + listeners []net.Listener + wg sync.WaitGroup + + // connCh passes accepted connections, from child listeners to parent. + connCh chan connErrPair + // stopCh communicates from parent to child listeners. + stopCh chan struct{} +} + +// compile time check to ensure *multiListener implements net.Listener +var _ net.Listener = &multiListener{} + +// MultiListen returns net.Listener which can listen on and accept connections for +// the given network on multiple addresses. Internally it uses stdlib to create +// sub-listener and multiplexes connection requests using go-routines. +// The network must be "tcp", "tcp4" or "tcp6". +// It follows the semantics of net.Listen that primarily means: +// 1. If the host is an unspecified/zero IP address with "tcp" network, MultiListen +// listens on all available unicast and anycast IP addresses of the local system. +// 2. Use "tcp4" or "tcp6" to exclusively listen on IPv4 or IPv6 family, respectively. +// 3. The host can accept names (e.g, localhost) and it will create a listener for at +// most one of the host's IP. +func MultiListen(ctx context.Context, network string, addrs ...string) (net.Listener, error) { + var lc net.ListenConfig + return multiListen( + ctx, + network, + addrs, + func(ctx context.Context, network, address string) (net.Listener, error) { + return lc.Listen(ctx, network, address) + }) +} + +// multiListen implements MultiListen by consuming stdlib functions as dependency allowing +// mocking for unit-testing. +func multiListen( + ctx context.Context, + network string, + addrs []string, + listenFunc func(ctx context.Context, network, address string) (net.Listener, error), +) (net.Listener, error) { + if !(network == "tcp" || network == "tcp4" || network == "tcp6") { + return nil, fmt.Errorf("network %q not supported", network) + } + if len(addrs) == 0 { + return nil, fmt.Errorf("no address provided to listen on") + } + + ml := &multiListener{ + connCh: make(chan connErrPair), + stopCh: make(chan struct{}), + } + for _, addr := range addrs { + l, err := listenFunc(ctx, network, addr) + if err != nil { + // close all the sub-listeners and exit + _ = ml.Close() + return nil, err + } + ml.listeners = append(ml.listeners, l) + } + + for _, l := range ml.listeners { + ml.wg.Add(1) + go func(l net.Listener) { + defer ml.wg.Done() + for { + // Accept() is blocking, unless ml.Close() is called, in which + // case it will return immediately with an error. + conn, err := l.Accept() + // This assumes that ANY error from Accept() will terminate the + // sub-listener. We could maybe be more precise, but it + // doesn't seem necessary. + terminate := err != nil + + select { + case ml.connCh <- connErrPair{conn: conn, err: err}: + case <-ml.stopCh: + // In case we accepted a connection AND were stopped, and + // this select-case was chosen, just throw away the + // connection. This avoids potentially blocking on connCh + // or leaking a connection. + if conn != nil { + _ = conn.Close() + } + terminate = true + } + // Make sure we don't loop on Accept() returning an error and + // the select choosing the channel case. + if terminate { + return + } + } + }(l) + } + return ml, nil +} + +// Accept implements net.Listener. It waits for and returns a connection from +// any of the sub-listener. +func (ml *multiListener) Accept() (net.Conn, error) { + // wait for any sub-listener to enqueue an accepted connection + connErr, ok := <-ml.connCh + if !ok { + // The channel will be closed only when Close() is called on the + // multiListener. Closing of this channel implies that all + // sub-listeners are also closed, which causes a "use of closed + // network connection" error on their Accept() calls. We return the + // same error for multiListener.Accept() if multiListener.Close() + // has already been called. + return nil, fmt.Errorf("use of closed network connection") + } + return connErr.conn, connErr.err +} + +// Close implements net.Listener. It will close all sub-listeners and wait for +// the go-routines to exit. +func (ml *multiListener) Close() error { + // Make sure this can be called repeatedly without explosions. + select { + case <-ml.stopCh: + return fmt.Errorf("use of closed network connection") + default: + } + + // Tell all sub-listeners to stop. + close(ml.stopCh) + + // Closing the listeners causes Accept() to immediately return an error in + // the sub-listener go-routines. + for _, l := range ml.listeners { + _ = l.Close() + } + + // Wait for all the sub-listener go-routines to exit. + ml.wg.Wait() + close(ml.connCh) + + // Drain any already-queued connections. + for connErr := range ml.connCh { + if connErr.conn != nil { + _ = connErr.conn.Close() + } + } + return nil +} + +// Addr is an implementation of the net.Listener interface. It always returns +// the address of the first listener. Callers should use conn.LocalAddr() to +// obtain the actual local address of the sub-listener. +func (ml *multiListener) Addr() net.Addr { + return ml.listeners[0].Addr() +} + +// Addrs is like Addr, but returns the address for all registered listeners. +func (ml *multiListener) Addrs() []net.Addr { + var ret []net.Addr + for _, l := range ml.listeners { + ret = append(ret, l.Addr()) + } + return ret +} diff --git a/net/multi_listen_test.go b/net/multi_listen_test.go new file mode 100644 index 00000000..dae7b479 --- /dev/null +++ b/net/multi_listen_test.go @@ -0,0 +1,498 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "context" + "fmt" + "net" + "strconv" + "sync/atomic" + "testing" + "time" +) + +type fakeCon struct { + remoteAddr net.Addr +} + +func (f *fakeCon) Read(_ []byte) (n int, err error) { + return 0, nil +} + +func (f *fakeCon) Write(_ []byte) (n int, err error) { + return 0, nil +} + +func (f *fakeCon) Close() error { + return nil +} + +func (f *fakeCon) LocalAddr() net.Addr { + return nil +} + +func (f *fakeCon) RemoteAddr() net.Addr { + return f.remoteAddr +} + +func (f *fakeCon) SetDeadline(_ time.Time) error { + return nil +} + +func (f *fakeCon) SetReadDeadline(_ time.Time) error { + return nil +} + +func (f *fakeCon) SetWriteDeadline(_ time.Time) error { + return nil +} + +var _ net.Conn = &fakeCon{} + +type fakeListener struct { + addr net.Addr + index int + err error + closed atomic.Bool + connErrPairs []connErrPair +} + +func (f *fakeListener) Accept() (net.Conn, error) { + if f.index < len(f.connErrPairs) { + index := f.index + connErr := f.connErrPairs[index] + f.index++ + return connErr.conn, connErr.err + } + for { + if f.closed.Load() { + return nil, fmt.Errorf("use of closed network connection") + } + } +} + +func (f *fakeListener) Close() error { + f.closed.Store(true) + return nil +} + +func (f *fakeListener) Addr() net.Addr { + return f.addr +} + +var _ net.Listener = &fakeListener{} + +func listenFuncFactory(listeners []*fakeListener) func(_ context.Context, network string, address string) (net.Listener, error) { + index := 0 + return func(_ context.Context, network string, address string) (net.Listener, error) { + if index < len(listeners) { + host, portStr, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, err + } + listener := listeners[index] + addr := &net.TCPAddr{ + IP: ParseIPSloppy(host), + Port: port, + } + if err != nil { + return nil, err + } + listener.addr = addr + index++ + + if listener.err != nil { + return nil, listener.err + } + return listener, nil + } + return nil, nil + } +} + +func TestMultiListen(t *testing.T) { + testCases := []struct { + name string + network string + addrs []string + fakeListeners []*fakeListener + errString string + }{ + { + name: "unsupported network", + network: "udp", + errString: "network \"udp\" not supported", + }, + { + name: "no host", + network: "tcp", + errString: "no address provided to listen on", + }, + { + name: "valid", + network: "tcp", + addrs: []string{"127.0.0.1:12345"}, + fakeListeners: []*fakeListener{{connErrPairs: []connErrPair{}}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.TODO() + ml, err := multiListen(ctx, tc.network, tc.addrs, listenFuncFactory(tc.fakeListeners)) + + if tc.errString != "" { + assertError(t, tc.errString, err) + } else { + assertNoError(t, err) + } + if ml != nil { + err = ml.Close() + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + } + }) + } +} + +func TestMultiListen_Addr(t *testing.T) { + ctx := context.TODO() + ml, err := multiListen(ctx, "tcp", []string{"10.10.10.10:5000", "192.168.1.10:5000", "127.0.0.1:5000"}, listenFuncFactory( + []*fakeListener{{}, {}, {}}, + )) + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + + if ml.Addr().String() != "10.10.10.10:5000" { + t.Errorf("Expected '10.10.10.10:5000' but got '%s'", ml.Addr().String()) + } + + err = ml.Close() + if err != nil { + t.Errorf("Did not expect error: %v", err) + } +} + +func TestMultiListen_Addrs(t *testing.T) { + ctx := context.TODO() + addrs := []string{"10.10.10.10:5000", "192.168.1.10:5000", "127.0.0.1:5000"} + ml, err := multiListen(ctx, "tcp", addrs, listenFuncFactory( + []*fakeListener{{}, {}, {}}, + )) + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + + gotAddrs := ml.(*multiListener).Addrs() + for i := range gotAddrs { + if gotAddrs[i].String() != addrs[i] { + t.Errorf("expected %q; got %q", addrs[i], gotAddrs[i].String()) + } + + } + + err = ml.Close() + if err != nil { + t.Errorf("Did not expect error: %v", err) + } +} + +func TestMultiListen_Close(t *testing.T) { + testCases := []struct { + name string + addrs []string + runner func(listener net.Listener, acceptCalls int) error + fakeListeners []*fakeListener + acceptCalls int + errString string + }{ + { + name: "close", + addrs: []string{"10.10.10.10:5000", "192.168.1.10:5000", "127.0.0.1:5000"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{}, {}, {}}, + }, + { + name: "close with pending connections", + addrs: []string{"10.10.10.10:5001", "192.168.1.10:5002", "127.0.0.1:5003"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("10.10.10.10"), Port: 50001}}, + }}}, { + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("192.168.1.10"), Port: 50002}}, + }, + }}, { + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 50003}}, + }}, + }}, + }, + { + name: "close with no pending connections", + addrs: []string{"10.10.10.10:3001", "192.168.1.10:3002", "127.0.0.1:3003"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("10.10.10.10"), Port: 50001}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("192.168.1.10"), Port: 50002}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 50003}}}, + }, + }}, + acceptCalls: 3, + }, + { + name: "close on close", + addrs: []string{"10.10.10.10:5000", "192.168.1.10:5000", "127.0.0.1:5000"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + + err = ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{}, {}, {}}, + errString: "use of closed network connection", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.TODO() + ml, err := multiListen(ctx, "tcp", tc.addrs, listenFuncFactory(tc.fakeListeners)) + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + err = tc.runner(ml, tc.acceptCalls) + if tc.errString != "" { + assertError(t, tc.errString, err) + } else { + assertNoError(t, err) + } + + for _, f := range tc.fakeListeners { + if !f.closed.Load() { + t.Errorf("Expeted sub-listener to be closed") + } + } + }) + } +} + +func TestMultiListen_Accept(t *testing.T) { + testCases := []struct { + name string + addrs []string + runner func(listener net.Listener, acceptCalls int) error + fakeListeners []*fakeListener + acceptCalls int + errString string + }{ + { + name: "accept all connections", + addrs: []string{"10.10.10.10:3000", "192.168.1.103:4000", "127.0.0.1:5000"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("10.10.10.10"), Port: 50001}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("192.168.1.10"), Port: 50002}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 50003}}}, + }, + }}, + acceptCalls: 3, + }, + { + name: "accept some connections", + addrs: []string{"10.10.10.10:3000", "192.168.1.103:4000", "172.16.20.10:5000", "127.0.0.1:6000"}, + runner: func(ml net.Listener, acceptCalls int) error { + + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("10.10.10.10"), Port: 30001}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("192.168.1.10"), Port: 40001}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("192.168.1.10"), Port: 40002}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("172.16.20.10"), Port: 50001}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("172.16.20.10"), Port: 50002}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("172.16.20.10"), Port: 50003}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("172.16.20.10"), Port: 50004}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 60001}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 60002}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 60003}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 60004}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 60005}}}, + }, + }}, + acceptCalls: 3, + }, + { + name: "accept on closed listener", + addrs: []string{"10.10.10.10:3001", "192.168.1.10:3002", "127.0.0.1:3003"}, + runner: func(ml net.Listener, acceptCalls int) error { + err := ml.Close() + if err != nil { + return err + } + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + return nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("10.10.10.10"), Port: 50001}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("192.168.1.10"), Port: 50002}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: ParseIPSloppy("127.0.0.1"), Port: 50003}}}, + }, + }}, + acceptCalls: 1, + errString: "use of closed network connection", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.TODO() + ml, err := multiListen(ctx, "tcp", tc.addrs, listenFuncFactory(tc.fakeListeners)) + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + + err = tc.runner(ml, tc.acceptCalls) + if tc.errString != "" { + assertError(t, tc.errString, err) + } else { + assertNoError(t, err) + } + }) + } +} + +func assertError(t *testing.T, errString string, err error) { + if err == nil { + t.Errorf("Expected error '%s' but got none", errString) + } + if err.Error() != errString { + t.Errorf("Expected error '%s' but got '%s'", errString, err.Error()) + } +} + +func assertNoError(t *testing.T, err error) { + if err != nil { + t.Errorf("Did not expect error: %v", err) + } +}