Skip to content

Commit

Permalink
make TestGetHostPortRange unit test deterministic
Browse files Browse the repository at this point in the history
  • Loading branch information
singholt committed Feb 27, 2023
1 parent 7600ecf commit cea7a56
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 30 deletions.
69 changes: 41 additions & 28 deletions agent/utils/ephemeral_ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"time"

"github.com/docker/go-connections/nat"
"github.com/pkg/errors"
)

// From https://www.kernel.org/doc/html/latest//networking/ip-sysctl.html#ip-variables
Expand All @@ -33,7 +34,8 @@ const (

var (
// Injection point for UTs
randIntFunc = rand.Intn
randIntFunc = rand.Intn
isPortAvailableFunc = isPortAvailable
// portLock is a mutex lock used to prevent two concurrent tasks to get the same host ports.
portLock sync.Mutex
)
Expand Down Expand Up @@ -132,34 +134,12 @@ func GetHostPortRange(numberOfPorts int, protocol string, dynamicHostPortRange s
func getHostPortRange(numberOfPorts, start, end int, protocol string) (string, int, error) {
var resultStartPort, resultEndPort, n int
for port := start; port <= end; port++ {
portStr := strconv.Itoa(port)
// check if port is available
if protocol == "tcp" {
// net.Listen announces on the local tcp network
ln, err := net.Listen(protocol, ":"+portStr)
// either port is unavailable or some error occurred while listening, we proceed to the next port
if err != nil {
continue
}
// let's close the listener first
err = ln.Close()
if err != nil {
continue
}
} else if protocol == "udp" {
// net.ListenPacket announces on the local udp network
ln, err := net.ListenPacket(protocol, ":"+portStr)
// either port is unavailable or some error occurred while listening, we proceed to the next port
if err != nil {
continue
}
// let's close the listener first
err = ln.Close()
if err != nil {
continue
}
isAvailable, err := isPortAvailableFunc(port, protocol)
if !isAvailable || err != nil {
// either port is unavailable or some error occurred while listening or closing the listener,
// we proceed to the next port
continue
}

// check if current port is contiguous relative to lastPort
if port-resultEndPort != 1 {
resultStartPort = port
Expand All @@ -182,3 +162,36 @@ func getHostPortRange(numberOfPorts, start, end int, protocol string) (string, i

return fmt.Sprintf("%d-%d", resultStartPort, resultEndPort), resultEndPort, nil
}

// isPortAvailable checks if a port is available
func isPortAvailable(port int, protocol string) (bool, error) {
portStr := strconv.Itoa(port)
switch protocol {
case "tcp":
// net.Listen announces on the local tcp network
ln, err := net.Listen(protocol, ":"+portStr)
if err != nil {
return false, err
}
// let's close the listener first
err = ln.Close()
if err != nil {
return false, err
}
return true, nil
case "udp":
// net.ListenPacket announces on the local udp network
ln, err := net.ListenPacket(protocol, ":"+portStr)
if err != nil {
return false, err
}
// let's close the listener first
err = ln.Close()
if err != nil {
return false, err
}
return true, nil
default:
return false, errors.New("invalid protocol")
}
}
26 changes: 24 additions & 2 deletions agent/utils/ephemeral_ports_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func TestGetHostPortRange(t *testing.T) {
testDynamicHostPortRange string
protocol string
expectedLastAssignedPort []int
isPortAvailableFunc func(port int, protocol string) (bool, error)
numberOfRequests int
expectedError error
}{
Expand All @@ -88,6 +89,7 @@ func TestGetHostPortRange(t *testing.T) {
testDynamicHostPortRange: "40001-40080",
protocol: testTCPProtocol,
expectedLastAssignedPort: []int{40010},
isPortAvailableFunc: func(port int, protocol string) (bool, error) { return true, nil },
numberOfRequests: 1,
expectedError: nil,
},
Expand All @@ -97,6 +99,7 @@ func TestGetHostPortRange(t *testing.T) {
testDynamicHostPortRange: "40001-40080",
protocol: testUDPProtocol,
expectedLastAssignedPort: []int{40040},
isPortAvailableFunc: func(port int, protocol string) (bool, error) { return true, nil },
numberOfRequests: 1,
expectedError: nil,
},
Expand All @@ -106,6 +109,7 @@ func TestGetHostPortRange(t *testing.T) {
testDynamicHostPortRange: "40001-40080",
protocol: testTCPProtocol,
expectedLastAssignedPort: []int{40060, 40000},
isPortAvailableFunc: func(port int, protocol string) (bool, error) { return true, nil },
numberOfRequests: 2,
expectedError: nil,
},
Expand All @@ -115,24 +119,42 @@ func TestGetHostPortRange(t *testing.T) {
testDynamicHostPortRange: "40001-40080",
protocol: testUDPProtocol,
expectedLastAssignedPort: []int{40015},
isPortAvailableFunc: func(port int, protocol string) (bool, error) { return true, nil },
numberOfRequests: 1,
expectedError: nil,
},
{
testName: "contiguous hostPortRange not found",
testName: "contiguous hostPortRange not found, numberOfPorts more than available",
numberOfPorts: 20,
testDynamicHostPortRange: "40001-40005",
protocol: testTCPProtocol,
isPortAvailableFunc: func(port int, protocol string) (bool, error) { return true, nil },
numberOfRequests: 1,
expectedError: errors.New("20 contiguous host ports unavailable"),
},
{
testName: "contiguous hostPortRange not found, no ports available on the host",
numberOfPorts: 5,
testDynamicHostPortRange: "40001-40005",
protocol: testTCPProtocol,
isPortAvailableFunc: func(port int, protocol string) (bool, error) { return false, nil },
numberOfRequests: 1,
expectedError: errors.New("5 contiguous host ports unavailable"),
},
}

// mock isPortAvailable() for unit test
// this ensures that the test doesn't rely on the runtime port availability on the host
isPortAvailableFuncTmp := isPortAvailableFunc
defer func() {
isPortAvailableFunc = isPortAvailableFuncTmp
}()

for _, tc := range testCases {
t.Run(tc.testName, func(t *testing.T) {
for i := 0; i < tc.numberOfRequests; i++ {
isPortAvailableFunc = tc.isPortAvailableFunc
if tc.expectedError == nil {

hostPortRange, err := GetHostPortRange(tc.numberOfPorts, tc.protocol, tc.testDynamicHostPortRange)
assert.NoError(t, err)

Expand Down

0 comments on commit cea7a56

Please sign in to comment.